# Copyright 2022 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Federated Averaging with Momentum (FedAvgM) [Hsu et al., 2019] strategy.Paper: arxiv.org/pdf/1909.06335.pdf"""fromloggingimportWARNINGfromtypingimportCallable,Optional,Unionfromflwr.commonimport(FitRes,MetricsAggregationFn,NDArrays,Parameters,Scalar,ndarrays_to_parameters,parameters_to_ndarrays,)fromflwr.common.loggerimportlogfromflwr.server.client_managerimportClientManagerfromflwr.server.client_proxyimportClientProxyfrom.aggregateimportaggregatefrom.fedavgimportFedAvg# pylint: disable=line-too-long
[문서]classFedAvgM(FedAvg):"""Federated Averaging with Momentum strategy. Implementation based on https://arxiv.org/abs/1909.06335 Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. Defaults to 1.0. fraction_evaluate : float, optional Fraction of clients used during validation. Defaults to 1.0. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. server_learning_rate: float Server-side learning rate used in server-side optimization. Defaults to 1.0. server_momentum: float Server-side momentum factor used for FedAvgM. Defaults to 0.0. """# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-longdef__init__(self,*,fraction_fit:float=1.0,fraction_evaluate:float=1.0,min_fit_clients:int=2,min_evaluate_clients:int=2,min_available_clients:int=2,evaluate_fn:Optional[Callable[[int,NDArrays,dict[str,Scalar]],Optional[tuple[float,dict[str,Scalar]]],]]=None,on_fit_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,on_evaluate_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,accept_failures:bool=True,initial_parameters:Optional[Parameters]=None,fit_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,evaluate_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,server_learning_rate:float=1.0,server_momentum:float=0.0,)->None:super().__init__(fraction_fit=fraction_fit,fraction_evaluate=fraction_evaluate,min_fit_clients=min_fit_clients,min_evaluate_clients=min_evaluate_clients,min_available_clients=min_available_clients,evaluate_fn=evaluate_fn,on_fit_config_fn=on_fit_config_fn,on_evaluate_config_fn=on_evaluate_config_fn,accept_failures=accept_failures,initial_parameters=initial_parameters,fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,)self.server_learning_rate=server_learning_rateself.server_momentum=server_momentumself.server_opt:bool=(self.server_momentum!=0.0)or(self.server_learning_rate!=1.0)self.momentum_vector:Optional[NDArrays]=Nonedef__repr__(self)->str:"""Compute a string representation of the strategy."""rep=f"FedAvgM(accept_failures={self.accept_failures})"returnrep
[문서]definitialize_parameters(self,client_manager:ClientManager)->Optional[Parameters]:"""Initialize global model parameters."""returnself.initial_parameters
[문서]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Parameters],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}# Convert resultsweights_results=[(parameters_to_ndarrays(fit_res.parameters),fit_res.num_examples)for_,fit_resinresults]fedavg_result=aggregate(weights_results)# following convention described in# https://pytorch.org/docs/stable/generated/torch.optim.SGD.htmlifself.server_opt:# You need to initialize the modelassert(self.initial_parametersisnotNone),"When using server-side optimization, model needs to be initialized."initial_weights=parameters_to_ndarrays(self.initial_parameters)# remember that updates are the opposite of gradientspseudo_gradient:NDArrays=[x-yforx,yinzip(parameters_to_ndarrays(self.initial_parameters),fedavg_result)]ifself.server_momentum>0.0:ifserver_round>1:assert(self.momentum_vector),"Momentum should have been created on round 1."self.momentum_vector=[self.server_momentum*x+yforx,yinzip(self.momentum_vector,pseudo_gradient)]else:self.momentum_vector=pseudo_gradient# No nesterov for nowpseudo_gradient=self.momentum_vector# SGDfedavg_result=[x-self.server_learning_rate*yforx,yinzip(initial_weights,pseudo_gradient)]# Update current weightsself.initial_parameters=ndarrays_to_parameters(fedavg_result)parameters_aggregated=ndarrays_to_parameters(fedavg_result)# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.fit_metrics_aggregation_fn:fit_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.fit_metrics_aggregation_fn(fit_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No fit_metrics_aggregation_fn provided")returnparameters_aggregated,metrics_aggregated