Personnalisez une Stratégie Flower

Bienvenue dans la prochaine partie du tutoriel d’intelligence artificielle collaborative Flower !

Dans les précédents tutoriels, vous avez créé une fédération simulée sur SuperGrid, exécuté et personnalisé des Applications Flower, passé de la démo NumPy à l’application quickstart PyTorch, et puis personnalisé cette application PyTorch en changeant et en étendant sa stratégie. Dans ce tutoriel, vous allez aller un peu plus loin et créer une version plus personnalisée de FedAdagrad.

Astuce

Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Flower Discuss ou Flower Slack pour vous présenter, poser des questions et obtenir de l’aide.

Créons une nouvelle Strategy avec un méthode de start personnalisée qui :

  • sauvegarde une copie du modèle global lorsque l’on trouve un nouveau meilleur taux d’exactitude global;

  • enregistre les métriques générées pendant l’exécution dans Weights & Biases !

Préparation

Ce tutoriel continue à partir de previous tutorial. Si vous l’avez terminé, ouvrez le répertoire existant quickstart-pytorch et continuez d’où vous en étiez.

Installation des dépendances

Si vous commencez ici directement, créez d’abord l’application comme montré dans le précédent tutoriel :

# Install Flower
$ pip install -U "flwr[simulation]"
# Create a new Flower App using the PyTorch quickstart template
$ flwr new @flwrlabs/quickstart-pytorch

Dans ce tutoriel, vous allez utiliser Weights & Biases pour enregistrer les métriques de stratégie. À partir du répertoire quickstart-pytorch, ajoutez wandb à la liste des dépendances dans pyproject.toml:

$ cd quickstart-pytorch
"wandb>=0.17.8"

Ensuite installez les dépendances de projet mises à jour :

$ pip install -e .

Note

Si c’est la première fois que vous installez wandb, vous pourriez être invité à créer un compte puis à vous connecter à votre système. Vous pouvez démarrer ce processus en tapant cela dans votre terminal:

$ wandb login

Personnalisez la méthode start d’une stratégie

Les stratégies Flower ont un certain nombre de méthodes qui peuvent être surchargées pour personnaliser leur comportement. Dans le précédent tutoriel de stratégie, vous avez appris comment personnaliser la méthode configure_train pour effectuer une décroissance du taux d’apprentissage et communiquer l’actualisation du taux d’apprentissage comme partie des messages ConfigRecord envoyés aux clients dans le tour Message. Dans ce tutoriel, vous allez apprendre à personnaliser la méthode start. Si vous inspectez le code de cette méthode, vous verrez qu’elle contient un boucle for où chaque itération représente une ronde d’apprentissage fédéré. Chaque tour se compose de trois étapes distinctes :

  1. Une étape d’entraînement, où un sous-ensemble de clients est sélectionné pour entraîner le modèle global actuel sur leurs données locales.

  2. Une étape d’évaluation, où un sous-ensemble de clients est sélectionné pour évaluer le modèle global mis à jour sur leurs jeux de validation locaux.

  3. Une étape facultative pour évaluer le modèle global sur le côté serveur. Notez que c’est ce que vous avez activé dans le tutoriel précédent par l’intermédiaire de l’appel-back global_evaluate.

Élargissons la stratégie CustomFedAdagrad que nous avons créée plus tôt et introduisons :

  1. _update_best_acc: Une méthode auxiliaire pour sauvegarder le modèle global chaque fois qu’un nouveau meilleur taux d’exactitude est trouvé.

  2. set_save_path: Une méthode auxiliaire pour définir le chemin où les journaux et les checkpoints de modèle seront sauvegardés. Cette méthode sera appelée à partir du server_app.py après avoir instancié la stratégie.

  3. Une méthode personnalisée start pour enregistrer les métriques dans Weights & Biases (W&B) et sauvegarder les checkpoints de modèle sur le disque.

import io
import time
from logging import INFO
from pathlib import Path
from typing import Callable, Iterable, Optional

import torch
import wandb
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
from flwr.common import log, logger
from flwr.serverapp import Grid
from flwr.serverapp.strategy import FedAdagrad, Result
from flwr.serverapp.strategy.strategy_utils import log_strategy_start_info

PROJECT_NAME = "FLOWER-advanced-pytorch"


