# Copyright 2021 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Adaptive Federated Optimization using Adam (FedAdam) strategy.[Reddi et al., 2020]Paper: arxiv.org/abs/2003.00295"""fromtypingimportCallable,Optional,Unionimportnumpyasnpfromflwr.commonimport(FitRes,MetricsAggregationFn,NDArrays,Parameters,Scalar,ndarrays_to_parameters,parameters_to_ndarrays,)fromflwr.server.client_proxyimportClientProxyfrom.fedoptimportFedOpt# pylint: disable=line-too-long
[docs]classFedAdam(FedOpt):"""FedAdam - Adaptive Federated Optimization using Adam. Implementation based on https://arxiv.org/abs/2003.00295v5 Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. Defaults to 1.0. fraction_evaluate : float, optional Fraction of clients used during validation. Defaults to 1.0. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters Initial global model parameters. fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] Metrics aggregation function, optional. evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] Metrics aggregation function, optional. eta : float, optional Server-side learning rate. Defaults to 1e-1. eta_l : float, optional Client-side learning rate. Defaults to 1e-1. beta_1 : float, optional Momentum parameter. Defaults to 0.9. beta_2 : float, optional Second moment parameter. Defaults to 0.99. tau : float, optional Controls the algorithm's degree of adaptability. Defaults to 1e-9. """# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-localsdef__init__(self,*,fraction_fit:float=1.0,fraction_evaluate:float=1.0,min_fit_clients:int=2,min_evaluate_clients:int=2,min_available_clients:int=2,evaluate_fn:Optional[Callable[[int,NDArrays,dict[str,Scalar]],Optional[tuple[float,dict[str,Scalar]]],]]=None,on_fit_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,on_evaluate_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,accept_failures:bool=True,initial_parameters:Parameters,fit_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,evaluate_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,eta:float=1e-1,eta_l:float=1e-1,beta_1:float=0.9,beta_2:float=0.99,tau:float=1e-9,)->None:super().__init__(fraction_fit=fraction_fit,fraction_evaluate=fraction_evaluate,min_fit_clients=min_fit_clients,min_evaluate_clients=min_evaluate_clients,min_available_clients=min_available_clients,evaluate_fn=evaluate_fn,on_fit_config_fn=on_fit_config_fn,on_evaluate_config_fn=on_evaluate_config_fn,accept_failures=accept_failures,initial_parameters=initial_parameters,fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,eta=eta,eta_l=eta_l,beta_1=beta_1,beta_2=beta_2,tau=tau,)def__repr__(self)->str:"""Compute a string representation of the strategy."""rep=f"FedAdam(accept_failures={self.accept_failures})"returnrep
[docs]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Parameters],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""fedavg_parameters_aggregated,metrics_aggregated=super().aggregate_fit(server_round=server_round,results=results,failures=failures)iffedavg_parameters_aggregatedisNone:returnNone,{}fedavg_weights_aggregate=parameters_to_ndarrays(fedavg_parameters_aggregated)# Adamdelta_t:NDArrays=[x-yforx,yinzip(fedavg_weights_aggregate,self.current_weights)]# m_tifnotself.m_t:self.m_t=[np.zeros_like(x)forxindelta_t]self.m_t=[np.multiply(self.beta_1,x)+(1-self.beta_1)*yforx,yinzip(self.m_t,delta_t)]# v_tifnotself.v_t:self.v_t=[np.zeros_like(x)forxindelta_t]self.v_t=[self.beta_2*x+(1-self.beta_2)*np.multiply(y,y)forx,yinzip(self.v_t,delta_t)]# Compute the bias-corrected learning rate, `eta_norm` for improving convergence# in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &# Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic# Optimization" in the formula line right before Section 2.1.eta_norm=(self.eta*np.sqrt(1-np.power(self.beta_2,server_round+1.0))/(1-np.power(self.beta_1,server_round+1.0)))new_weights=[x+eta_norm*y/(np.sqrt(z)+self.tau)forx,y,zinzip(self.current_weights,self.m_t,self.v_t)]self.current_weights=new_weightsreturnndarrays_to_parameters(self.current_weights),metrics_aggregated