Source code for flwr.serverapp.strategy.fedavg

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower message-based FedAvg strategy."""


from collections.abc import Iterable
from logging import INFO
from typing import Callable, Optional

from flwr.common import (
    ArrayRecord,
    ConfigRecord,
    Message,
    MessageType,
    MetricRecord,
    RecordDict,
    log,
)
from flwr.server import Grid

from .strategy import Strategy
from .strategy_utils import (
    aggregate_arrayrecords,
    aggregate_metricrecords,
    sample_nodes,
    validate_message_reply_consistency,
)


# pylint: disable=too-many-instance-attributes
[docs] class FedAvg(Strategy): """Federated Averaging strategy. Implementation based on https://arxiv.org/abs/1602.05629 Parameters ---------- fraction_train : float (default: 1.0) Fraction of nodes used during training. In case `min_train_nodes` is larger than `fraction_train * total_connected_nodes`, `min_train_nodes` will still be sampled. fraction_evaluate : float (default: 1.0) Fraction of nodes used during validation. In case `min_evaluate_nodes` is larger than `fraction_evaluate * total_connected_nodes`, `min_evaluate_nodes` will still be sampled. min_train_nodes : int (default: 2) Minimum number of nodes used during training. min_evaluate_nodes : int (default: 2) Minimum number of nodes used during validation. min_available_nodes : int (default: 2) Minimum number of total nodes in the system. weighted_by_key : str (default: "num-examples") The key within each MetricRecord whose value is used as the weight when computing weighted averages for both ArrayRecords and MetricRecords. arrayrecord_key : str (default: "arrays") Key used to store the ArrayRecord when constructing Messages. configrecord_key : str (default: "config") Key used to store the ConfigRecord when constructing Messages. train_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. evaluate_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. """ # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = "num-examples", arrayrecord_key: str = "arrays", configrecord_key: str = "config", train_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, evaluate_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, ) -> None: self.fraction_train = fraction_train self.fraction_evaluate = fraction_evaluate self.min_train_nodes = min_train_nodes self.min_evaluate_nodes = min_evaluate_nodes self.min_available_nodes = min_available_nodes self.weighted_by_key = weighted_by_key self.arrayrecord_key = arrayrecord_key self.configrecord_key = configrecord_key self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords self.evaluate_metrics_aggr_fn = ( evaluate_metrics_aggr_fn or aggregate_metricrecords )
[docs] def summary(self) -> None: """Log summary configuration of the strategy.""" log(INFO, "\t├──> Sampling:") log( INFO, "\t\t├──Fraction: train (%.2f) | evaluate ( %.2f)", self.fraction_train, self.fraction_evaluate, ) # pylint: disable=line-too-long log( INFO, "\t\t├──Minimum nodes: train (%d) | evaluate (%d)", self.min_train_nodes, self.min_evaluate_nodes, ) # pylint: disable=line-too-long log(INFO, "\t\t└──Minimum available nodes: %d", self.min_available_nodes) log(INFO, "\t└──> Keys in records:") log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key) log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key) log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
def _construct_messages( self, record: RecordDict, node_ids: list[int], message_type: str ) -> Iterable[Message]: """Construct N Messages carrying the same RecordDict payload.""" messages = [] for node_id in node_ids: # one message for each node message = Message( content=record, message_type=message_type, dst_node_id=node_id, ) messages.append(message) return messages
[docs] def configure_train( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of federated training.""" # Sample nodes num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train) sample_size = max(num_nodes, self.min_train_nodes) node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size) log( INFO, "configure_train: Sampled %s nodes (out of %s)", len(node_ids), len(num_total), ) # Always inject current server round config["server-round"] = server_round # Construct messages record = RecordDict( {self.arrayrecord_key: arrays, self.configrecord_key: config} ) return self._construct_messages(record, node_ids, MessageType.TRAIN)
[docs] def aggregate_train( self, server_round: int, replies: Iterable[Message], ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]: """Aggregate ArrayRecords and MetricRecords in the received Messages.""" if not replies: return None, None # Log if any Messages carried errors # Filter messages that carry content num_errors = 0 replies_with_content = [] for msg in replies: if msg.has_error(): log( INFO, "Received error in reply from node %d: %s", msg.metadata.src_node_id, msg.error, ) num_errors += 1 else: replies_with_content.append(msg.content) log( INFO, "aggregate_train: Received %s results and %s failures", len(replies_with_content), num_errors, ) # Ensure expected ArrayRecords and MetricRecords are received validate_message_reply_consistency( replies=replies_with_content, weighted_by_key=self.weighted_by_key, check_arrayrecord=True, ) arrays, metrics = None, None if replies_with_content: # Aggregate ArrayRecords arrays = aggregate_arrayrecords( replies_with_content, self.weighted_by_key, ) # Aggregate MetricRecords metrics = self.train_metrics_aggr_fn( replies_with_content, self.weighted_by_key, ) return arrays, metrics
[docs] def configure_evaluate( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of federated evaluation.""" # Sample nodes num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate) sample_size = max(num_nodes, self.min_evaluate_nodes) node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size) log( INFO, "configure_evaluate: Sampled %s nodes (out of %s)", len(node_ids), len(num_total), ) # Always inject current server round config["server-round"] = server_round # Construct messages record = RecordDict( {self.arrayrecord_key: arrays, self.configrecord_key: config} ) return self._construct_messages(record, node_ids, MessageType.EVALUATE)
[docs] def aggregate_evaluate( self, server_round: int, replies: Iterable[Message], ) -> Optional[MetricRecord]: """Aggregate MetricRecords in the received Messages.""" if not replies: return None # Log if any Messages carried errors # Filter messages that carry content num_errors = 0 replies_with_content = [] for msg in replies: if msg.has_error(): log( INFO, "Received error in reply from node %d: %s", msg.metadata.src_node_id, msg.error, ) num_errors += 1 else: replies_with_content.append(msg.content) log( INFO, "aggregate_evaluate: Received %s results and %s failures", len(replies_with_content), num_errors, ) # Ensure expected ArrayRecords and MetricRecords are received validate_message_reply_consistency( replies=replies_with_content, weighted_by_key=self.weighted_by_key, check_arrayrecord=False, ) metrics = None if replies_with_content: # Aggregate MetricRecords metrics = self.evaluate_metrics_aggr_fn( replies_with_content, self.weighted_by_key, ) return metrics