# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower message-based FedAvg strategy."""
from collections.abc import Iterable
from logging import INFO
from typing import Callable, Optional
from flwr.common import (
ArrayRecord,
ConfigRecord,
Message,
MessageType,
MetricRecord,
RecordDict,
log,
)
from flwr.server import Grid
from .strategy import Strategy
from .strategy_utils import (
aggregate_arrayrecords,
aggregate_metricrecords,
sample_nodes,
validate_message_reply_consistency,
)
# pylint: disable=too-many-instance-attributes
[docs]
class FedAvg(Strategy):
"""Federated Averaging strategy.
Implementation based on https://arxiv.org/abs/1602.05629
Parameters
----------
fraction_train : float (default: 1.0)
Fraction of nodes used during training. In case `min_train_nodes`
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
will still be sampled.
fraction_evaluate : float (default: 1.0)
Fraction of nodes used during validation. In case `min_evaluate_nodes`
is larger than `fraction_evaluate * total_connected_nodes`,
`min_evaluate_nodes` will still be sampled.
min_train_nodes : int (default: 2)
Minimum number of nodes used during training.
min_evaluate_nodes : int (default: 2)
Minimum number of nodes used during validation.
min_available_nodes : int (default: 2)
Minimum number of total nodes in the system.
weighted_by_key : str (default: "num-examples")
The key within each MetricRecord whose value is used as the weight when
computing weighted averages for both ArrayRecords and MetricRecords.
arrayrecord_key : str (default: "arrays")
Key used to store the ArrayRecord when constructing Messages.
configrecord_key : str (default: "config")
Key used to store the ConfigRecord when constructing Messages.
train_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
"""
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
fraction_train: float = 1.0,
fraction_evaluate: float = 1.0,
min_train_nodes: int = 2,
min_evaluate_nodes: int = 2,
min_available_nodes: int = 2,
weighted_by_key: str = "num-examples",
arrayrecord_key: str = "arrays",
configrecord_key: str = "config",
train_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
evaluate_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
) -> None:
self.fraction_train = fraction_train
self.fraction_evaluate = fraction_evaluate
self.min_train_nodes = min_train_nodes
self.min_evaluate_nodes = min_evaluate_nodes
self.min_available_nodes = min_available_nodes
self.weighted_by_key = weighted_by_key
self.arrayrecord_key = arrayrecord_key
self.configrecord_key = configrecord_key
self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords
self.evaluate_metrics_aggr_fn = (
evaluate_metrics_aggr_fn or aggregate_metricrecords
)
[docs]
def summary(self) -> None:
"""Log summary configuration of the strategy."""
log(INFO, "\t├──> Sampling:")
log(
INFO,
"\t│\t├──Fraction: train (%.2f) | evaluate ( %.2f)",
self.fraction_train,
self.fraction_evaluate,
) # pylint: disable=line-too-long
log(
INFO,
"\t│\t├──Minimum nodes: train (%d) | evaluate (%d)",
self.min_train_nodes,
self.min_evaluate_nodes,
) # pylint: disable=line-too-long
log(INFO, "\t│\t└──Minimum available nodes: %d", self.min_available_nodes)
log(INFO, "\t└──> Keys in records:")
log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key)
log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key)
log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
def _construct_messages(
self, record: RecordDict, node_ids: list[int], message_type: str
) -> Iterable[Message]:
"""Construct N Messages carrying the same RecordDict payload."""
messages = []
for node_id in node_ids: # one message for each node
message = Message(
content=record,
message_type=message_type,
dst_node_id=node_id,
)
messages.append(message)
return messages
[docs]
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
if not replies:
return None, None
# Log if any Messages carried errors
# Filter messages that carry content
num_errors = 0
replies_with_content = []
for msg in replies:
if msg.has_error():
log(
INFO,
"Received error in reply from node %d: %s",
msg.metadata.src_node_id,
msg.error,
)
num_errors += 1
else:
replies_with_content.append(msg.content)
log(
INFO,
"aggregate_train: Received %s results and %s failures",
len(replies_with_content),
num_errors,
)
# Ensure expected ArrayRecords and MetricRecords are received
validate_message_reply_consistency(
replies=replies_with_content,
weighted_by_key=self.weighted_by_key,
check_arrayrecord=True,
)
arrays, metrics = None, None
if replies_with_content:
# Aggregate ArrayRecords
arrays = aggregate_arrayrecords(
replies_with_content,
self.weighted_by_key,
)
# Aggregate MetricRecords
metrics = self.train_metrics_aggr_fn(
replies_with_content,
self.weighted_by_key,
)
return arrays, metrics
[docs]
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate MetricRecords in the received Messages."""
if not replies:
return None
# Log if any Messages carried errors
# Filter messages that carry content
num_errors = 0
replies_with_content = []
for msg in replies:
if msg.has_error():
log(
INFO,
"Received error in reply from node %d: %s",
msg.metadata.src_node_id,
msg.error,
)
num_errors += 1
else:
replies_with_content.append(msg.content)
log(
INFO,
"aggregate_evaluate: Received %s results and %s failures",
len(replies_with_content),
num_errors,
)
# Ensure expected ArrayRecords and MetricRecords are received
validate_message_reply_consistency(
replies=replies_with_content,
weighted_by_key=self.weighted_by_key,
check_arrayrecord=False,
)
metrics = None
if replies_with_content:
# Aggregate MetricRecords
metrics = self.evaluate_metrics_aggr_fn(
replies_with_content,
self.weighted_by_key,
)
return metrics