모델 체크포인트 저장 및 로드#

Flower는 서버 측에서 모델 업데이트를 자동으로 저장하지 않습니다. 이 사용법 가이드에서는 Flower에서 모델 체크포인트를 저장(및 로드)하는 단계에 대해 설명합니다.

모델 체크포인트#

Strategy 메소드를 사용자 지정하여 서버 측에서 모델 업데이트를 지속할 수 있습니다. 사용자 지정 전략을 구현하는 것은 항상 옵션이지만 대부분의 경우 기존 전략을 간단히 사용자 지정하는 것이 더 편리할 수 있습니다. 다음 코드 예시는 기존의 기본 제공 FedAvg 전략을 사용자 지정한 새로운 SaveModelStrategy`를 정의합니다. 특히, 기본 클래스(:code:`FedAvg)에서 :code:`aggregate_fit`을 호출하여 :code:`aggregate_fit`을 사용자 지정합니다. 그런 다음 호출자(즉, 서버)에게 집계된 가중치를 반환하기 전에 반환된(집계된) 가중치를 계속 저장합니다:

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Save aggregated_ndarrays
            print(f"Saving round {server_round} aggregated_ndarrays...")
            np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)

        return aggregated_parameters, aggregated_metrics

# Create strategy and run server
strategy = SaveModelStrategy(
    # (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)

파이토치 체크포인트 저장 및 로드#

이전 예제와 비슷하지만 몇 가지 단계가 추가되어 torch.save 함수를 사용하여 파이토치 체크포인트를 저장하는 방법을 보여드리겠습니다. 먼저, aggregate_fit``은 ``Parameters 객체를 반환하는데, 이 객체는 NumPy ndarray``의 목록으로 변환되어야 하며, ``OrderedDict 클래스 구조에 따라 파이토치 ``state_dict``로 변환됩니다.

net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

            # Save the model
            torch.save(net.state_dict(), f"model_round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

진행 상황을 로드하려면 코드에 다음 줄을 추가하기만 하면 됩니다. 이렇게 하면 저장된 모든 체크포인트를 반복하고 최신 체크포인트를 로드합니다:

list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)

전략``을 정의할 ``초기_파라미터``와 같이 필요한 경우 ``파라미터 유형의 이 객체를 반환/사용합니다.