class CustomFedAdagrad(FedAdagrad):

    def configure_train(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of federated training and maybe do LR decay."""
        # Decrease learning rate by a factor of 0.5 every 5 rounds
        if server_round % 5 == 0 and server_round > 0:
            config["lr"] *= 0.5
            print("LR decreased to:", config["lr"])
        # Pass the updated config and the rest of arguments to the parent class
        return super().configure_train(server_round, arrays, config, grid)

    def set_save_path(self, path: Path):
        """Set the path where wandb logs and model checkpoints will be saved."""
        self.save_path = path

    def _update_best_acc(
        self, current_round: int, accuracy: float, arrays: ArrayRecord
    ) -> None:
        """Update best accuracy and save model checkpoint if current accuracy is
        higher."""
        if accuracy > self.best_acc_so_far:
            self.best_acc_so_far = accuracy
            logger.log(INFO, "💡 New best global model found: %f", accuracy)
            # Save the PyTorch model
            file_name = f"model_state_acc_{accuracy}_round_{current_round}.pth"
            torch.save(arrays.to_torch_state_dict(), self.save_path / file_name)
            logger.log(INFO, "💾 New best model saved to disk: %s", file_name)

    def start(
        self,
        grid: Grid,
        initial_arrays: ArrayRecord,
        num_rounds: int = 3,
        timeout: float = 3600,
        train_config: Optional[ConfigRecord] = None,
        evaluate_config: Optional[ConfigRecord] = None,
        evaluate_fn: Optional[
            Callable[[int, ArrayRecord], Optional[MetricRecord]]
        ] = None,
    ) -> Result:
        """Execute the federated learning strategy logging results to W&B and saving
        them to disk."""

        # Init W&B
        name = f"{str(self.save_path.parent.name)}/{str(self.save_path.name)}-ServerApp"
        wandb.init(project=PROJECT_NAME, name=name)

        # Keep track of best acc
        self.best_acc_so_far = 0.0

        log(INFO, "Starting %s strategy:", self.__class__.__name__)
        log_strategy_start_info(
            num_rounds, initial_arrays, train_config, evaluate_config
        )
        self.summary()
        log(INFO, "")

        # Initialize if None
        train_config = ConfigRecord() if train_config is None else train_config
        evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
        result = Result()

        t_start = time.time()
        # Evaluate starting global parameters
        if evaluate_fn:
            res = evaluate_fn(0, initial_arrays)
            log(INFO, "Initial global evaluation results: %s", res)
            if res is not None:
                result.evaluate_metrics_serverapp[0] = res

        arrays = initial_arrays

        for current_round in range(1, num_rounds + 1):
            log(INFO, "")
            log(INFO, "[ROUND %s/%s]", current_round, num_rounds)

            # -----------------------------------------------------------------
            # --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
            # -----------------------------------------------------------------

            # Call strategy to configure training round
            # Send messages and wait for replies
            train_replies = grid.send_and_receive(
                messages=self.configure_train(
                    current_round,
                    arrays,
                    train_config,
                    grid,
                ),
                timeout=timeout,
            )

            # Aggregate train
            agg_arrays, agg_train_metrics = self.aggregate_train(
                current_round,
                train_replies,
            )

            # Log training metrics and append to history
            if agg_arrays is not None:
                result.arrays = agg_arrays
                arrays = agg_arrays
            if agg_train_metrics is not None:
                log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
                result.train_metrics_clientapp[current_round] = agg_train_metrics
                # Log to W&B
                wandb.log(dict(agg_train_metrics), step=current_round)

            # -----------------------------------------------------------------
            # --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
            # -----------------------------------------------------------------

            # Call strategy to configure evaluation round
            # Send messages and wait for replies
            evaluate_replies = grid.send_and_receive(
                messages=self.configure_evaluate(
                    current_round,
                    arrays,
                    evaluate_config,
                    grid,
                ),
                timeout=timeout,
            )

            # Aggregate evaluate
            agg_evaluate_metrics = self.aggregate_evaluate(
                current_round,
                evaluate_replies,
            )

            # Log training metrics and append to history
            if agg_evaluate_metrics is not None:
                log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
                result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
                # Log to W&B
                wandb.log(dict(agg_evaluate_metrics), step=current_round)
            # -----------------------------------------------------------------
            # --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
            # -----------------------------------------------------------------

            # Centralized evaluation
            if evaluate_fn:
                log(INFO, "Global evaluation")
                res = evaluate_fn(current_round, arrays)
                log(INFO, "\t└──> MetricRecord: %s", res)
                if res is not None:
                    result.evaluate_metrics_serverapp[current_round] = res
                    # Maybe save to disk if new best is found
                    self._update_best_acc(current_round, res["accuracy"], arrays)
                    # Log to W&B
                    wandb.log(dict(res), step=current_round)

        log(INFO, "")
        log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
        log(INFO, "")
        log(INFO, "Final results:")
        log(INFO, "")
        for line in io.StringIO(str(result)):
            log(INFO, "\t%s", line.strip("\n"))
        log(INFO, "")

        return result

Avec la stratégie CustomFedAdagrad étendue définie, nous avons besoin maintenant de définir l’emplacement où les checkpoints de modèle seront sauvegardés ainsi que le nom des runs dans W&B. Nous devons appeler la méthode set_save_path après avoir instancié la stratégie et avant d’appeler la méthode start. Dans server_app.py, nous pouvons créer un nouveau répertoire appelé results et puis un sous-répertoire avec l’heure courante pour stocker les résultats de chaque run. Nous pouvons ensuite appeler la méthode set_save_path. Dans ce tutoriel, nous créons le répertoire basé sur la date et l’heure actuelles, cela signifie que chaque fois que vous faites flwr run, un nouveau répertoire sera utilisé. Voyons comment cela ressemble en code :

# ... unchanged
# add this to the imports
from datetime import datetime
from pathlib import Path

# ... unchanged


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # ... unchanged

    # Initialize FedAdagrad strategy
    # strategy = CustomFedAdagrad( ... )

    # Get the current date and time
    current_time = datetime.now()
    run_dir = current_time.strftime("%Y-%m-%d/%H-%M-%S")
    # Save path is based on the current directory
    save_path = Path.cwd() / f"outputs/{run_dir}"
    save_path.mkdir(parents=True, exist_ok=False)

    # Set the path where results and model checkpoints will be saved
    strategy.set_save_path(save_path)

    # ... rest unchanged

Finalement, lançons l’application Flower localement. Cette étude écrit les checkpoints de modèle dans votre répertoire de travail et enregistre des métriques dans Weights & Biases, ce qui facilite la visualisation des résultats.

$ flwr run . local --stream

Le flwr run . local soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Voir Exécuter Flower Localement avec un SuperLink Géré pour le workflow local complet.

Après avoir démarré l’exécution, vous remarquerez deux choses :

  1. Un nouveau répertoire sera créé dans sorties/YYYY-MM-DD/HH-MM-SSYYYY-MM-DD/HH-MM-SS est la date et l’heure actuelles. Ce répertoire contiendra les checkpoints de modèle sauvegardés pendant l’exécution. Rappelez-vous que un checkpoint est sauvegardé chaque fois qu’un nouveau meilleur taux d’exactitude est trouvé lors de la phase d’évaluation centralisée.

  2. Une nouvelle exécution sera créée dans votre W&B project où vous pouvez visualiser les métriques enregistrées pendant l’exécution.

Félicitations ! Vous avez réussi à créer une stratégie Flower personnalisée en surchargeant la méthode start. Vous avez également appris à enregistrer des métriques dans Weights & Biases et à sauvegarder les checkpoints de modèle sur le disque.

Récapitulation

Dans cette étude, nous avons vu comment personnaliser la méthode start d’une stratégie Flower. Cette méthode est l’entrée principale de toute stratégie et contient la logique pour exécuter le processus d’apprentissage fédéré. Dans cette étude, vous avez appris à enregistrer les métriques dans Weights & Biases et à sauvegarder les checkpoints de modèle sur le disque.

Dans le prochain tutoriel, vous communiquerez des informations supplémentaires entre la ClientApp et la ServerApp en les sérialisant et en les envoyant dans un Message.

Prochaines étapes

Avant de continuer, assurez-vous d’adhérer à la communauté Flower sur Flower Discuss (Join Flower Discuss) et Slack (Join Slack).

Il existe un canal Slack dédié si vous avez besoin d’aide, mais nous aimerions également entendre qui vous êtes dans #introductions !

Le Flower Collaborative AI Tutorial - Part 6: Communicate custom Messages montre comment personnaliser ce que le ClientApp envoie à nouveau au ServerApp.