Flower Strategy Abstraction (abstraction Strategy de Flower)

L’abstraction stratégique permet la mise en œuvre de stratégies d’apprentissage fédéré personnalisées. Dans Flower, une stratégie est essentiellement l’algorithme d’apprentissage fédéré qui s’exécute à l’intérieur du ServerApp. Les stratégies définissent comment faire :

  • Échantillonner des clients

  • Configurer les instructions d’entraînement et d’évaluation

  • Mettre en commun les mises à jour et les métriques

  • Évaluer les modèles

Flower embarque un certain nombre de stratégies intégrées, toutes suivant la même API décrite ci-dessous. Vous pouvez également mettre en œuvre vos propres stratégies avec accès complet aux mêmes capacités.

L’abstraction Strategy

Toutes les implémentations de stratégie doivent dériver de la classe abstraite Strategy. Cela inclut à la fois les stratégies intégrées et les stratégies tierces-parties/personnalisées. En étendant cette classe, les stratégies définies par l’utilisateur gagnent le même pouvoir et la même flexibilité que celles intégrées.

La classe de base Strategy définit une méthode start et exige que les sous-classes implémentent plusieurs méthodes abstraites :

class Strategy(ABC):
    """Abstract base class for server strategy implementations."""

    @abstractmethod
    def configure_train(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of training."""

    @abstractmethod
    def aggregate_train(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
        """Aggregate training results from client nodes."""

    @abstractmethod
    def configure_evaluate(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of evaluation."""

    @abstractmethod
    def aggregate_evaluate(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> Optional[MetricRecord]:
        """Aggregate evaluation metrics from client nodes."""

    @abstractmethod
    def summary(self) -> None:
        """Log a summary of the strategy configuration."""

    def start(
        self,
        grid: Grid,
        initial_arrays: ArrayRecord,
        num_rounds: int = 3,
        timeout: float = 3600,
        train_config: Optional[ConfigRecord] = None,
        evaluate_config: Optional[ConfigRecord] = None,
        evaluate_fn: Optional[
            Callable[[int, ArrayRecord], Optional[MetricRecord]]
        ] = None,
    ) -> Result:
        """Execute the federated learning strategy."""
        # Implementation details
        pass

Créer une nouvelle stratégie

Vous pouvez personnaliser une stratégie existante (par exemple, FedAvg) en surchargeant une ou plusieurs de ses méthodes. Pour une flexibilité totale, vous pouvez également mettre en œuvre une stratégie à partir de zéro. Pour mettre en œuvre une nouvelle stratégie, définez simplement une classe qui dérive de Strategy et implémentez les méthodes abstraites :

class SotaStrategy(Strategy):

    def configure_train(self, server_round, arrays, config, grid):
        # Your implementation here
        pass

    def aggregate_train(self, server_round, replies):
        # Your implementation here
        pass

    def configure_evaluate(self, server_round, arrays, config, grid):
        # Your implementation here
        pass

    def aggregate_evaluate(self, server_round, replies):
        # Your implementation here
        pass

    def summary(self):
        print("SotaStrategy: This is the state-of-the-art strategy!")

La méthode start est déjà mise en œuvre dans la classe de base et ne nécessite généralement pas d’être surchargée. Elle orchestre le processus d’apprentissage fédéré en invoquant les méthodes abstraites dans une séquence.

Comprenez la méthode start

La méthode start de la classe de base Strategy suit ce workflow :

  1. Appelez evaluate_fn (si fourni) pour évaluer le modèle initial sur le côté ServerApp.

  2. Appelez configure_train pour générer des messages d’entraînement pour les applications Client.

  3. Envoyer des messages d’entraînement aux ClientApps.

  4. Les ClientApps exécutent leur fonction @app.train() et retournent les réponses d’entraînement.

  5. Appelez aggregate_train pour agréger les réponses d’entraînement.

  6. Appelez configure_evaluate pour générer des messages d’évaluation pour les applications Client.

  7. Envoyer des messages d’évaluation aux ClientApps.

  8. Les ClientApps exécutent leur fonction @app.evaluate() et retournent les réponses d’évaluation.

  9. Appelez aggregate_evaluate pour agréger les réponses d’évaluation.

  10. Appelez evaluate_fn (si fourni) pour évaluer le modèle agrégé sur le côté ServerApp.

  11. Répétez les étapes 2-10 pour le nombre spécifié de tours.

  12. Retournez le modèle final, qui contient le modèle final et l’historique des métriques.

Le diagramme suivant illustre le flux.

Note

Le diagramme de séquence ci-dessous montre l’interaction entre ServerApp, Strategy (à l’intérieur de ServerApp) et ClientApp. En réalité, ils ne communiquent pas directement sur le réseau—l’infrastructure Flower (SuperLink et SuperNode) gère transparentement toutes les communications. Vous pouvez en savoir plus à ce sujet dans le guide Flower Network Communication.

sequenceDiagram participant SA as ServerApp participant ST as Strategy participant CA as ClientApps SA->>ST: start(num_rounds, ...) opt ST->>ST: evaluate_fn() end loop rounds 1..N Note over ST: --- Training Phase --- ST->>ST: configure_train() ST->>CA: train_messages CA->>CA: @app.train() callback CA-->>ST: train_replies ST->>ST: aggregate_train(train_replies) Note over ST: --- Evaluation Phase --- ST->>ST: configure_evaluate() ST->>CA: evaluate_messages CA->>CA: @app.evaluate() callback CA-->>ST: evaluate_replies ST->>ST: aggregate_evaluate(evaluate_replies) opt ST->>ST: evaluate_fn() end end ST-->>SA: final Result

La méthode configure_train

La méthode configure_train est responsable de la préparation de la prochaine ronde d’entraînement. Mais qu’est-ce que configurer signifie dans ce contexte ? Cela signifie sélectionner les clients qui devraient participer à le tour et décider quelles instructions ils devraient recevoir.

Voici l’énoncé de la méthode :

@abstractmethod
def configure_train(
    self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
    """Configure the next round of training."""

Cette méthode prend quatre arguments :

  • server_round: Le numéro actuel de tour

  • arrays: Les paramètres du modèle global actuels

  • config: Un dictionnaire de configuration pour le tour

  • grid: L’objet responsable de la gestion de la communication avec les clients

La valeur de retour est un itérateur de Message objets, où chaque message contient les instructions à envoyer à un client spécifique. Une mise en œuvre typique de configure_train sera :

  • Utilisez le grid pour échantillonner aléatoirement un sous-ensemble (ou tous) des clients disponibles

  • Créez un Message par client sélectionné, contenant les paramètres du modèle global et les valeurs de configuration

Des stratégies plus avancées peuvent mettre en œuvre une logique de sélection de client personnalisée en utilisant les capacités de grid. Un client ne participe qu’à un tour si configure_train génère un message pour son ID de nœud.

Note

Puisque la valeur de retour est définie par client, les stratégies peuvent facilement mettre en œuvre des configurations hétérogènes. Par exemple, différents clients peuvent recevoir des modèles ou des hyperparamètres différents, permettant des comportements d’entraînement hautement personnalisés.

La méthode aggregate_train

La méthode aggregate_train est responsable de l’agrégation des résultats d’entraînement reçus des clients sélectionnés dans configure_train.

Voici l’énoncé de la méthode :

@abstractmethod
def aggregate_train(
    self,
    server_round: int,
    replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
    """Aggregate training results from client nodes."""

Cette méthode prend deux arguments :

  • server_round: Le numéro actuel de tour

  • replies: Un itérateur de Message objets provenant des clients participants

Elle retourne un tuple composé de :

  1. ArrayRecord: Les paramètres du modèle global mis à jour

  2. MetricRecord: Les métriques d’entraînement agrégées (telles que la perte ou l’exactitude)

Si l’agrégation ne peut pas être effectuée (par exemple, si trop de clients ont échoué pendant le tour), la méthode peut décider de retourner (None, None) au lieu.

Indication

Vous pouvez utiliser Message.has_error() pour vérifier si une réponse contient une erreur et décider comment y réagir lors de l’agrégation.

La méthode configure_evaluate

La méthode configure_evaluate est responsable de la préparation du prochain tour d’évaluation. De même que configure_train, cela implique la sélection des clients qui devraient participer et la décision de ce qu’ils devraient recevoir en instructions pour l’évaluation.

Voici l’énoncé de la méthode :

@abstractmethod
def configure_evaluate(
    self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
    """Configure the next round of evaluation."""

Cette méthode prend quatre arguments :

  • server_round: Le numéro actuel de tour

  • arrays: Les paramètres du modèle global actuels à évaluer

  • config: Un dictionnaire de configuration pour l’évaluation

  • grid: L’objet qui gère la communication avec les clients

La valeur de retour est un itérateur de Message objets, un par client sélectionné. Chaque message contient généralement les paramètres du modèle global actuels et toute configuration d’évaluation.

Une mise en œuvre typique de configure_evaluate sera :

  • Utilisez grid pour sélectionner un sous-ensemble (ou tous) des clients disponibles

  • Créez un Message par client sélectionné contenant le modèle global et la configuration d’évaluation

Comme avec l’entraînement, les stratégies plus avancées peuvent appliquer une logique de sélection de client personnalisée ou envoyer différentes configurations d’évaluation à différents clients.

Note

Puisque chaque client reçoit son propre message, les stratégies peuvent mettre en œuvre des paramètres d’évaluation hétérogènes. Par exemple, certains clients pourraient évaluer sur des jeux de test plus grands, tandis que d’autres utiliseraient des métriques spécialisées.

La méthode aggregate_evaluate

La méthode aggregate_evaluate est responsable de l’agrégation des résultats d’évaluation reçus des clients sélectionnés dans configure_evaluate.

Voici l’énoncé de la méthode :

@abstractmethod
def aggregate_evaluate(
    self,
    server_round: int,
    replies: Iterable[Message],
) -> Optional[MetricRecord]:
    """Aggregate evaluation metrics from client nodes."""

Cette méthode prend deux arguments :

  • server_round: Le numéro actuel de tour

  • replies: Une itérable de Message objets retournés par les clients après qu’ils ont exécuté l’évaluation

Il retourne un seul MetricRecord qui représente les métriques d’évaluation agrégées sur tous les clients participant. Si l’agrégation ne peut pas être effectuée (par exemple, en raison de trop nombreuses échecs des clients ou de métriques manquantes), la méthode peut retourner None.

Indication

Comme pour l’entraînement, Message.has_error() peut être utilisé pour détecter et gérer les erreurs des clients pendant l’agrégation.