# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fair Resource Allocation in Federated Learning [Li et al., 2020] strategy.
Paper: openreview.net/pdf?id=ByexElSYDr
"""
from collections import OrderedDict
from collections.abc import Iterable
from logging import INFO
from typing import Callable, Optional
import numpy as np
from flwr.common import (
Array,
ArrayRecord,
ConfigRecord,
Message,
MetricRecord,
NDArray,
RecordDict,
)
from flwr.common.logger import log
from flwr.server import Grid
from ..exception import AggregationError
from .fedavg import FedAvg
[docs]
class QFedAvg(FedAvg):
"""Q-FedAvg strategy.
Implementation based on openreview.net/pdf?id=ByexElSYDr
Parameters
----------
client_learning_rate : float
Local learning rate used by clients during training. This value is used by
the strategy to approximate the base Lipschitz constant L, via
L = 1 / client_learning_rate.
q : float (default: 0.1)
The parameter q that controls the degree of fairness of the algorithm. Please
tune this parameter based on your use case.
When set to 0, q-FedAvg is equivalent to FedAvg.
train_loss_key : str (default: "train_loss")
The key within the MetricRecord whose value is used as the training loss when
aggregating ArrayRecords following q-FedAvg.
fraction_train : float (default: 1.0)
Fraction of nodes used during training. In case `min_train_nodes`
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
will still be sampled.
fraction_evaluate : float (default: 1.0)
Fraction of nodes used during validation. In case `min_evaluate_nodes`
is larger than `fraction_evaluate * total_connected_nodes`,
`min_evaluate_nodes` will still be sampled.
min_train_nodes : int (default: 2)
Minimum number of nodes used during training.
min_evaluate_nodes : int (default: 2)
Minimum number of nodes used during validation.
min_available_nodes : int (default: 2)
Minimum number of total nodes in the system.
weighted_by_key : str (default: "num-examples")
The key within each MetricRecord whose value is used as the weight when
computing weighted averages for MetricRecords.
arrayrecord_key : str (default: "arrays")
Key used to store the ArrayRecord when constructing Messages.
configrecord_key : str (default: "config")
Key used to store the ConfigRecord when constructing Messages.
train_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
"""
def __init__( # pylint: disable=R0913, R0917
self,
client_learning_rate: float,
q: float = 0.1,
train_loss_key: str = "train_loss",
fraction_train: float = 1.0,
fraction_evaluate: float = 1.0,
min_train_nodes: int = 2,
min_evaluate_nodes: int = 2,
min_available_nodes: int = 2,
weighted_by_key: str = "num-examples",
arrayrecord_key: str = "arrays",
configrecord_key: str = "config",
train_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
evaluate_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
) -> None:
super().__init__(
fraction_train=fraction_train,
fraction_evaluate=fraction_evaluate,
min_train_nodes=min_train_nodes,
min_evaluate_nodes=min_evaluate_nodes,
min_available_nodes=min_available_nodes,
weighted_by_key=weighted_by_key,
arrayrecord_key=arrayrecord_key,
configrecord_key=configrecord_key,
train_metrics_aggr_fn=train_metrics_aggr_fn,
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
)
self.q = q
self.client_learning_rate = client_learning_rate
self.train_loss_key = train_loss_key
self.current_arrays: Optional[ArrayRecord] = None
[docs]
def summary(self) -> None:
"""Log summary configuration of the strategy."""
log(INFO, "\t├──> q-FedAvg settings:")
log(INFO, "\t│\t├── client_learning_rate: %s", self.client_learning_rate)
log(INFO, "\t│\t├── q: %s", self.q)
log(INFO, "\t│\t└── train_loss_key: '%s'", self.train_loss_key)
super().summary()
[docs]
def aggregate_train( # pylint: disable=too-many-locals
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
# Call FedAvg aggregate_train to perform validation and aggregation
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
if not valid_replies:
return None, None
# Compute estimate of Lipschitz constant L
L = 1.0 / self.client_learning_rate # pylint: disable=C0103
# q-FedAvg aggregation
if self.current_arrays is None:
raise AggregationError(
"Current global model weights are not available. Make sure to call"
"`configure_train` before calling `aggregate_train`."
)
array_keys = list(self.current_arrays.keys()) # Preserve keys
global_weights = self.current_arrays.to_numpy_ndarrays(keep_input=False)
sum_delta = None
sum_h = 0.0
for msg in valid_replies:
# Extract local weights and training loss from Message
local_weights = get_local_weights(msg)
loss = get_train_loss(msg, self.train_loss_key)
# Compute delta and h
delta, h = compute_delta_and_h(
global_weights, local_weights, self.q, L, loss
)
# Compute sum of deltas and sum of h
if sum_delta is None:
sum_delta = delta
else:
sum_delta = [sd + d for sd, d in zip(sum_delta, delta)]
sum_h += h
# Compute new global weights and convert to Array type
# `np.asarray` can convert numpy scalars to 0-dim arrays
assert sum_delta is not None # Make mypy happy
array_list = [
Array(np.asarray(gw - (d / sum_h)))
for gw, d in zip(global_weights, sum_delta)
]
# Aggregate MetricRecords
metrics = self.train_metrics_aggr_fn(
[msg.content for msg in valid_replies],
self.weighted_by_key,
)
return ArrayRecord(OrderedDict(zip(array_keys, array_list))), metrics
def get_train_loss(msg: Message, loss_key: str) -> float:
"""Extract training loss from a Message."""
metrics = list(msg.content.metric_records.values())[0]
if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int, float)):
raise AggregationError(
"Missing or invalid training loss. "
f"The strategy expected a float value for the key '{loss_key}' "
"as the training loss in each MetricRecord from the clients. "
f"Ensure that '{loss_key}' is present and maps to a valid float."
)
return float(loss)
def get_local_weights(msg: Message) -> list[NDArray]:
"""Extract local weights from a Message."""
arrays = list(msg.content.array_records.values())[0]
return arrays.to_numpy_ndarrays(keep_input=False)
def l2_norm(ndarrays: list[NDArray]) -> float:
"""Compute the squared L2 norm of a list of numpy.ndarray."""
return float(sum(np.sum(np.square(g)) for g in ndarrays))
def compute_delta_and_h(
global_weights: list[NDArray],
local_weights: list[NDArray],
q: float,
L: float, # Lipschitz constant # pylint: disable=C0103
loss: float,
) -> tuple[list[NDArray], float]:
"""Compute delta and h used in q-FedAvg aggregation."""
# Compute gradient_k = L * (w - w_k)
for gw, lw in zip(global_weights, local_weights):
np.subtract(gw, lw, out=lw)
lw *= L
grad = local_weights # After in-place operations, local_weights is now grad
# Compute ||w_k - w||^2
norm = l2_norm(grad)
# Compute delta_k = loss_k^q * gradient_k
loss_pow_q: float = np.float_power(loss + 1e-10, q)
for g in grad:
g *= loss_pow_q
delta = grad # After in-place multiplication, grad is now delta
# Compute h_k
h = q * np.float_power(loss + 1e-10, q - 1) * norm + L * loss_pow_q
return delta, h