모델 체크포인트 저장 및 로드¶
Flower는 서버 측에서 모델 업데이트를 자동으로 저장하지 않습니다. 이 사용법 가이드에서는 Flower에서 모델 체크포인트를 저장(및 로드)하는 단계에 대해 설명합니다.
모델 체크포인트¶
Model updates can be persisted on the server-side by customizing Strategy
methods.
Implementing custom strategies is always an option, but for many cases it may be more
convenient to simply customize an existing strategy. The following code example defines
a new SaveModelStrategy
which customized the existing built-in FedAvg
strategy.
In particular, it customizes aggregate_fit
by calling aggregate_fit
in the base
class (FedAvg
). It then continues to save returned (aggregated) weights before it
returns those aggregated weights to the caller (i.e., the server):
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_parameters is not None:
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(
aggregated_parameters
)
# Save aggregated_ndarrays
print(f"Saving round {server_round} aggregated_ndarrays...")
np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)
return aggregated_parameters, aggregated_metrics
# Create strategy and run server
strategy = SaveModelStrategy(
# (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)
파이토치 체크포인트 저장 및 로드¶
이전 예제와 비슷하지만 몇 가지 단계가 추가되어 torch.save
함수를 사용하여 파이토치 체크포인트를 저장하는 방법을 보여드리겠습니다. 먼저, aggregate_fit``은 ``Parameters
객체를 반환하는데, 이 객체는 NumPy ndarray``의 목록으로 변환되어야 하며, ``OrderedDict
클래스 구조에 따라 파이토치 ``state_dict``로 변환됩니다.
net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate model weights using weighted average and store checkpoint"""
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_parameters is not None:
print(f"Saving round {server_round} aggregated_parameters...")
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(
aggregated_parameters
)
# Convert `List[np.ndarray]` to PyTorch`state_dict`
params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
# Save the model
torch.save(net.state_dict(), f"model_round_{server_round}.pth")
return aggregated_parameters, aggregated_metrics
진행 상황을 로드하려면 코드에 다음 줄을 추가하기만 하면 됩니다. 이렇게 하면 저장된 모든 체크포인트를 반복하고 최신 체크포인트를 로드합니다:
list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
전략``을 정의할 때 ``초기_파라미터``와 같이 필요한 경우 ``파라미터
유형의 이 객체를 반환/사용합니다.