Mettre à niveau vers l’API Message¶
Bienvenue dans le guide de migration pour mettre à jour vos applications Flower pour utiliser l’API Message de Flower ! Ce guide vous guidera tout au long des étapes nécessaires pour passer d’applications Flower basées sur Strategy et NumPyClient à celles qui utilisent la nouvelle API Message. Ce guide est pertinent lors de la mise à jour des applications Flower pré-1.21 vers la dernière version stable.
Diveons-y !
Astuce
Si vous souhaitez créer une nouvelle application Flower en utilisant le message API, exécutez la commande flwr new et choisissez le modèle approprié. Alternativement, vous pouvez peut-être vouloir jeter un œil à l’exemple quickstart-pytorch.
Résumé des changements¶
Des milliers d’applications Flower ont été créées en utilisant les stratégies et les abstractions NumPyClient. Avec l’introduction du message API, ces applications peuvent désormais profiter d’une couche de communication plus puissante et flexible avec l’abstraction Message étant son pilier. Les messages remplacent les anciens structures de données FitIns et FitRes (et leurs équivalents pour les autres opérations) par une seule structure de données unifiée et plus versatile.
Pour tirer pleinement parti des nouveaux modèles de communication basés sur message, vous devrez mettre à jour le code de votre application pour utiliser les nouveaux modèles de communication. Ce guide vous montrera comment faire :
Mettez à jour votre
ServerApppour utiliser les nouvelles stratégies basées surMessage. Vous n’aurez plus besoin d’utiliser leserver_fn.Mettez à jour votre
ClientApppour qu’elle opère directement sur des objetsMessagereçus duServerApp. Vous serez en mesure de conserver la plupart du code de votre implémentationNumPyClient, mais vous n’aurez plus besoin de créer une nouvelle classe ou d’utiliser la fonction d’assistanceclient_fn.
Astuce
Les principaux objets Message transportent des données de type RecordDict. Vous pouvez les considérer comme un dictionnaire pouvant contenir d’autres types de records, à savoir ArrayRecord, MetricRecord, et ConfigRecord. Voyons quelques exemples courts pour comprendre l’utilisation prévue derrière chaque type de record.
from flwr.app import ArrayRecord, MetricRecord, ConfigRecord, RecordDict
# ConfigRecord can be used to communicate configs between ServerApp and ClientApp
# They can hold scalars, but also strings and booleans
config = ConfigRecord(
{"batch_size": 32, "use_augmentation": True, "data-path": "/my/dataset"}
)
# MetricRecords are designed for scalar-based metrics only (i.e. int/float/list[int]/list[float])
# By limiting the types Flower can aggregate MetricRecords automatically
metrics = MetricRecord({"accuracy": 0.9, "losses": [0.1, 0.001], "perplexity": 2.31})
# ArrayRecord objects are designed to communicate arrays/tensors/weights from ML models
array_record = ArrayRecord(my_model.state_dict()) # for a PyTorch model
array_record_other = ArrayRecord(my_model.to_numpy_ndarrays()) # for other ML models
# A RecordDict is like a dictionary that holds named records.
# This is the main payload of a Message
rd = RecordDict({"my-config": config, "metrics": metrics, "my-model": array_record})
Référez-vous à la documentation pour chacun des records pour obtenir les détails sur leur construction et adaptation à votre cas d’utilisation. Dans ce guide, nous ne nous attacherons pas aux spécificités de chaque type de record, mais plutôt au processus de migration global.
Installer la mise à jour¶
Le premier pas consiste à mettre à jour la version Flower définie dans le fichier pyproject.toml de votre application :
dependencies = [
"flwr[simulation]>=1.21.0", # update Flower package
# ...
]
Exécutez ensuite la commande suivante pour installer les dépendances mises à jour :
# Install the app with updated dependencies
$ pip install -e .
Update your ServerApp¶
À partir de Flower 1.21, le SuperLink n’a plus besoin d’une fonction server_fn pour utiliser des stratégies. C’est parce qu’un nouveau ensemble de stratégies (toutes partageant la classe de base commune Strategy) a été créé pour opérer directement sur les objets Message, permettant une approche plus fluide et flexible pour les tours d’apprentissage fédéré.
Note
Les nouvelles stratégies basées sur Message sont localisées dans le module flwr.serverapp.strategy contrairement aux anciennes stratégies qui étaient localisées dans le module flwr.server.strategy. Au fil du temps, plus de stratégies seront ajoutées au module flwr.serverapp.strategy. Les utilisateurs sont encouragés à utiliser ces nouvelles stratégies.
Depuis Flower 1.10, la mise en œuvre recommandée de SuperLink ressemblerait quelque chose comme le code snippet suivant. Naturellement, plus de personnalisation peut être appliquée à la Stratégie par exemple en lisant la config du fichier Context. Mais pour garder les choses focalisées, nous utiliserons un exemple simple et supposerons que nous sommes en train de fédérer un modèle PyTorch.
Note
Context a été déplacé vers flwr.app et ServerApp vers flwr.serverapp. L’importation d’eux depuis flwr.common ou flwr.server est obsolète.
from flwr.common import Context # Deprecated, import from flwr.app instead
from flwr.server import ServerApp # Deprecated, import from flwr.serverapp instead
from flwr.server import ServerAppComponents, ServerConfig, start_server
from flwr.server.strategy import FedAvg
def server_fn(context: Context):
# Instantiate strategy with initial parameters
model = MyModel()
parameters = ndarrays_to_parameters(
[v.cpu().numpy() for v in model.state_dict().values()]
)
strategy = FedAvg(fraction_fit=0.5, initial_parameters=parameters)
# Set number of rounds and return
config = ServerConfig(num_rounds=3)
return ServerAppComponents(config=config, strategy=strategy)
# Create ServerApp with helper function
app = ServerApp(server_fn=server_fn)
Avec Flower 1.21 et ultérieurement, l’équivalent de SuperLink en utilisant le nouveau message API ressemblerait à celui-ci après avoir suivi ces étapes :
Définissez la méthode
mainsous l’attribut@app.main(). Si votreserver_fnlisait des valeurs de configuration dans leContext, vous pouvez toujours procéder ainsi (considérez copier ces lignes directement depuis votre fonctionserver_fn)Instanciez votre modèle comme d’habitude et construisez un objet
ArrayRecordà partir de ses paramètres.Remplacez votre stratégie existante par une des celles du module
flwr.serverapp.strategy. Par exemple, avecFedAvg. Passer les arguments liés au échantillonnage de nœud à la constructeur de votre stratégie.Appelez la méthode
startdu nouveau stratégie en passant à elle l’objetArrayRecordreprésentant l’état initial de votre modèle global, le nombre de tours FL et l’objetGrid(qui est utilisé internement pour communiquer avec les nœuds exécutant laClientApp).
Notez que nous n’avons plus besoin de la fonction server_fn. Le Context reste accessible, ce qui vous permet de personnaliser le comportement du ServerApp au runtime. Avec les nouvelles stratégies, une nouvelle méthode start est disponible. Elle définit une boucle for qui fixe les étapes d’un tour de FL. Par défaut, elle se comporte comme les stratégies d’origine, c’est-à-dire un tour d’entraînement FL suivi d’un tour d’évaluation FL et d’une étape d’évaluation du modèle global. Notez que la méthode start retourne des résultats. Ils sont de type Result et contiennent par défaut le modèle global final (via result.arrays), ainsi que les MetricRecord agrégés des étapes fédérées et, éventuellement, les métriques des étapes d’évaluation effectuées dans le ServerApp.
Note
En plus des méthodes d’aide pour travailler avec les modèles PyTorch, la classe ArrayRecord comporte un couple de méthodes pour convertir tel record en liste d’arrays NumPy (c’est-à-dire vers to_numpy_ndarrays et from_numpy_ndarrays). Vous pouvez choisir ces méthodes si vous ne travaillez pas avec des modèles PyTorch.
Avertissement
Notez que les nouvelles stratégies ont renommé plusieurs arguments liés à l’échantillonnage de nœud/client, remplaçant le terme « fit » par « train » et « clients » par « nodes ». Les arguments suivants ont été renommés :
fraction_fit→fraction_trainmin_fit_clients→min_train_nodesmin_evaluate_clients→min_evaluate_nodesmin_available_clients→min_available_nodes
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
from flwr.serverapp import Grid, ServerApp
from flwr.serverapp.strategy import FedAvg
# Create ServerApp
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
# Defined model to federate and extract parameters
model = MyModel()
arrays = ArrayRecord(global_model.state_dict())
# Instantiate strategy
strategy = FedAvg(fraction_train=0.5)
# Start the strategy
result = strategy.start(
grid=grid,
initial_arrays=arrays,
num_rounds=3,
)
print(result)
Mettez à jour votre ClientApp¶
De même que le ServerApp, le ClientApp n’a plus besoin d’une fonction helper (c’est-à-dire client_fn) qui instancie un objet NumPyClient ou base Client. Au lieu de cela, avec le Message API, vous pouvez définir directement comment l’opération du ClientApp opère sur les objets Message reçus à partir du ServerApp.
Rappelez-vous que NumPyClient venait avec deux méthodes clés intégrées, fit et evaluate, conçues respectivement pour effectuer l’entraînement fédéré et l’évaluation en utilisant les données locales du client. Avec la nouvelle API Message, vous pouvez définir des méthodes similaires directement sur le ClientApp via des décorateurs pour gérer les objets Message entrants.
Voyons un exemple de base montrant d’abord une mise en œuvre minimale basée sur NumPyClient et puis la conception améliorée utilisant l’API Message.
Note
Context a été déplacé vers flwr.app et ClientApp vers flwr.clientapp. L’importation de ceux-ci depuis flwr.common ou flwr.client est obsolète.
from flwr.client import ClientApp # Deprecated, import from flwr.clientapp instead
from flwr.client import NumPyClient
from flwr.common import Context # Deprecated, import from flwr.app instead
from my_utils import train_fn, test_fn, get_weights, set_weights
class MyFlowerClient(NumPyClient):
def __init__(self):
self.model = MyModel()
self.train_loader = DataLoader(...)
self.test_loader = DataLoader(...)
def fit(self, parameters, config):
"""Fit the model to the local data using the parameters sent by ServerApp."""
# Update model with the latest parameters
set_weights(self.model, parameters)
# Train the model locally
train_fn(self.model, self.train_loader)
# Return the updated parameters and number of training examples
return get_weights(self.model), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config):
"""Evaluate the model on the local data using the parameters sent by ServerApp."""
# Update model with the latest parameters
set_weights(self.model, parameters)
# Evaluate the model locally
loss, accuracy = test_fn(self.model, self.test_loader)
# Return the evaluation results
return float(loss), len(self.test_loader.dataset), {"accuracy": float(accuracy)}
def client_fn(context: Context):
# Return an instance of MyFlowerClient
return MyFlowerClient().to_client()
app = ClientApp(client_fn=client_fn)
L’amélioration d’un ClientApp conçu autour des abstractions NumPyClient et client_fn en utilisant le Message API entraînerait les modifications suivantes. Notez que le comportement du ClientApp est défini directement dans ses méthodes (c’est-à-dire une classe secondaire basée sur NumPyClient n’est plus nécessaire).
L’abstraction ClientApp comporte les décorateurs intégrés @app.train et @app.evaluate. Les arguments des méthodes associées ont été unifiés, et ils opèrent tous deux sur des objets Message. Chaque méthode est responsable de gérer les objets Message entrants et de retourner la réponse appropriée (également sous forme de Message). Notez que vous pourrez toujours utiliser les fonctions que vous avez pu écrire, par exemple, pour entraîner votre modèle en utilisant le framework ML de votre choix. Dans cet exemple, celles-ci sont représentées par train_fn et test_fn. Suivez ces étapes pour migrer votre application existante ClientApp :
Introduisez les décorateurs
@app.trainet@app.evaluate, ainsi que leurs méthodes respectives.Copiez les lignes de code que vous aviez dans votre
client_fnpour lire des valeurs de configuration à partir duContextdans vos méthodes d’implémentationtrainetevaluate(crées à l’étape 1).À partir de l’objet
Message, extrayez les éléments pertinents (par exemple, unArrayRecorddéfinissant le modèle global, unConfigRecordcontenant des configurations pour le tour actuelle) à utiliser dans votre logique d’entraînement et d’évaluation.Copiez les lignes appelant les fonctions qui effectuent l’entraînement/évaluation réel (dans le code snippet ci-dessous, nous avons nommé celles-ci
train_fnettest_fn).Sur la base de la méthode, construisez un
RecordDictet utilisez-le pour construire la réponseMessage.
Note
Le payload que les objets Message transportent est du type RecordDict qui peut contenir des enregistrements de type ArrayRecord, MetricRecord et ConfigRecord.
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp
from my_utils import train_fn, test_fn
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context) -> Message:
"""Train the model on local data."""
# Init Model and data loader
train_loader = DataLoader(...)
model = MyModel()
# Read ArrayRecord received from ServerApp
arrays = msg.content["arrays"]
# Load weights to model
model.load_state_dict(arrays.to_torch_state_dict())
# Do local training
train_loss = train_fn(model, train_loader)
# Construct reply Message: arrays and metrics
model_record = ArrayRecord(model.state_dict())
# You can include any metric (scalar or list of scalars)
# relevant to your usecase.
# A weighting metric (`num-examples` by default) is always
# expected by FedAvg to do aggregation
metrics = MetricRecord(
{
"train_loss": train_loss,
"num-examples": len(train_loader.dataset),
}
)
# Construct RecordDict and add ArrayRecord and MetricRecord
content = RecordDict({"arrays": model_record, "metrics": metrics})
return Message(content=content, reply_to=msg)
@app.evaluate()
def evaluate(msg: Message, context: Context) -> Message:
"""Evaluate the model on local data."""
# Identical to @app.train but returning only metrics
# after doing local evaluation
# ...
# Do local evaluation
loss, accuracy = test_fn(model, test_loader)
# Construct reply Message
# Retrun metrics relevant to usecase
# THe weighting metric is also sent and will be used
# to do weighted aggregation of metrics
metrics = MetricRecord(
{
"eval_loss": loss,
"eval_accuracy": accuracy,
"num-examples": len(test_loader.dataset),
}
)
# Construct RecordDict and add MetricRecord
content = RecordDict({"metrics": metrics})
return Message(content=content, reply_to=msg)
Cela conclut le guide de migration. Nous espérons que vous avez trouvé cela utile ! Bonne fédération !