# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated XGBoost cyclic aggregation strategy."""
from logging import WARNING
from typing import Any, cast
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from .fedavg import FedAvg
[docs]
class FedXgbCyclic(FedAvg):
"""Configurable FedXgbCyclic strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
**kwargs: Any,
):
self.global_model: bytes | None = None
super().__init__(**kwargs)
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"FedXgbCyclic(accept_failures={self.accept_failures})"
return rep
[docs]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[tuple[ClientProxy, FitRes] | BaseException],
) -> tuple[Parameters | None, dict[str, Scalar]]:
"""Aggregate fit results using bagging."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Fetch the client model from last round as global model
for _, fit_res in results:
update = fit_res.parameters.tensors
for bst in update:
self.global_model = bst
return (
Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]),
{},
)
[docs]
def aggregate_evaluate(
self,
server_round: int,
results: list[tuple[ClientProxy, EvaluateRes]],
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
) -> tuple[float | None, dict[str, Scalar]]:
"""Aggregate evaluation metrics using average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
return 0, metrics_aggregated