Save and load model checkpoints¶

This how-to guide describes the steps to save (and load) model checkpoints in ClientApp and ServerApp.

How to save model checkpoints in ClientApp¶

Model updates are saved in ArrayRecord and transmitted between ServerApp and ClientApp. To save model checkpoints in ClientApp, you need to convert the ArrayRecord into a format compatible with your ML framework (e.g., PyTorch, TensorFlow, or NumPy). Include the following code in your functions registered with the ClientApp (e.g., in your training function decorated with @app.train()):

PyTorch

# Convert ArrayRecord to PyTorch state dict.
state_dict = arrays.to_torch_state_dict()

# Save model weights to disk
torch.save(state_dict, "model.pt")

TensorFlow

# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()

# Load weights to a keras model
model.set_weights(ndarrays)

# Save model weights to disk
model.save("model.keras")

NumPy

# Convert ArrayRecord to NumPy ndarrays
ndarrays = arrays.to_numpy_ndarrays()

# Save model weights to disk
numpy.savez("model.npz", *ndarrays)

How to save model checkpoints in ServerApp¶

To save model checkpoints in ServerApp across different FL rounds, you can implement this in a customized evaluate_fn and pass it to the strategy’s start method. Here’s an example showing how to save the global PyTorch model:

def get_evaluate_fn(save_every_round, total_round, save_path):
    def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
        # Save model every `save_every_round` round and for the last round
        if server_round != 0 and (
            server_round == total_round or server_round % save_every_round == 0
        ):
            # Convert ArrayRecord to PyTorch state dict
            state_dict = arrays.to_torch_state_dict()

            # Save model weights to disk
            torch.save(state_dict, f"{save_path}/model_{server_round}.pt")

        return MetricRecord()

    return evaluate

Then, pass it to the start method of the defined strategy:

strategy.start(
    ...,
    evaluate_fn=get_evaluate_fn(save_every_round, total_round, save_path),
)

If you are interested, checkout the details in Advanced PyTorch Example and Advanced TensorFlow Example.