# Copyright 2025 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Federated Averaging (FedAvg) [McMahan et al., 2016] strategy.Paper: arxiv.org/abs/1602.05629"""fromloggingimportWARNINGfromtypingimportCallable,Optional,Unionfromflwr.commonimport(EvaluateIns,EvaluateRes,FitIns,FitRes,MetricsAggregationFn,NDArrays,Parameters,Scalar,ndarrays_to_parameters,parameters_to_ndarrays,)fromflwr.common.loggerimportlogfromflwr.server.client_managerimportClientManagerfromflwr.server.client_proxyimportClientProxyfrom.aggregateimportaggregate,aggregate_inplace,weighted_loss_avgfrom.strategyimportStrategyWARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW="""Setting `min_available_clients` lower than `min_fit_clients` or`min_evaluate_clients` can cause the server to fail when there are too few clientsconnected to the server. `min_available_clients` must be set to a value largerthan or equal to the values of `min_fit_clients` and `min_evaluate_clients`."""# pylint: disable=line-too-long
[문서]classFedAvg(Strategy):"""Federated Averaging strategy. Implementation based on https://arxiv.org/abs/1602.05629 Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. In case `min_fit_clients` is larger than `fraction_fit * available_clients`, `min_fit_clients` will still be sampled. Defaults to 1.0. fraction_evaluate : float, optional Fraction of clients used during validation. In case `min_evaluate_clients` is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients` will still be sampled. Defaults to 1.0. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] Metrics aggregation function, optional. evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] Metrics aggregation function, optional. inplace : bool (default: True) Enable (True) or disable (False) in-place aggregation of model updates. """# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-longdef__init__(self,*,fraction_fit:float=1.0,fraction_evaluate:float=1.0,min_fit_clients:int=2,min_evaluate_clients:int=2,min_available_clients:int=2,evaluate_fn:Optional[Callable[[int,NDArrays,dict[str,Scalar]],Optional[tuple[float,dict[str,Scalar]]],]]=None,on_fit_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,on_evaluate_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,accept_failures:bool=True,initial_parameters:Optional[Parameters]=None,fit_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,evaluate_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,inplace:bool=True,)->None:super().__init__()if(min_fit_clients>min_available_clientsormin_evaluate_clients>min_available_clients):log(WARNING,WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)self.fraction_fit=fraction_fitself.fraction_evaluate=fraction_evaluateself.min_fit_clients=min_fit_clientsself.min_evaluate_clients=min_evaluate_clientsself.min_available_clients=min_available_clientsself.evaluate_fn=evaluate_fnself.on_fit_config_fn=on_fit_config_fnself.on_evaluate_config_fn=on_evaluate_config_fnself.accept_failures=accept_failuresself.initial_parameters=initial_parametersself.fit_metrics_aggregation_fn=fit_metrics_aggregation_fnself.evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fnself.inplace=inplacedef__repr__(self)->str:"""Compute a string representation of the strategy."""rep=f"FedAvg(accept_failures={self.accept_failures})"returnrep
[문서]defnum_fit_clients(self,num_available_clients:int)->tuple[int,int]:"""Return the sample size and the required number of available clients."""num_clients=int(num_available_clients*self.fraction_fit)returnmax(num_clients,self.min_fit_clients),self.min_available_clients
[문서]defnum_evaluation_clients(self,num_available_clients:int)->tuple[int,int]:"""Use a fraction of available clients for evaluation."""num_clients=int(num_available_clients*self.fraction_evaluate)returnmax(num_clients,self.min_evaluate_clients),self.min_available_clients
[문서]definitialize_parameters(self,client_manager:ClientManager)->Optional[Parameters]:"""Initialize global model parameters."""initial_parameters=self.initial_parametersself.initial_parameters=None# Don't keep initial parameters in memoryreturninitial_parameters
[문서]defevaluate(self,server_round:int,parameters:Parameters)->Optional[tuple[float,dict[str,Scalar]]]:"""Evaluate model parameters using an evaluation function."""ifself.evaluate_fnisNone:# No evaluation function providedreturnNoneparameters_ndarrays=parameters_to_ndarrays(parameters)eval_res=self.evaluate_fn(server_round,parameters_ndarrays,{})ifeval_resisNone:returnNoneloss,metrics=eval_resreturnloss,metrics
[문서]defconfigure_fit(self,server_round:int,parameters:Parameters,client_manager:ClientManager)->list[tuple[ClientProxy,FitIns]]:"""Configure the next round of training."""config={}ifself.on_fit_config_fnisnotNone:# Custom fit config function providedconfig=self.on_fit_config_fn(server_round)fit_ins=FitIns(parameters,config)# Sample clientssample_size,min_num_clients=self.num_fit_clients(client_manager.num_available())clients=client_manager.sample(num_clients=sample_size,min_num_clients=min_num_clients)# Return client/config pairsreturn[(client,fit_ins)forclientinclients]
[문서]defconfigure_evaluate(self,server_round:int,parameters:Parameters,client_manager:ClientManager)->list[tuple[ClientProxy,EvaluateIns]]:"""Configure the next round of evaluation."""# Do not configure federated evaluation if fraction eval is 0.ifself.fraction_evaluate==0.0:return[]# Parameters and configconfig={}ifself.on_evaluate_config_fnisnotNone:# Custom evaluation config function providedconfig=self.on_evaluate_config_fn(server_round)evaluate_ins=EvaluateIns(parameters,config)# Sample clientssample_size,min_num_clients=self.num_evaluation_clients(client_manager.num_available())clients=client_manager.sample(num_clients=sample_size,min_num_clients=min_num_clients)# Return client/config pairsreturn[(client,evaluate_ins)forclientinclients]
[문서]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Parameters],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}ifself.inplace:# Does in-place weighted average of resultsaggregated_ndarrays=aggregate_inplace(results)else:# Convert resultsweights_results=[(parameters_to_ndarrays(fit_res.parameters),fit_res.num_examples)for_,fit_resinresults]aggregated_ndarrays=aggregate(weights_results)parameters_aggregated=ndarrays_to_parameters(aggregated_ndarrays)# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.fit_metrics_aggregation_fn:fit_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.fit_metrics_aggregation_fn(fit_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No fit_metrics_aggregation_fn provided")returnparameters_aggregated,metrics_aggregated
[문서]defaggregate_evaluate(self,server_round:int,results:list[tuple[ClientProxy,EvaluateRes]],failures:list[Union[tuple[ClientProxy,EvaluateRes],BaseException]],)->tuple[Optional[float],dict[str,Scalar]]:"""Aggregate evaluation losses using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}# Aggregate lossloss_aggregated=weighted_loss_avg([(evaluate_res.num_examples,evaluate_res.loss)for_,evaluate_resinresults])# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.evaluate_metrics_aggregation_fn:eval_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.evaluate_metrics_aggregation_fn(eval_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No evaluate_metrics_aggregation_fn provided")returnloss_aggregated,metrics_aggregated