Save and Load Model Checkpoints¶
Flower는 서버 측에서 모델 업데이트를 자동으로 저장하지 않습니다. 이 사용법 가이드에서는 Flower에서 모델 체크포인트를 저장(및 로드)하는 단계에 대해 설명합니다.
Model Checkpointing¶
Model updates can be persisted on the server-side by customizing Strategy
methods.
Implementing custom strategies is always an option, but for many cases it may be more
convenient to simply customize an existing strategy. The following code example defines
a new SaveModelStrategy
which customized the existing built-in FedAvg
strategy.
In particular, it customizes aggregate_fit
by calling aggregate_fit
in the base
class (FedAvg
). It then continues to save returned (aggregated) weights before it
returns those aggregated weights to the caller (i.e., the server):
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: list[tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_parameters is not None:
# Convert `Parameters` to `list[np.ndarray]`
aggregated_ndarrays: list[np.ndarray] = fl.common.parameters_to_ndarrays(
aggregated_parameters
)
# Save aggregated_ndarrays to disk
print(f"Saving round {server_round} aggregated_ndarrays...")
np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)
return aggregated_parameters, aggregated_metrics
# Create strategy and pass into ServerApp
def server_fn(context):
strategy = SaveModelStrategy(
# (same arguments as FedAvg here)
)
config = ServerConfig(num_rounds=3)
return ServerAppComponents(strategy=strategy, config=config)
app = ServerApp(server_fn=server_fn)
Save and Load PyTorch Checkpoints¶
이전 예제와 비슷하지만 몇 가지 단계가 추가되어 torch.save
함수를 사용하여 파이토치 체크포인트를 저장하는 방법을 보여드리겠습니다. 먼저, aggregate_fit``은 ``Parameters
객체를 반환하는데, 이 객체는 NumPy ndarray``의 목록으로 변환되어야 하며, ``OrderedDict
클래스 구조에 따라 파이토치 ``state_dict``로 변환됩니다.
net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: list[tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
"""Aggregate model weights using weighted average and store checkpoint"""
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_parameters is not None:
print(f"Saving round {server_round} aggregated_parameters...")
# Convert `Parameters` to `list[np.ndarray]`
aggregated_ndarrays: list[np.ndarray] = fl.common.parameters_to_ndarrays(
aggregated_parameters
)
# Convert `list[np.ndarray]` to PyTorch `state_dict`
params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
# Save the model to disk
torch.save(net.state_dict(), f"model_round_{server_round}.pth")
return aggregated_parameters, aggregated_metrics
진행 상황을 로드하려면 코드에 다음 줄을 추가하기만 하면 됩니다. 이렇게 하면 저장된 모든 체크포인트를 반복하고 최신 체크포인트를 로드합니다:
list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
전략``을 정의할 때 ``초기_파라미터``와 같이 필요한 경우 ``파라미터
유형의 이 객체를 반환/사용합니다.
Alternatively, we can save and load the model updates during evaluation phase by
overriding evaluate()
or aggregate_evaluate()
method of the strategy
(FedAvg
). Checkout the details in Advanced PyTorch Example and Advanced
TensorFlow Example.