Flower Strategy Abstraction (abstraction Strategy de Flower)¶
L’abstraction stratégique permet la mise en œuvre de stratégies d’apprentissage fédéré personnalisées. Dans Flower, une stratégie est essentiellement l’algorithme d’apprentissage fédéré qui s’exécute à l’intérieur du ServerApp. Les stratégies définissent comment faire :
Échantillonner des clients
Configurer les instructions d’entraînement et d’évaluation
Mettre en commun les mises à jour et les métriques
Évaluer les modèles
Flower embarque un certain nombre de stratégies intégrées, toutes suivant la même API décrite ci-dessous. Vous pouvez également mettre en œuvre vos propres stratégies avec accès complet aux mêmes capacités.
L’abstraction Strategy¶
Toutes les implémentations de stratégie doivent dériver de la classe abstraite Strategy. Cela inclut à la fois les stratégies intégrées et les stratégies tierces-parties/personnalisées. En étendant cette classe, les stratégies définies par l’utilisateur gagnent le même pouvoir et la même flexibilité que celles intégrées.
La classe de base Strategy définit une méthode start et exige que les sous-classes implémentent plusieurs méthodes abstraites :
class Strategy(ABC):
"""Abstract base class for server strategy implementations."""
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
@abstractmethod
def summary(self) -> None:
"""Log a summary of the strategy configuration."""
def start(
self,
grid: Grid,
initial_arrays: ArrayRecord,
num_rounds: int = 3,
timeout: float = 3600,
train_config: Optional[ConfigRecord] = None,
evaluate_config: Optional[ConfigRecord] = None,
evaluate_fn: Optional[
Callable[[int, ArrayRecord], Optional[MetricRecord]]
] = None,
) -> Result:
"""Execute the federated learning strategy."""
# Implementation details
pass
Créer une nouvelle stratégie¶
Vous pouvez personnaliser une stratégie existante (par exemple, FedAvg) en surchargeant une ou plusieurs de ses méthodes. Pour une flexibilité totale, vous pouvez également mettre en œuvre une stratégie à partir de zéro. Pour mettre en œuvre une nouvelle stratégie, définez simplement une classe qui dérive de Strategy et implémentez les méthodes abstraites :
class SotaStrategy(Strategy):
def configure_train(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_train(self, server_round, replies):
# Your implementation here
pass
def configure_evaluate(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_evaluate(self, server_round, replies):
# Your implementation here
pass
def summary(self):
print("SotaStrategy: This is the state-of-the-art strategy!")
La méthode start est déjà mise en œuvre dans la classe de base et ne nécessite généralement pas d’être surchargée. Elle orchestre le processus d’apprentissage fédéré en invoquant les méthodes abstraites dans une séquence.
Comprenez la méthode start¶
La méthode start de la classe de base Strategy suit ce workflow :
Appelez
evaluate_fn(si fourni) pour évaluer le modèle initial sur le côté ServerApp.Appelez
configure_trainpour générer des messages d’entraînement pour les applications Client.Envoyer des messages d’entraînement aux ClientApps.
Les ClientApps exécutent leur fonction
@app.train()et retournent les réponses d’entraînement.Appelez
aggregate_trainpour agréger les réponses d’entraînement.Appelez
configure_evaluatepour générer des messages d’évaluation pour les applications Client.Envoyer des messages d’évaluation aux ClientApps.
Les ClientApps exécutent leur fonction
@app.evaluate()et retournent les réponses d’évaluation.Appelez
aggregate_evaluatepour agréger les réponses d’évaluation.Appelez
evaluate_fn(si fourni) pour évaluer le modèle agrégé sur le côté ServerApp.Répétez les étapes 2-10 pour le nombre spécifié de tours.
Retournez le modèle final, qui contient le modèle final et l’historique des métriques.
Le diagramme suivant illustre le flux.
Note
Le diagramme de séquence ci-dessous montre l’interaction entre ServerApp, Strategy (à l’intérieur de ServerApp) et ClientApp. En réalité, ils ne communiquent pas directement sur le réseau—l’infrastructure Flower (SuperLink et SuperNode) gère transparentement toutes les communications. Vous pouvez en savoir plus à ce sujet dans le guide Flower Network Communication.
La méthode configure_train¶
La méthode configure_train est responsable de la préparation de la prochaine ronde d’entraînement. Mais qu’est-ce que configurer signifie dans ce contexte ? Cela signifie sélectionner les clients qui devraient participer à le tour et décider quelles instructions ils devraient recevoir.
Voici l’énoncé de la méthode :
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
Cette méthode prend quatre arguments :
server_round: Le numéro actuel de tourarrays: Les paramètres du modèle global actuelsconfig: Un dictionnaire de configuration pour le tourgrid: L’objet responsable de la gestion de la communication avec les clients
La valeur de retour est un itérateur de Message objets, où chaque message contient les instructions à envoyer à un client spécifique. Une mise en œuvre typique de configure_train sera :
Utilisez le
gridpour échantillonner aléatoirement un sous-ensemble (ou tous) des clients disponiblesCréez un
Messagepar client sélectionné, contenant les paramètres du modèle global et les valeurs de configuration
Des stratégies plus avancées peuvent mettre en œuvre une logique de sélection de client personnalisée en utilisant les capacités de grid. Un client ne participe qu’à un tour si configure_train génère un message pour son ID de nœud.
Note
Puisque la valeur de retour est définie par client, les stratégies peuvent facilement mettre en œuvre des configurations hétérogènes. Par exemple, différents clients peuvent recevoir des modèles ou des hyperparamètres différents, permettant des comportements d’entraînement hautement personnalisés.
La méthode aggregate_train¶
La méthode aggregate_train est responsable de l’agrégation des résultats d’entraînement reçus des clients sélectionnés dans configure_train.
Voici l’énoncé de la méthode :
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
Cette méthode prend deux arguments :
server_round: Le numéro actuel de tourreplies: Un itérateur deMessageobjets provenant des clients participants
Elle retourne un tuple composé de :
ArrayRecord: Les paramètres du modèle global mis à jourMetricRecord: Les métriques d’entraînement agrégées (telles que la perte ou l’exactitude)
Si l’agrégation ne peut pas être effectuée (par exemple, si trop de clients ont échoué pendant le tour), la méthode peut décider de retourner (None, None) au lieu.
Indication
Vous pouvez utiliser Message.has_error() pour vérifier si une réponse contient une erreur et décider comment y réagir lors de l’agrégation.
La méthode configure_evaluate¶
La méthode configure_evaluate est responsable de la préparation du prochain tour d’évaluation. De même que configure_train, cela implique la sélection des clients qui devraient participer et la décision de ce qu’ils devraient recevoir en instructions pour l’évaluation.
Voici l’énoncé de la méthode :
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
Cette méthode prend quatre arguments :
server_round: Le numéro actuel de tourarrays: Les paramètres du modèle global actuels à évaluerconfig: Un dictionnaire de configuration pour l’évaluationgrid: L’objet qui gère la communication avec les clients
La valeur de retour est un itérateur de Message objets, un par client sélectionné. Chaque message contient généralement les paramètres du modèle global actuels et toute configuration d’évaluation.
Une mise en œuvre typique de configure_evaluate sera :
Utilisez
gridpour sélectionner un sous-ensemble (ou tous) des clients disponiblesCréez un
Messagepar client sélectionné contenant le modèle global et la configuration d’évaluation
Comme avec l’entraînement, les stratégies plus avancées peuvent appliquer une logique de sélection de client personnalisée ou envoyer différentes configurations d’évaluation à différents clients.
Note
Puisque chaque client reçoit son propre message, les stratégies peuvent mettre en œuvre des paramètres d’évaluation hétérogènes. Par exemple, certains clients pourraient évaluer sur des jeux de test plus grands, tandis que d’autres utiliseraient des métriques spécialisées.
La méthode aggregate_evaluate¶
La méthode aggregate_evaluate est responsable de l’agrégation des résultats d’évaluation reçus des clients sélectionnés dans configure_evaluate.
Voici l’énoncé de la méthode :
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
Cette méthode prend deux arguments :
server_round: Le numéro actuel de tourreplies: Une itérable deMessageobjets retournés par les clients après qu’ils ont exécuté l’évaluation
Il retourne un seul MetricRecord qui représente les métriques d’évaluation agrégées sur tous les clients participant. Si l’agrégation ne peut pas être effectuée (par exemple, en raison de trop nombreuses échecs des clients ou de métriques manquantes), la méthode peut retourner None.
Indication
Comme pour l’entraînement, Message.has_error() peut être utilisé pour détecter et gérer les erreurs des clients pendant l’agrégation.