Concevoir des ClientApps étatiques¶
Par conception, les ClientApp sont sans état. Cela signifie que l’objet ClientApp est recréé chaque fois qu’un nouveau Message doit être traité. Ce comportement est identique à celui du Simulation Runtime et du Deployment Runtime dans Flower. Pour le premier, cela nous permet de simuler la mise en œuvre d’un grand nombre de nœuds sur une seule machine ou sur plusieurs machines. Pour le second, cela permet à chaque SuperNode d’être partie de plusieurs runs, chacun exécutant un différent ClientApp.
Lorsqu’une ClientApp est exécutée, elle reçoit un Context. Ce contexte est unique pour chaque ClientApp, ce qui signifie que les exécutions ultérieures du même ClientApp depuis le même nœud recevront le même objet Context. Dans la Context, l’attribut .state (de type RecordDict) peut être utilisé pour stocker des informations que vous souhaitez que le ClientApp ait accès pendant toute la durée de l’exécution. Cela pourrait être tout, depuis les résultats intermédiaires tels que l’historique des pertes d’entraînement (par exemple, sous forme de liste de valeurs float avec une nouvelle entrée ajoutée chaque fois que le ClientApp est exécuté), certaines parties du modèle qui doivent persister côté client, ou d’autres objets Python arbitraires. Ces éléments devraient être sérialisés avant d’être enregistrés dans le contexte.
Sauvegarder des métriques dans le contexte¶
Cette section démontrera comment sauvegarder des métriques telles que les valeurs de précision/perte pour le Context afin qu’elles puissent être utilisées dans les exécutions ultérieures de la ClientApp.
Commencez par un simple paramétrage dans lequel ClientApp est défini comme suit. La fonction train() ne génère qu’un nombre aléatoire, l’imprime et renvoie un message vide.
Astuce
Vous pouvez créer un projet PyTorch avec des composants prêts à l’emploi et autres en exécutant flwr new.
import random
from flwr.app import Context, Message, RecordDict
from flwr.clientapp import ClientApp
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Generate a random integer between 0 and 10
n = random.randint(0, 10)
print(n)
return Message(RecordDict(), reply_to=msg)
Avec le minimal ClientApp ci-dessus, chaque fois que vous adressez une Message à cette fonction train, un nouveau nombre entier aléatoire sera généré et imprimé. Supposons que nous voulions sauvegarder ce nombre aléatoire généré et l’ajouter à une liste qui persiste dans le Context. De cette façon, chaque fois que la fonction est exécutée,, elle imprime l’historique des nombres entiers aléatoires. Voyons comment cela ressemble en code:
Astuce
Rappelez-vous, l’attribut state d’un objet Context est du type RecordDict, ce qui signifie que vous pouvez y sauvegarder non seulement MetricRecord comme dans l’exemple ci-dessous, mais aussi ArrayRecord et ConfigRecord objets.
import random
from flwr.app import Context, Message, RecordDict
from flwr.clientapp import ClientApp
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Generate a random integer between 0 and 10
n = random.randint(0, 10)
print(n)
# Append to list in context or initialize if it doesn't exist
if "random-metrics" not in context.state:
# Initialize MetricRecord in state
context.state["random-metrics"] = MetricRecord({"random-ints": []})
# Append to record
context.state["random-metrics"]["random-ints"].append(n)
# Print history
print(context.state["random-metrics"])
return Message(RecordDict(), reply_to=msg)
Si vous lancez une application Flower incluant la logique ci-dessus dans votre ClientApp et ayant juste deux clients dans votre fédération échantillonnés à chaque tour, vous verrez un résultat similaire à celui-ci dessous. Notez comment après chaque tour le record random-metrics dans le Context obtient une valeur supplémentaire ? Remarquez que, en Simulation Runtime, l’ordre des messages de journal peut changer à chaque tour en raison de l’ordonnancement aléatoire des clients simulés.
# round 1
config_records={'random-metrics': {'random-ints': [2]}}
config_records={'random-metrics': {'random-ints': [7]}}
# round 2
config_records={'random-metrics': {'random-ints': [2, 5]}}
config_records={'random-metrics': {'random-ints': [7, 4]}}
# round 3
config_records={'random-metrics': {'random-ints': [2, 5, 1]}}
config_records={'random-metrics': {'random-ints': [7, 4, 2]}}
Sauvegarder les paramètres du modèle dans le contexte¶
En utilisant ConfigRecord ou MetricRecord pour sauvegarder des composants « simples » est acceptable (par exemple, float, entier, booléen, chaîne, octets et listes de ces types. Notez que MetricRecord ne prend en charge que les float, les entiers et les listes de ces types). Flower a un type spécifique de registre, ArrayRecord, pour stocker les paramètres du modèle ou plus généralement des tableaux de données.
Voyons quelques exemples de sauvegarde d’arrays NumPy puis comment sauvegarder les paramètres des modèles PyTorch et TensorFlow.
Note
Les exemples suivants omettent la définition d’une ClientApp pour garder les blocs de code concis. Pour utiliser des objets ArrayRecord dans votre ClientApp, vous pouvez suivre les mêmes principes que ceux énoncés plus tôt.
Sauvegarder des arrays NumPy dans le contexte¶
Les éléments stockés dans un ArrayRecord sont du type Array, qui est une structure de données qui contient bytes et des métadonnées pouvant être utilisées pour la désérialisation. Voyons comment créer un Array à partir d’un array NumPy et l’insérer dans un ArrayRecord.
Note
Les objets Array transportent les octets comme principal payload et des métadonnées supplémentaires pour utiliser lors de la désérialisation. Vous pouvez également mettre en œuvre votre propre sérialisation/désérialisation.
Voyons comment utiliser ces fonctions pour stocker un array NumPy dans le contexte.
import numpy as np
from flwr.app import Array, ArrayRecord, Context
# Let's create a simple NumPy array
arr_np = np.random.randn(3, 3)
# If we print it
# array([[-1.84242409, -1.01539537, -0.46528405],
# [ 0.32991896, 0.55540414, 0.44085534],
# [-0.10758364, 1.97619858, -0.37120501]])
# Now, let's serialize it and construct an Array
arr = Array(arr_np)
# If we print it (note the binary data)
# Array(dtype='float64', shape=[3, 3], stype='numpy.ndarray', data=b'\x93NUMPY\x01\x00v\x00...)
# It can be inserted in an ArrayRecord like this
arr_record = ArrayRecord()
arr_record["my_array"] = arr
# You can also do it via the constructor
# arr_record = ArrayRecord({"my_array": arr})
# If you don't need the keys, you can also pass a list of Numpy arrays
# arr_record = ArrayRecord([arr_np])
# Then, it can be added to the state in the context
context.state["some_parameters"] = arr_record
Pour extraire les données dans un ArrayRecord, il suffit de désérialiser l’array d’intérêt. Par exemple, en suivant l’exemple ci-dessus :
# Get Array from context
arr = context.state["some_parameters"]["my_array"]
# If you constructed the ArrayRecord with a list of Numpy, then do
# arr = context.state["some_parameters"].to_numpy_ndarrays()[0] # get first array
# Deserialize it
arr_deserialized = arr.numpy()
# If we print it (it should show the exact same values as earlier)
# array([[-1.84242409, -1.01539537, -0.46528405],
# [ 0.32991896, 0.55540414, 0.44085534],
# [-0.10758364, 1.97619858, -0.37120501]])
Sauvegarder les paramètres PyTorch dans le contexte¶
Flower offre des utilitaires d’une ligne pour convertir les paramètres du modèle PyTorch en objets ArrayRecord. Voyons comment faire cela.
import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.app import ArrayRecord
class Net(nn.Module):
"""A very simple model"""
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 32, 5)
self.fc = nn.Linear(1024, 10)
def forward(self, x):
x = F.relu(self.conv(x))
return self.fc(x)
# Instantiate model as usual
model = Net()
# Save the state_dict into a single ArrayRecord
arr_record = ArrayRecord(model.state_dict())
# Add to a context
context.state["net_parameters"] = arr_record
Supposons maintenant que vous souhaitez appliquer les paramètres stockés dans votre contexte à une nouvelle instance du modèle (comme il se produit chaque fois qu’un ClientApp est exécuté). Vous aurez besoin de:
Récupérez le
ArrayRecorddu contexteConstruisez un
state_dictet le chargez
state_dict = {}
# Extract record from context
arr_record = context.state["net_parameters"]
# Deserialize the parameters
state_dict = arr_record.to_torch_state_dict()
# Apply state dict to a new model instance
model_ = Net()
model_.load_state_dict(state_dict)
# now this model has the exact same parameters as the one created earlier
# You can verify this by doing
for p, p_ in zip(model.state_dict().values(), model_.state_dict().values()):
assert torch.allclose(p, p_), "`state_dict`s do not match"
Et voilà ! Rappelez-vous que même si cet exemple montre comment stocker l’ensemble du state_dict dans un ArrayRecord, vous pouvez simplement sauvegarder une partie. Le processus serait identique, mais vous pourriez avoir besoin d’ajuster la façon dont il est chargé dans un modèle existant à l’aide des API PyTorch.
Sauvegarder les paramètres TensorFlow/Keras dans le contexte¶
Suivez les mêmes étapes que celles effectuées ci-dessus mais remplacez la logique state_dict par simplement get_weights() pour convertir les paramètres du modèle en une liste d’arrays NumPy pouvant ensuite être sauvegardés dans un ArrayRecord. Ensuite, après désérialisation, utilisez set_weights() pour appliquer les nouveaux paramètres à un modèle.
import tensorflow as tf
from flwr.app import ArrayRecord
# Define a simple model
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
# Save model weights into an ArrayRecord and add to a context
context.state["model_weights"] = ArrayRecord(model.get_weights())
...
# Extract record from context and apply to the model
model.set_weights(context.state["model_weights"].to_numpy_ndarrays())