모델 체크포인트 저장 및 로드#
Flower는 서버 측에서 모델 업데이트를 자동으로 저장하지 않습니다. 이 사용법 가이드에서는 Flower에서 모델 체크포인트를 저장(및 로드)하는 단계에 대해 설명합니다.
모델 체크포인트#
Strategy
메소드를 사용자 지정하여 서버 측에서 모델 업데이트를 지속할 수 있습니다. 사용자 지정 전략을 구현하는 것은 항상 옵션이지만 대부분의 경우 기존 전략을 간단히 사용자 지정하는 것이 더 편리할 수 있습니다. 다음 코드 예시는 기존의 기본 제공 FedAvg
전략을 사용자 지정한 새로운 SaveModelStrategy`를 정의합니다. 특히, 기본 클래스(:code:`FedAvg
)에서 :code:`aggregate_fit`을 호출하여 :code:`aggregate_fit`을 사용자 지정합니다. 그런 다음 호출자(즉, 서버)에게 집계된 가중치를 반환하기 전에 반환된(집계된) 가중치를 계속 저장합니다:
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
if aggregated_parameters is not None:
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
# Save aggregated_ndarrays
print(f"Saving round {server_round} aggregated_ndarrays...")
np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)
return aggregated_parameters, aggregated_metrics
# Create strategy and run server
strategy = SaveModelStrategy(
# (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)
파이토치 체크포인트 저장 및 로드#
이전 예제와 비슷하지만 몇 가지 단계가 추가되어 torch.save
함수를 사용하여 파이토치 체크포인트를 저장하는 방법을 보여드리겠습니다. 먼저, aggregate_fit``은 ``Parameters
객체를 반환하는데, 이 객체는 NumPy ndarray``의 목록으로 변환되어야 하며, ``OrderedDict
클래스 구조에 따라 파이토치 ``state_dict``로 변환됩니다.
net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate model weights using weighted average and store checkpoint"""
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
if aggregated_parameters is not None:
print(f"Saving round {server_round} aggregated_parameters...")
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
# Convert `List[np.ndarray]` to PyTorch`state_dict`
params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
# Save the model
torch.save(net.state_dict(), f"model_round_{server_round}.pth")
return aggregated_parameters, aggregated_metrics
진행 상황을 로드하려면 코드에 다음 줄을 추가하기만 하면 됩니다. 이렇게 하면 저장된 모든 체크포인트를 반복하고 최신 체크포인트를 로드합니다:
list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
전략``을 정의할 때 ``초기_파라미터``와 같이 필요한 경우 ``파라미터
유형의 이 객체를 반환/사용합니다.