# Copyright 2023 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Federated XGBoost [Ma et al., 2023] strategy.Strategy in the horizontal setting based on building Neural Network and averaging onprediction outcomes.Paper: arxiv.org/abs/2304.07537"""fromloggingimportWARNINGfromtypingimportAny,Optional,Unionfromflwr.commonimportFitRes,Scalar,ndarrays_to_parameters,parameters_to_ndarraysfromflwr.common.loggerimportlog,warn_deprecated_featurefromflwr.server.client_proxyimportClientProxyfrom.aggregateimportaggregatefrom.fedavgimportFedAvg
[docs]classFedXgbNnAvg(FedAvg):"""Configurable FedXgbNnAvg strategy implementation. Warning ------- This strategy is deprecated, but a copy of it is available in Flower Baselines: https://github.com/adap/flower/tree/main/baselines/hfedxgboost. """def__init__(self,*args:Any,**kwargs:Any)->None:"""Federated XGBoost [Ma et al., 2023] strategy. Implementation based on https://arxiv.org/abs/2304.07537. """super().__init__(*args,**kwargs)warn_deprecated_feature("`FedXgbNnAvg` strategy")def__repr__(self)->str:"""Compute a string representation of the strategy."""rep=f"FedXgbNnAvg(accept_failures={self.accept_failures})"returnrep
[docs]defevaluate(self,server_round:int,parameters:Any)->Optional[tuple[float,dict[str,Scalar]]]:"""Evaluate model parameters using an evaluation function."""ifself.evaluate_fnisNone:# No evaluation function providedreturnNoneeval_res=self.evaluate_fn(server_round,parameters,{})ifeval_resisNone:returnNoneloss,metrics=eval_resreturnloss,metrics
[docs]defaggregate_fit(self,server_round:int,results:list[tuple[ClientProxy,FitRes]],failures:list[Union[tuple[ClientProxy,FitRes],BaseException]],)->tuple[Optional[Any],dict[str,Scalar]]:"""Aggregate fit results using weighted average."""ifnotresults:returnNone,{}# Do not aggregate if there are failures and failures are not acceptedifnotself.accept_failuresandfailures:returnNone,{}# Convert resultsweights_results=[(parameters_to_ndarrays(fit_res.parameters[0].parameters),# type: ignore # noqa: E501 # pylint: disable=line-too-longfit_res.num_examples,)for_,fit_resinresults]parameters_aggregated=ndarrays_to_parameters(aggregate(weights_results))# Aggregate XGBoost trees from all clientstrees_aggregated=[fit_res.parameters[1]for_,fit_resinresults]# type: ignore # noqa: E501 # pylint: disable=line-too-long# Aggregate custom metrics if aggregation fn was providedmetrics_aggregated={}ifself.fit_metrics_aggregation_fn:fit_metrics=[(res.num_examples,res.metrics)for_,resinresults]metrics_aggregated=self.fit_metrics_aggregation_fn(fit_metrics)elifserver_round==1:# Only log this warning oncelog(WARNING,"No fit_metrics_aggregation_fn provided")return[parameters_aggregated,trees_aggregated],metrics_aggregated