Save and load model checkpoints¶
This how-to guide describes the steps to save (and load) model checkpoints in
ClientApp
and ServerApp
.
How to save model checkpoints in ClientApp
¶
Model updates are saved in ArrayRecord
and transmitted between ServerApp
and ClientApp
. To save model checkpoints in ClientApp
, you need to
convert the ArrayRecord
into a format compatible with your ML framework (e.g.,
PyTorch, TensorFlow, or NumPy). Include the following code in your functions registered
with the ClientApp
(e.g., in your training function decorated with
@app.train()
):
PyTorch
# Convert ArrayRecord to PyTorch state dict.
state_dict = arrays.to_torch_state_dict()
# Save model weights to disk
torch.save(state_dict, "model.pt")
TensorFlow
# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()
# Load weights to a keras model
model.set_weights(ndarrays)
# Save model weights to disk
model.save("model.keras")
NumPy
# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()
# Save model weights to disk
numpy.savez("model.npz", *ndarrays)
How to save model checkpoints in ServerApp
¶
To save model checkpoints in ServerApp
across different FL rounds, you can
implement this in a customized evaluate_fn
and pass it to the strategy’s
start
method. Here’s an example showing how to save the global PyTorch
model:
def get_evaluate_fn(save_every_round, total_round, save_path):
def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
# Save model every `save_every_round` round and for the last round
if server_round != 0 and (
server_round == total_round or server_round % save_every_round == 0
):
# Convert ArrayRecord to PyTorch state dict
state_dict = arrays.to_torch_state_dict()
# Save model weights to disk
torch.save(state_dict, f"{save_path}/model_{server_round}.pt")
return MetricRecord()
return evaluate
Then, pass it to the start
method of the defined strategy:
strategy.start(
...,
evaluate_fn=get_evaluate_fn(save_every_round, total_round, save_path),
)
If you are interested, checkout the details in Advanced PyTorch Example and Advanced TensorFlow Example.