Save and load model checkpoints#

Flower does not automatically save model updates on the server-side. This how-to guide describes the steps to save (and load) model checkpoints in Flower.

Model checkpointing#

Les mises à jour du modèle peuvent être conservées côté serveur en personnalisant les méthodes Strategy. L’implémentation de stratégies personnalisées est toujours possible, mais dans de nombreux cas, il peut être plus pratique de simplement personnaliser une stratégie existante. L’exemple de code suivant définit une nouvelle SaveModelStrategy qui personnalise la stratégie intégrée FedAvg existante. En particulier, il personnalise aggregate_fit en appelant aggregate_fit dans la classe de base (FedAvg). Il continue ensuite à sauvegarder les poids retournés (agrégés) avant de renvoyer ces poids agrégés à l’appelant (c’est-à-dire le serveur) :

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Save aggregated_ndarrays
            print(f"Saving round {server_round} aggregated_ndarrays...")
            np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)

        return aggregated_parameters, aggregated_metrics

# Create strategy and run server
strategy = SaveModelStrategy(
    # (same arguments as FedAvg here)

Save and load PyTorch checkpoints#

Similar to the previous example but with a few extra steps, we’ll show how to store a PyTorch checkpoint we’ll use the function. Firstly, aggregate_fit returns a Parameters object that has to be transformed into a list of NumPy ndarray’s, then those are transformed into the PyTorch state_dict following the OrderedDict class structure.

net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

            # Save the model
  , f"model_round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

Pour charger ta progression, il te suffit d’ajouter les lignes suivantes à ton code. Note que cela va itérer sur tous les points de contrôle sauvegardés et charger le plus récent :

list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)

Return/use this object of type Parameters wherever necessary, such as in the initial_parameters when defining a Strategy.