Code source de flwr.server.strategy.fedavg_android
# Copyright 2021 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""FedAvg [McMahan et al., 2016] strategy with custom serialization for Android devices.Paper: arxiv.org/abs/1602.05629"""fromtypingimportCallable,Optional,Union,castimportnumpyasnpfromflwr.commonimport(EvaluateIns,EvaluateRes,FitIns,FitRes,NDArray,NDArrays,Parameters,Scalar,)fromflwr.server.client_managerimportClientManagerfromflwr.server.client_proxyimportClientProxyfrom.aggregateimportaggregate,weighted_loss_avgfrom.strategyimportStrategy# pylint: disable=line-too-long
[docs]classFedAvgAndroid(Strategy):"""Federated Averaging strategy. Implementation based on https://arxiv.org/abs/1602.05629 Parameters ---------- fraction_fit : Optional[float] Fraction of clients used during training. Defaults to 1.0. fraction_evaluate : Optional[float] Fraction of clients used during validation. Defaults to 1.0. min_fit_clients : Optional[int] Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : Optional[int] Minimum number of clients used during validation. Defaults to 2. min_available_clients : Optional[int] Minimum number of total clients in the system. Defaults to 2. evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] Function used to configure training. Defaults to None. on_evaluate_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] Function used to configure validation. Defaults to None. accept_failures : Optional[bool] Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Optional[Parameters] Initial global model parameters. """# pylint: disable=too-many-arguments,too-many-instance-attributesdef__init__(self,*,fraction_fit:float=1.0,fraction_evaluate:float=1.0,min_fit_clients:int=2,min_evaluate_clients:int=2,min_available_clients:int=2,evaluate_fn:Optional[Callable[[int,NDArrays,dict[str,Scalar]],Optional[tuple[float,dict[str,Scalar]]],]]=None,on_fit_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,on_evaluate_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,accept_failures:bool=True,initial_parameters:Optional[Parameters]=None,)->None:super().__init__()self.min_fit_clients=min_fit_clientsself.min_evaluate_clients=min_evaluate_clientsself.fraction_fit=fraction_fitself.fraction_evaluate=fraction_evaluateself.min_available_clients=min_available_clientsself.evaluate_fn=evaluate_fnself.on_fit_config_fn=on_fit_config_fnself.on_evaluate_config_fn=on_evaluate_config_fnself.accept_failures=accept_failuresself.initial_parameters=initial_parametersdef__repr__(self)->str:"""Compute a string representation of the strategy."""rep=f"FedAvg(accept_failures={self.accept_failures})"returnrep
[docs]defnum_fit_clients(self,num_available_clients:int)->tuple[int,int]:"""Return the sample size and the required number of available clients."""num_clients=int(num_available_clients*self.fraction_fit)returnmax(num_clients,self.min_fit_clients),self.min_available_clients
[docs]defnum_evaluation_clients(self,num_available_clients:int)->tuple[int,int]:"""Use a fraction of available clients for evaluation."""num_clients=int(num_available_clients*self.fraction_evaluate)returnmax(num_clients,self.min_evaluate_clients),self.min_available_clients
[docs]definitialize_parameters(self,client_manager:ClientManager)->Optional[Parameters]:"""Initialize global model parameters."""initial_parameters=self.initial_parametersself.initial_parameters=None# Don't keep initial parameters in memoryreturninitial_parameters
[docs]defevaluate(self,server_round:int,parameters:Parameters)->Optional[tuple[float,dict[str,Scalar]]]:"""Evaluate model parameters using an evaluation function."""ifself.evaluate_fnisNone:# No evaluation function providedreturnNoneweights=self.parameters_to_ndarrays(parameters)eval_res=self.evaluate_fn(server_round,weights,{})ifeval_resisNone:returnNoneloss,metrics=eval_resreturnloss,metrics
[docs]defconfigure_fit(self,server_round:int,parameters:Parameters,client_manager:ClientManager)->list[tuple[ClientProxy,FitIns]]:"""Configure the next round of training."""config={}ifself.on_fit_config_fnisnotNone:# Custom fit config function providedconfig=self.on_fit_config_fn(server_round)fit_ins=FitIns(parameters,config)# Sample clientssample_size,min_num_clients=self.num_fit_clients(client_manager.num_available())clients=client_manager.sample(num_clients=sample_size,min_num_clients=min_num_clients)# Return client/config pairsreturn[(client,fit_ins)forclientinclients]
[docs]defconfigure_evaluate(self,server_round:int,parameters:Parameters,client_manager:ClientManager)->list[tuple[ClientProxy,EvaluateIns]]:"""Configure the next round of evaluation."""# Do not configure federated evaluation if fraction_evaluate is 0ifself.fraction_evaluate==0.0:return[]# Parameters and configconfig={}ifself.on_evaluate_config_fnisnotNone:# Custom evaluation config function providedconfig=self.on_evaluate_config_fn(server_round)evaluate_ins=EvaluateIns(parameters,config)# Sample clientssample_size,min_num_clients=self.num_evaluation_clients(client_manager.num_available())clients=client_manager.sample(num_clients=sample_size,min_num_clients=min_num_clients)# Return client/config pairsreturn[(client,evaluate_ins)forclientinclients]
[docs]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Parameters],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}# Convert resultsweights_results=[(self.parameters_to_ndarrays(fit_res.parameters),fit_res.num_examples)forclient,fit_resinresults]returnself.ndarrays_to_parameters(aggregate(weights_results)),{}
[docs]defaggregate_evaluate(self,server_round:int,results:list[tuple[ClientProxy,EvaluateRes]],failures:list[Union[tuple[ClientProxy,EvaluateRes],BaseException]],)->tuple[Optional[float],dict[str,Scalar]]:"""Aggregate evaluation losses using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}loss_aggregated=weighted_loss_avg([(evaluate_res.num_examples,evaluate_res.loss)for_,evaluate_resinresults])returnloss_aggregated,{}
[docs]defndarrays_to_parameters(self,ndarrays:NDArrays)->Parameters:"""Convert NumPy ndarrays to parameters object."""tensors=[self.ndarray_to_bytes(ndarray)forndarrayinndarrays]returnParameters(tensors=tensors,tensor_type="numpy.nda")
[docs]defparameters_to_ndarrays(self,parameters:Parameters)->NDArrays:"""Convert parameters object to NumPy weights."""return[self.bytes_to_ndarray(tensor)fortensorinparameters.tensors]
[docs]defndarray_to_bytes(self,ndarray:NDArray)->bytes:"""Serialize NumPy array to bytes."""returnndarray.tobytes()
[docs]defbytes_to_ndarray(self,tensor:bytes)->NDArray:"""Deserialize NumPy array from bytes."""ndarray_deserialized=np.frombuffer(tensor,dtype=np.float32)returncast(NDArray,ndarray_deserialized)