Source code for flwr.server.strategy.fedxgb_bagging

# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated XGBoost bagging aggregation strategy."""


import json
from logging import WARNING
from typing import Any, Callable, Optional, Union, cast

from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


[docs] class FedXgbBagging(FedAvg): """Configurable FedXgbBagging strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( self, evaluate_function: Optional[ Callable[ [int, Parameters, dict[str, Scalar]], Optional[tuple[float, dict[str, Scalar]]], ] ] = None, **kwargs: Any, ): self.evaluate_function = evaluate_function self.global_model: Optional[bytes] = None super().__init__(**kwargs) def __repr__(self) -> str: """Compute a string representation of the strategy.""" rep = f"FedXgbBagging(accept_failures={self.accept_failures})" return rep
[docs] def aggregate_fit( self, server_round: int, results: list[tuple[ClientProxy, FitRes]], failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Aggregate all the client trees global_model = self.global_model for _, fit_res in results: update = fit_res.parameters.tensors for bst in update: global_model = aggregate(global_model, bst) self.global_model = global_model return ( Parameters(tensor_type="", tensors=[cast(bytes, global_model)]), {}, )
[docs] def aggregate_evaluate( self, server_round: int, results: list[tuple[ClientProxy, EvaluateRes]], failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.evaluate_metrics_aggregation_fn: eval_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) elif server_round == 1: # Only log this warning once log(WARNING, "No evaluate_metrics_aggregation_fn provided") return 0, metrics_aggregated
[docs] def evaluate( self, server_round: int, parameters: Parameters ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_function is None: # No evaluation function provided return None eval_res = self.evaluate_function(server_round, parameters, {}) if eval_res is None: return None loss, metrics = eval_res return loss, metrics
def aggregate( bst_prev_org: Optional[bytes], bst_curr_org: bytes, ) -> bytes: """Conduct bagging aggregation for given trees.""" if not bst_prev_org: return bst_curr_org # Get the tree numbers tree_num_prev, _ = _get_tree_nums(bst_prev_org) _, paral_tree_num_curr = _get_tree_nums(bst_curr_org) bst_prev = json.loads(bytearray(bst_prev_org)) bst_curr = json.loads(bytearray(bst_curr_org)) bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" ] = str(tree_num_prev + paral_tree_num_curr) iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ "iteration_indptr" ] bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( iteration_indptr[-1] + paral_tree_num_curr ) # Aggregate new trees trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] for tree_count in range(paral_tree_num_curr): trees_curr[tree_count]["id"] = tree_num_prev + tree_count bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( trees_curr[tree_count] ) bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8") return bst_prev_bytes def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]: xgb_model = json.loads(bytearray(xgb_model_org)) # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" ] ) # Get the number of parallel trees paral_tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_parallel_tree" ] ) return tree_num, paral_tree_num