Source code for flwr.server.strategy.fault_tolerant_fedavg
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Fault-tolerant variant of FedAvg strategy."""fromloggingimportWARNINGfromtypingimportCallable,Optional,Unionfromflwr.commonimport(EvaluateRes,FitRes,MetricsAggregationFn,NDArrays,Parameters,Scalar,ndarrays_to_parameters,parameters_to_ndarrays,)fromflwr.common.loggerimportlogfromflwr.server.client_proxyimportClientProxyfrom.aggregateimportaggregate,weighted_loss_avgfrom.fedavgimportFedAvg
[docs]classFaultTolerantFedAvg(FedAvg):"""Configurable fault-tolerant FedAvg strategy implementation."""# pylint: disable=too-many-arguments,too-many-instance-attributesdef__init__(self,*,fraction_fit:float=1.0,fraction_evaluate:float=1.0,min_fit_clients:int=1,min_evaluate_clients:int=1,min_available_clients:int=1,evaluate_fn:Optional[Callable[[int,NDArrays,dict[str,Scalar]],Optional[tuple[float,dict[str,Scalar]]],]]=None,on_fit_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,on_evaluate_config_fn:Optional[Callable[[int],dict[str,Scalar]]]=None,min_completion_rate_fit:float=0.5,min_completion_rate_evaluate:float=0.5,initial_parameters:Optional[Parameters]=None,fit_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,evaluate_metrics_aggregation_fn:Optional[MetricsAggregationFn]=None,)->None:super().__init__(fraction_fit=fraction_fit,fraction_evaluate=fraction_evaluate,min_fit_clients=min_fit_clients,min_evaluate_clients=min_evaluate_clients,min_available_clients=min_available_clients,evaluate_fn=evaluate_fn,on_fit_config_fn=on_fit_config_fn,on_evaluate_config_fn=on_evaluate_config_fn,accept_failures=True,initial_parameters=initial_parameters,fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,)self.completion_rate_fit=min_completion_rate_fitself.completion_rate_evaluate=min_completion_rate_evaluatedef__repr__(self)->str:"""Compute a string representation of the strategy."""return"FaultTolerantFedAvg()"
[docs]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Parameters],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""ifnotresults:returnNone,{}# Check if enough results are availablecompletion_rate=len(results)/(len(results)+len(failures))ifcompletion_rate<self.completion_rate_fit:# Not enough results for aggregationreturnNone,{}# Convert resultsweights_results=[(parameters_to_ndarrays(fit_res.parameters),fit_res.num_examples)forclient,fit_resinresults]parameters_aggregated=ndarrays_to_parameters(aggregate(weights_results))# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.fit_metrics_aggregation_fn:fit_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.fit_metrics_aggregation_fn(fit_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No fit_metrics_aggregation_fn provided")returnparameters_aggregated,metrics_aggregated
[docs]defaggregate_evaluate(self,server_round:int,results:list[tuple[ClientProxy,EvaluateRes]],failures:list[Union[tuple[ClientProxy,EvaluateRes],BaseException]],)->tuple[Optional[float],dict[str,Scalar]]:"""Aggregate evaluation losses using weighted average."""ifnotresults:returnNone,{}# Check if enough results are availablecompletion_rate=len(results)/(len(results)+len(failures))ifcompletion_rate<self.completion_rate_evaluate:# Not enough results for aggregationreturnNone,{}# Aggregate lossloss_aggregated=weighted_loss_avg([(evaluate_res.num_examples,evaluate_res.loss)for_,evaluate_resinresults])# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.evaluate_metrics_aggregation_fn:eval_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.evaluate_metrics_aggregation_fn(eval_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No evaluate_metrics_aggregation_fn provided")returnloss_aggregated,metrics_aggregated