Utilisez les stratégies¶
Flower permet la personnalisation complète du processus d’apprentissage à travers l’abstraction Strategy. Un certain nombre de stratégies intégrées strategies sont fournies dans le cadre de base.
Il existe quatre façons de personnaliser la manière dont Flower orchestre le processus d’apprentissage côté serveur :
Utilisez une stratégie existante, par exemple
FedAvgPersonnalisez une stratégie existante avec des fonctions d’appel-back pour son méthode
startPersonnalisez une stratégie existante en surchargeant l’une ou plusieurs de ses méthodes.
Implémentez une nouvelle stratégie à partir de zéro
Note
Les stratégies intégrées de Flower communiquent un ArrayRecord et un MetricRecord dans un Message au ClientApps. Les stratégies attendent des réponses contenant un MetricRecord et, si c’est une ronde où les clients effectuent l’entraînement local, un ArrayRecord également. L’abstraction Message permet des enregistrements illimités de tout type. Si vous voulez communiquer plusieurs enregistrements, vous devrez soit élargir une stratégie existante, soit en implémenter une nouvelle à partir de zéro.
Utilise une stratégie existante¶
Flower est accompagné d’un certain nombre de popularités d’apprentissage fédéré Strategies qui peuvent être instantiées comme suit dans un simple ServerApp:
# Create ServerApp
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
"""Main entry point for the ServerApp."""
# Load global model
global_model = Net()
arrays = ArrayRecord(global_model.state_dict())
# Initialize FedAvg strategy with default settings
strategy = FedAvg()
# Start strategy, run FedAvg for `num_rounds`
result = strategy.start(
grid=grid,
initial_arrays=arrays,
)
Dans le code ci-dessus, l’instantiation de FedAvg ne déclenche pas la logique intégrée à la stratégie (c’est-à-dire l’échantillonnage des nœuds, la communication Message, la mise en forme d’agrégat, etc.). Pour y parvenir, nous devons exécuter le méthode start.
Le code ci-dessus est très minimal, utilise les paramètres par défaut pour ServerApp et ne passe que les arguments requis à la méthode FedAvg. Voyons un peu plus en détail quelles options nous avons lors de l’instanciation des stratégies et lors de leur exécution.
Paramétrer une stratégie existante¶
Le constructeur des stratégies accepte différents paramètres en fonction, principalement, de l’algorithme d’agrégation qu’ils mettent en œuvre. Par exemple, FedAdam accepte des arguments supplémentaires (c’est-à-dire pour appliquer la vitesse pendant l’agrégation) par rapport à ceux que FedAvg exige. Cependant, communs à toutes les stratégies sont les paramètres de contrôle du fait que les nœuds qui exécutent des instances ClientApp soient échantillonnés. Voyons ce ensemble d’arguments :
from flwr.serverapp.strategy import FedAvg
# Initialize FedAvg strategy
strategy = FedAvg(
fraction_train=0.5, # fraction of nodes to involve in a round of training
fraction_evaluate=1.0, # fraction of nodes to involve in a round of evaluation
min_available_nodes=100, # minimum connected nodes required before FL starts
)
Pour la plupart des applications, spécifier un ou plusieurs des arguments présentés ci-dessus est suffisant. Une stratégie de type Flower définie comme celle-ci attendrait 100 nœuds connectés avant que toute étape fédérée ne commence. Ensuite, 50% des nœuds connectés seront impliqués dans une étape d’entraînement fédéré, suivie d’une autre étape d’évaluation fédérée où tous les nœuds connectés participeront. Il est possible de définir les arguments min_train_nodes et min_evaluate_nodes pour un contrôle plus fin.
En outre des arguments permettant de personnaliser la façon dont la stratégie effectue l’échantillonnage, nous pouvons définir à l’heure de construction les clés qui seront utilisées pour communiquer différentes informations entre la stratégie et l’application ServerApp et l’application ClientApp. Notez que ces clés sont utilisées dans les deux types d’étapes au sein de la logique de stratégie start, c’est-à-dire l’entraînement fédéré et l’évaluation fédérée.
from flwr.serverapp.strategy import FedAvg
# Initialize FedAvg strategy
# Here we define our own keys instead of using the default
strategy = FedAvg(
arrayrecord_key="my-arrays",
configrecord_key="super-config",
weighted_by_key="num-batches",
)
arrayrecord_key: laMessagetransmise auClientAppcontiendra unArrayRecordcontenant les tableaux du modèle global sous cette clé. Par défaut, la clé est"arrays".configrecord_key: laMessagetransmise auClientAppcontiendra unConfigRecordcontenant les paramètres de configuration. Par défaut, la clé est"config".weighted_by_key: Une clé à l’intérieur duMetricRecordque leClientAppretourne en tant que partie de sa réponse auServerApp. La valeur sous cette clé est utilisée pour effectuer une agrégation pondérée desMetricRecordset, après un tour d’entraînement fédéré, desArrayRecords. La valeur par défaut est"num-examples".
Avec une stratégie définie comme ci-dessus, l’application ClientApp devrait recevoir un message Message avec la structure suivante:
# The content of a Message arriving to the ClientApp will have
# the following structure and using the keys defined in the strategy
msg = Message(
# ....
content=RecordDict(
{
"my-arrays": ArrayRecord(...),
"super-config": ConfigRecord(...),
}
)
)
# The reply Message should contain a MetricRecord and inside it
# an item associated with the key used to initialize the strategy
reply_msg_content = RecordDict(
{
"locally-updated-params": ArrayRecord(...),
"local-metrics": MetricRecord(
{
"num-batches": N,
# ... Other metrics
}
),
}
)
Note
Même si les stratégies fixent les clés utilisées pour transmettre la ArrayRecord et la MetricRecord au ClientApps, les réponses que ces derniers envoient à la ServerApp peuvent utiliser des clés différentes. Dans l’exemple de code ci-dessus, nous avons utilisé "locally-updated-params" et "local-metrics". Cependant, toutes les ClientApps doivent utiliser les mêmes clés dans leur réponse Messages, sinon l’agrégation des réponses (ArrayRecord et MetricRecord) ne peut pas être effectuée.
Finalement, le constructeur de stratégie permet également de passer deux appels-back pour contrôler comment les messages que ClientApps envoient sont agrégés. Suivez le guide Aggréguez les résultats d’évaluation pour une marche à suivre sur la façon de définir ces appels-back.
Utiliser la méthode start de la stratégie¶
Comme mentionné plus tôt, c’est la méthode start de la stratégie qui démarre le processus d’apprentissage fédéré. Voyons ce que chaque argument passé à cette méthode représente.
Astuce
Vérifiez l’explorateur Flower Strategy Abstraction pour un aperçu approfondi de la façon dont les différentes étapes implémentées comme partie de la méthode start fonctionnent.
Les seuls arguments requis sont le Grid et un ArrayRecord. Le premier est un objet qui sera utilisé pour interagir avec les nœuds exécutant l’application ClientApp pour les impliquer dans une ronde d’entraînement/évaluation/requête ou autre. Le second contient les paramètres du modèle que nous voulons fédérer. Par conséquent, une exécution minimale de la méthode start ressemble à ceci:
# Start strategy
result = strategy.start(
grid=grid,
initial_arrays=ArrayRecord(...),
)
Dans la plupart des configurations, nous voulons personnaliser la façon dont la méthode start est exécutée en passant également le nombre de tours à exécuter et un couple d’objets ConfigRecord à envoyer à l’application ClientApp lors d’une étape d’entraînement et d’évaluation respectivement.
# Define configs to send to ClientApp
train_cfg = ConfigRecord({"lr": 0.1, "optim": "adam"})
eval_cfg = ConfigRecord({"max-steps": 500, "local-checkpoint": True})
# Start strategy
result = strategy.start(
grid=grid,
initial_arrays=ArrayRecord(...),
train_config=train_cfg,
evaluate_config=eval_cfg,
num_rounds=100,
)
La méthode start permet également de limiter pendant combien de temps l’application strategy attend des réponses de l’application ClientApps avant de poursuivre avec le reste des étapes. Cela peut être contrôlé à l’aide de l’argument timeout (qui par défaut est fixé à 3600s, c’est-à-dire 1h). Par exemple, si nous voulons augmenter la durée d’attente à 2 heures, nous ferions:
# Define configs to send to ClientApp
train_cfg = ConfigRecord({"lr": 0.1, "optim": "adam"})
eval_cfg = ConfigRecord({"max-steps": 500, "local-checkpoint": True})
# Start strategy
result = strategy.start(
grid=grid,
initial_arrays=ArrayRecord(...),
train_config=train_cfg,
evaluate_config=eval_cfg,
num_rounds=100,
timeout=7200, # 2 hours
)
Enfin, l’argument final dans start est nommé evaluate_fn et permet de passer une fonction callback pour évaluer le modèle agrégé sur des données locales que le ServerApp pourrait avoir accès. Cette fonctionnalité est également utile si vous souhaitez sauvegarder le modèle global à la fin de chaque tour (ou tous les N tours). Voyons ce que est la signature de cette fonction callback et comment l’utiliser :
# Callback definition. The function can have any name
# but the arguments are fixed
def my_callback(server_round: int, arrays: ArrayRecord) -> MetricRecord:
"""Evaluate model on central data."""
# Save checkpoint
state_dict = arrays.to_torch_state_dict()
torch.save(state_dict, f"model_at_round_{server_round}.pt")
# eval model on local data
model = MyModel()
model.load_state_dict(state_dict)
acc, loss = test(model, ...)
# Return MetricRecord
return MetricRecord({"acc": acc, "loss": loss})
# Pass the callback to the start method
strategy.start(..., evaluate_fn=my_callback)
Astuce
Prenez un coup d’œil à l’exemple quickstart-pytorch sur GitHub pour un exemple complet utilisant plusieurs des concepts présentés dans ce guide.