Source code for flwr.serverapp.strategy.fedxgb_bagging
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Flower message-based FedXgbBagging strategy."""fromcollections.abcimportIterablefromtypingimportOptional,castimportnumpyasnpfromflwr.commonimportArrayRecord,ConfigRecord,Message,MetricRecordfromflwr.serverimportGridfrom..exceptionimportInconsistentMessageRepliesfrom.fedavgimportFedAvgfrom.strategy_utilsimportaggregate_bagging# pylint: disable=line-too-long
[docs]classFedXgbBagging(FedAvg):"""Configurable FedXgbBagging strategy implementation. Parameters ---------- fraction_train : float (default: 1.0) Fraction of nodes used during training. In case `min_train_nodes` is larger than `fraction_train * total_connected_nodes`, `min_train_nodes` will still be sampled. fraction_evaluate : float (default: 1.0) Fraction of nodes used during validation. In case `min_evaluate_nodes` is larger than `fraction_evaluate * total_connected_nodes`, `min_evaluate_nodes` will still be sampled. min_train_nodes : int (default: 2) Minimum number of nodes used during training. min_evaluate_nodes : int (default: 2) Minimum number of nodes used during validation. min_available_nodes : int (default: 2) Minimum number of total nodes in the system. weighted_by_key : str (default: "num-examples") The key within each MetricRecord whose value is used as the weight when computing weighted averages for MetricRecords. arrayrecord_key : str (default: "arrays") Key used to store the ArrayRecord when constructing Messages. configrecord_key : str (default: "config") Key used to store the ConfigRecord when constructing Messages. train_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. evaluate_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. """current_bst:Optional[bytes]=Nonedef_ensure_single_array(self,arrays:ArrayRecord)->None:"""Check that ensures there's only one Array in the ArrayRecord."""n=len(arrays)ifn!=1:raiseInconsistentMessageReplies(reason="Expected exactly one Array in ArrayRecord. ""Skipping aggregation.")
[docs]defconfigure_train(self,server_round:int,arrays:ArrayRecord,config:ConfigRecord,grid:Grid)->Iterable[Message]:"""Configure the next round of federated training."""self._ensure_single_array(arrays)# Keep track of array record being communicatedself.current_bst=arrays["0"].numpy().tobytes()returnsuper().configure_train(server_round,arrays,config,grid)
[docs]defaggregate_train(self,server_round:int,replies:Iterable[Message],)->tuple[Optional[ArrayRecord],Optional[MetricRecord]]:"""Aggregate ArrayRecords and MetricRecords in the received Messages."""valid_replies,_=self._check_and_log_replies(replies,is_train=True)arrays,metrics=None,Noneifvalid_replies:reply_contents=[msg.contentformsginvalid_replies]array_record_key=next(iter(reply_contents[0].array_records.keys()))# Aggregate ArrayRecordsforcontentinreply_contents:self._ensure_single_array(cast(ArrayRecord,content[array_record_key]))bst=content[array_record_key]["0"].numpy().tobytes()# type: ignore[union-attr]ifself.current_bstisnotNone:self.current_bst=aggregate_bagging(self.current_bst,bst)ifself.current_bstisnotNone:arrays=ArrayRecord([np.frombuffer(self.current_bst,dtype=np.uint8)])# Aggregate MetricRecordsmetrics=self.train_metrics_aggr_fn(reply_contents,self.weighted_by_key,)returnarrays,metrics