Save and load model checkpoints¶
Ce guide décrit les étapes pour sauvegarder (et charger) les checkpoints de modèle dans ClientApp et ServerApp.
Comment sauvegarder les checkpoints de modèle dans ClientApp¶
Les mises à jour du modèle sont enregistrées dans ArrayRecord et transmises entre ServerApp et ClientApp. Pour enregistrer les checkpoints de modèle dans ClientApp, vous devez convertir le format ArrayRecord compatible avec votre framework ML (par exemple, PyTorch, TensorFlow ou NumPy). Incluez le code suivant dans vos fonctions enregistrées avec le ClientApp (par exemple, dans votre fonction d’entraînement décorée avec @app.train()) :
PyTorch
# Convert ArrayRecord to PyTorch state dict.
state_dict = arrays.to_torch_state_dict()
# Save model weights to disk
torch.save(state_dict, "model.pt")
TensorFlow
# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()
# Load weights to a keras model
model.set_weights(ndarrays)
# Save model weights to disk
model.save("model.keras")
NumPy
# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()
# Save model weights to disk
numpy.savez("model.npz", *ndarrays)
Comment sauvegarder les checkpoints de modèle dans ServerApp¶
Pour enregistrer les checkpoints de modèle dans ServerApp au cours des différentes rondes FL, vous pouvez implémenter cela dans une stratégie personnalisée evaluate_fn et la passer à la méthode start du stratagème. Voici un exemple montrant comment sauvegarder le modèle PyTorch global :
def get_evaluate_fn(save_every_round, total_round, save_path):
def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
# Save model every `save_every_round` round and for the last round
if server_round != 0 and (
server_round == total_round or server_round % save_every_round == 0
):
# Convert ArrayRecord to PyTorch state dict
state_dict = arrays.to_torch_state_dict()
# Save model weights to disk
torch.save(state_dict, f"{save_path}/model_{server_round}.pt")
return MetricRecord()
return evaluate
Puis, passez-le à la méthode start de la stratégie définie :
strategy.start(
...,
evaluate_fn=get_evaluate_fn(save_every_round, total_round, save_path),
)
Si vous êtes intéressé, consultez les détails dans Advanced PyTorch Example.