保存和加载模型检查点#

Flower 不会在服务器端自动保存模型更新。本指南将介绍在 Flower 中保存(和加载)模型检查点的步骤。

模型检查点#

模型更新可通过自定义 Strategy 方法在服务器端持久化。实现自定义策略始终是一种选择,但在许多情况下,简单地自定义现有策略可能更方便。下面的代码示例定义了一个新的 SaveModelStrategy,它自定义了现有的内置 FedAvg 策略。特别是,它通过调用基类(FedAvg)中的 aggregate_fit 来定制 aggregate_fit。然后继续保存返回的(聚合)参数,然后再将这些聚合参数返回给调用者(即服务器):

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Save aggregated_ndarrays
            print(f"Saving round {server_round} aggregated_ndarrays...")
            np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)

        return aggregated_parameters, aggregated_metrics

# Create strategy and run server
strategy = SaveModelStrategy(
    # (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)

保存和加载 PyTorch 检查点#

与前面的例子类似,但多了几个步骤,我们将展示如何存储一个 PyTorch 检查点,我们将使用 torch.save 函数。首先,aggregate_fit 返回一个 Parameters 对象,它必须被转换成一个 NumPy ndarray 的列表,然后这些对象按照 OrderedDict 类结构被转换成 PyTorch state_dict 对象。

net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

            # Save the model
            torch.save(net.state_dict(), f"model_round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

要加载进度,只需在代码中添加以下几行。请注意,这将遍历所有已保存的检查点,并加载最新的检查点:

list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)

Return/use this object of type Parameters wherever necessary, such as in the initial_parameters when defining a Strategy.