Personnalisez une Stratégie Flower¶
Bienvenue dans la prochaine partie du tutoriel d’intelligence artificielle collaborative Flower !
Dans les précédents tutoriels, vous avez créé une fédération simulée sur SuperGrid, exécuté et personnalisé des Applications Flower, passé de la démo NumPy à l’application quickstart PyTorch, et puis personnalisé cette application PyTorch en changeant et en étendant sa stratégie. Dans ce tutoriel, vous allez aller un peu plus loin et créer une version plus personnalisée de FedAdagrad.
Astuce
Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Flower Discuss ou Flower Slack pour vous présenter, poser des questions et obtenir de l’aide.
Créons une nouvelle Strategy avec un méthode de start personnalisée qui :
sauvegarde une copie du modèle global lorsque l’on trouve un nouveau meilleur taux d’exactitude global;
enregistre les métriques générées pendant l’exécution dans Weights & Biases !
Préparation¶
Ce tutoriel continue à partir de previous tutorial. Si vous l’avez terminé, ouvrez le répertoire existant quickstart-pytorch et continuez d’où vous en étiez.
Installation des dépendances¶
Si vous commencez ici directement, créez d’abord l’application comme montré dans le précédent tutoriel :
# Install Flower
$ pip install -U "flwr[simulation]"
# Create a new Flower App using the PyTorch quickstart template
$ flwr new @flwrlabs/quickstart-pytorch
Dans ce tutoriel, vous allez utiliser Weights & Biases pour enregistrer les métriques de stratégie. À partir du répertoire quickstart-pytorch, ajoutez wandb à la liste des dépendances dans pyproject.toml:
$ cd quickstart-pytorch
"wandb>=0.17.8"
Ensuite installez les dépendances de projet mises à jour :
$ pip install -e .
Note
Si c’est la première fois que vous installez wandb, vous pourriez être invité à créer un compte puis à vous connecter à votre système. Vous pouvez démarrer ce processus en tapant cela dans votre terminal:
$ wandb login
Personnalisez la méthode start d’une stratégie¶
Les stratégies Flower ont un certain nombre de méthodes qui peuvent être surchargées pour personnaliser leur comportement. Dans le précédent tutoriel de stratégie, vous avez appris comment personnaliser la méthode configure_train pour effectuer une décroissance du taux d’apprentissage et communiquer l’actualisation du taux d’apprentissage comme partie des messages ConfigRecord envoyés aux clients dans le tour Message. Dans ce tutoriel, vous allez apprendre à personnaliser la méthode start. Si vous inspectez le code de cette méthode, vous verrez qu’elle contient un boucle for où chaque itération représente une ronde d’apprentissage fédéré. Chaque tour se compose de trois étapes distinctes :
Une étape d’entraînement, où un sous-ensemble de clients est sélectionné pour entraîner le modèle global actuel sur leurs données locales.
Une étape d’évaluation, où un sous-ensemble de clients est sélectionné pour évaluer le modèle global mis à jour sur leurs jeux de validation locaux.
Une étape facultative pour évaluer le modèle global sur le côté serveur. Notez que c’est ce que vous avez activé dans le tutoriel précédent par l’intermédiaire de l’appel-back
global_evaluate.
Élargissons la stratégie CustomFedAdagrad que nous avons créée plus tôt et introduisons :
_update_best_acc: Une méthode auxiliaire pour sauvegarder le modèle global chaque fois qu’un nouveau meilleur taux d’exactitude est trouvé.set_save_path: Une méthode auxiliaire pour définir le chemin où les journaux et les checkpoints de modèle seront sauvegardés. Cette méthode sera appelée à partir duserver_app.pyaprès avoir instancié la stratégie.Une méthode personnalisée
startpour enregistrer les métriques dans Weights & Biases (W&B) et sauvegarder les checkpoints de modèle sur le disque.
import io
import time
from logging import INFO
from pathlib import Path
from typing import Callable, Iterable, Optional
import torch
import wandb
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
from flwr.common import log, logger
from flwr.serverapp import Grid
from flwr.serverapp.strategy import FedAdagrad, Result
from flwr.serverapp.strategy.strategy_utils import log_strategy_start_info
PROJECT_NAME = "FLOWER-advanced-pytorch"
class CustomFedAdagrad(FedAdagrad):
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of federated training and maybe do LR decay."""
# Decrease learning rate by a factor of 0.5 every 5 rounds
if server_round % 5 == 0 and server_round > 0:
config["lr"] *= 0.5
print("LR decreased to:", config["lr"])
# Pass the updated config and the rest of arguments to the parent class
return super().configure_train(server_round, arrays, config, grid)
def set_save_path(self, path: Path):
"""Set the path where wandb logs and model checkpoints will be saved."""
self.save_path = path
def _update_best_acc(
self, current_round: int, accuracy: float, arrays: ArrayRecord
) -> None:
"""Update best accuracy and save model checkpoint if current accuracy is
higher."""
if accuracy > self.best_acc_so_far:
self.best_acc_so_far = accuracy
logger.log(INFO, "💡 New best global model found: %f", accuracy)
# Save the PyTorch model
file_name = f"model_state_acc_{accuracy}_round_{current_round}.pth"
torch.save(arrays.to_torch_state_dict(), self.save_path / file_name)
logger.log(INFO, "💾 New best model saved to disk: %s", file_name)
def start(
self,
grid: Grid,
initial_arrays: ArrayRecord,
num_rounds: int = 3,
timeout: float = 3600,
train_config: Optional[ConfigRecord] = None,
evaluate_config: Optional[ConfigRecord] = None,
evaluate_fn: Optional[
Callable[[int, ArrayRecord], Optional[MetricRecord]]
] = None,
) -> Result:
"""Execute the federated learning strategy logging results to W&B and saving
them to disk."""
# Init W&B
name = f"{str(self.save_path.parent.name)}/{str(self.save_path.name)}-ServerApp"
wandb.init(project=PROJECT_NAME, name=name)
# Keep track of best acc
self.best_acc_so_far = 0.0
log(INFO, "Starting %s strategy:", self.__class__.__name__)
log_strategy_start_info(
num_rounds, initial_arrays, train_config, evaluate_config
)
self.summary()
log(INFO, "")
# Initialize if None
train_config = ConfigRecord() if train_config is None else train_config
evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
result = Result()
t_start = time.time()
# Evaluate starting global parameters
if evaluate_fn:
res = evaluate_fn(0, initial_arrays)
log(INFO, "Initial global evaluation results: %s", res)
if res is not None:
result.evaluate_metrics_serverapp[0] = res
arrays = initial_arrays
for current_round in range(1, num_rounds + 1):
log(INFO, "")
log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
# -----------------------------------------------------------------
# --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
# -----------------------------------------------------------------
# Call strategy to configure training round
# Send messages and wait for replies
train_replies = grid.send_and_receive(
messages=self.configure_train(
current_round,
arrays,
train_config,
grid,
),
timeout=timeout,
)
# Aggregate train
agg_arrays, agg_train_metrics = self.aggregate_train(
current_round,
train_replies,
)
# Log training metrics and append to history
if agg_arrays is not None:
result.arrays = agg_arrays
arrays = agg_arrays
if agg_train_metrics is not None:
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
result.train_metrics_clientapp[current_round] = agg_train_metrics
# Log to W&B
wandb.log(dict(agg_train_metrics), step=current_round)
# -----------------------------------------------------------------
# --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
# -----------------------------------------------------------------
# Call strategy to configure evaluation round
# Send messages and wait for replies
evaluate_replies = grid.send_and_receive(
messages=self.configure_evaluate(
current_round,
arrays,
evaluate_config,
grid,
),
timeout=timeout,
)
# Aggregate evaluate
agg_evaluate_metrics = self.aggregate_evaluate(
current_round,
evaluate_replies,
)
# Log training metrics and append to history
if agg_evaluate_metrics is not None:
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
# Log to W&B
wandb.log(dict(agg_evaluate_metrics), step=current_round)
# -----------------------------------------------------------------
# --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
# -----------------------------------------------------------------
# Centralized evaluation
if evaluate_fn:
log(INFO, "Global evaluation")
res = evaluate_fn(current_round, arrays)
log(INFO, "\t└──> MetricRecord: %s", res)
if res is not None:
result.evaluate_metrics_serverapp[current_round] = res
# Maybe save to disk if new best is found
self._update_best_acc(current_round, res["accuracy"], arrays)
# Log to W&B
wandb.log(dict(res), step=current_round)
log(INFO, "")
log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
log(INFO, "")
log(INFO, "Final results:")
log(INFO, "")
for line in io.StringIO(str(result)):
log(INFO, "\t%s", line.strip("\n"))
log(INFO, "")
return result
Avec la stratégie CustomFedAdagrad étendue définie, nous avons besoin maintenant de définir l’emplacement où les checkpoints de modèle seront sauvegardés ainsi que le nom des runs dans W&B. Nous devons appeler la méthode set_save_path après avoir instancié la stratégie et avant d’appeler la méthode start. Dans server_app.py, nous pouvons créer un nouveau répertoire appelé results et puis un sous-répertoire avec l’heure courante pour stocker les résultats de chaque run. Nous pouvons ensuite appeler la méthode set_save_path. Dans ce tutoriel, nous créons le répertoire basé sur la date et l’heure actuelles, cela signifie que chaque fois que vous faites flwr run, un nouveau répertoire sera utilisé. Voyons comment cela ressemble en code :
# ... unchanged
# add this to the imports
from datetime import datetime
from pathlib import Path
# ... unchanged
@app.main()
def main(grid: Grid, context: Context) -> None:
"""Main entry point for the ServerApp."""
# ... unchanged
# Initialize FedAdagrad strategy
# strategy = CustomFedAdagrad( ... )
# Get the current date and time
current_time = datetime.now()
run_dir = current_time.strftime("%Y-%m-%d/%H-%M-%S")
# Save path is based on the current directory
save_path = Path.cwd() / f"outputs/{run_dir}"
save_path.mkdir(parents=True, exist_ok=False)
# Set the path where results and model checkpoints will be saved
strategy.set_save_path(save_path)
# ... rest unchanged
Finalement, lançons l’application Flower localement. Cette étude écrit les checkpoints de modèle dans votre répertoire de travail et enregistre des métriques dans Weights & Biases, ce qui facilite la visualisation des résultats.
$ flwr run . local --stream
Le flwr run . local soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Voir Exécuter Flower Localement avec un SuperLink Géré pour le workflow local complet.
Après avoir démarré l’exécution, vous remarquerez deux choses :
Un nouveau répertoire sera créé dans
sorties/YYYY-MM-DD/HH-MM-SSoùYYYY-MM-DD/HH-MM-SSest la date et l’heure actuelles. Ce répertoire contiendra les checkpoints de modèle sauvegardés pendant l’exécution. Rappelez-vous que un checkpoint est sauvegardé chaque fois qu’un nouveau meilleur taux d’exactitude est trouvé lors de la phase d’évaluation centralisée.Une nouvelle exécution sera créée dans votre W&B project où vous pouvez visualiser les métriques enregistrées pendant l’exécution.
Félicitations ! Vous avez réussi à créer une stratégie Flower personnalisée en surchargeant la méthode start. Vous avez également appris à enregistrer des métriques dans Weights & Biases et à sauvegarder les checkpoints de modèle sur le disque.
Récapitulation¶
Dans cette étude, nous avons vu comment personnaliser la méthode start d’une stratégie Flower. Cette méthode est l’entrée principale de toute stratégie et contient la logique pour exécuter le processus d’apprentissage fédéré. Dans cette étude, vous avez appris à enregistrer les métriques dans Weights & Biases et à sauvegarder les checkpoints de modèle sur le disque.
Dans le prochain tutoriel, vous communiquerez des informations supplémentaires entre la ClientApp et la ServerApp en les sérialisant et en les envoyant dans un Message.
Prochaines étapes¶
Avant de continuer, assurez-vous d’adhérer à la communauté Flower sur Flower Discuss (Join Flower Discuss) et Slack (Join Slack).
Il existe un canal Slack dédié si vous avez besoin d’aide, mais nous aimerions également entendre qui vous êtes dans #introductions !
Le Flower Collaborative AI Tutorial - Part 6: Communicate custom Messages montre comment personnaliser ce que le ClientApp envoie à nouveau au ServerApp.