flwr.serverapp.strategy.fedavg의 소스 코드

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower message-based FedAvg strategy."""


from collections.abc import Iterable
from logging import INFO, WARNING
from typing import Callable, Optional

from flwr.common import (
    ArrayRecord,
    ConfigRecord,
    Message,
    MessageType,
    MetricRecord,
    RecordDict,
    log,
)
from flwr.server import Grid

from .strategy import Strategy
from .strategy_utils import (
    aggregate_arrayrecords,
    aggregate_metricrecords,
    sample_nodes,
    validate_message_reply_consistency,
)


# pylint: disable=too-many-instance-attributes
[문서] class FedAvg(Strategy): """Federated Averaging strategy. Implementation based on https://arxiv.org/abs/1602.05629 Parameters ---------- fraction_train : float (default: 1.0) Fraction of nodes used during training. In case `min_train_nodes` is larger than `fraction_train * total_connected_nodes`, `min_train_nodes` will still be sampled. fraction_evaluate : float (default: 1.0) Fraction of nodes used during validation. In case `min_evaluate_nodes` is larger than `fraction_evaluate * total_connected_nodes`, `min_evaluate_nodes` will still be sampled. min_train_nodes : int (default: 2) Minimum number of nodes used during training. min_evaluate_nodes : int (default: 2) Minimum number of nodes used during validation. min_available_nodes : int (default: 2) Minimum number of total nodes in the system. weighted_by_key : str (default: "num-examples") The key within each MetricRecord whose value is used as the weight when computing weighted averages for both ArrayRecords and MetricRecords. arrayrecord_key : str (default: "arrays") Key used to store the ArrayRecord when constructing Messages. configrecord_key : str (default: "config") Key used to store the ConfigRecord when constructing Messages. train_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. evaluate_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. """ # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = "num-examples", arrayrecord_key: str = "arrays", configrecord_key: str = "config", train_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, evaluate_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, ) -> None: self.fraction_train = fraction_train self.fraction_evaluate = fraction_evaluate self.min_train_nodes = min_train_nodes self.min_evaluate_nodes = min_evaluate_nodes self.min_available_nodes = min_available_nodes self.weighted_by_key = weighted_by_key self.arrayrecord_key = arrayrecord_key self.configrecord_key = configrecord_key self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords self.evaluate_metrics_aggr_fn = ( evaluate_metrics_aggr_fn or aggregate_metricrecords ) if self.fraction_evaluate == 0.0: self.min_evaluate_nodes = 0 log( WARNING, "fraction_evaluate is set to 0.0. " "Federated evaluation will be skipped.", ) if self.fraction_train == 0.0: self.min_train_nodes = 0 log( WARNING, "fraction_train is set to 0.0. Federated training will be skipped.", )
[문서] def summary(self) -> None: """Log summary configuration of the strategy.""" log(INFO, "\t├──> Sampling:") log( INFO, "\t\t├──Fraction: train (%.2f) | evaluate ( %.2f)", self.fraction_train, self.fraction_evaluate, ) # pylint: disable=line-too-long log( INFO, "\t\t├──Minimum nodes: train (%d) | evaluate (%d)", self.min_train_nodes, self.min_evaluate_nodes, ) # pylint: disable=line-too-long log(INFO, "\t\t└──Minimum available nodes: %d", self.min_available_nodes) log(INFO, "\t└──> Keys in records:") log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key) log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key) log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
def _construct_messages( self, record: RecordDict, node_ids: list[int], message_type: str ) -> Iterable[Message]: """Construct N Messages carrying the same RecordDict payload.""" messages = [] for node_id in node_ids: # one message for each node message = Message( content=record, message_type=message_type, dst_node_id=node_id, ) messages.append(message) return messages
[문서] def configure_train( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of federated training.""" # Do not configure federated train if fraction_train is 0. if self.fraction_train == 0.0: return [] # Sample nodes num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train) sample_size = max(num_nodes, self.min_train_nodes) node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size) log( INFO, "configure_train: Sampled %s nodes (out of %s)", len(node_ids), len(num_total), ) # Always inject current server round config["server-round"] = server_round # Construct messages record = RecordDict( {self.arrayrecord_key: arrays, self.configrecord_key: config} ) return self._construct_messages(record, node_ids, MessageType.TRAIN)
def _check_and_log_replies( self, replies: Iterable[Message], is_train: bool, validate: bool = True ) -> tuple[list[Message], list[Message]]: """Check replies for errors and log them. Parameters ---------- replies : Iterable[Message] Iterable of reply Messages. is_train : bool Set to True if the replies are from a training round; False otherwise. This impacts logging and validation behavior. validate : bool (default: True) Whether to validate the reply contents for consistency. Returns ------- tuple[list[Message], list[Message]] A tuple containing two lists: - Messages with valid contents. - Messages with errors. """ if not replies: return [], [] # Filter messages that carry content valid_replies: list[Message] = [] error_replies: list[Message] = [] for msg in replies: if msg.has_error(): error_replies.append(msg) else: valid_replies.append(msg) log( INFO, "%s: Received %s results and %s failures", "aggregate_train" if is_train else "aggregate_evaluate", len(valid_replies), len(error_replies), ) # Log errors for msg in error_replies: log( INFO, "\t> Received error in reply from node %d: %s", msg.metadata.src_node_id, msg.error.reason, ) # Ensure expected ArrayRecords and MetricRecords are received if validate and valid_replies: validate_message_reply_consistency( replies=[msg.content for msg in valid_replies], weighted_by_key=self.weighted_by_key, check_arrayrecord=is_train, ) return valid_replies, error_replies
[문서] def aggregate_train( self, server_round: int, replies: Iterable[Message], ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]: """Aggregate ArrayRecords and MetricRecords in the received Messages.""" valid_replies, _ = self._check_and_log_replies(replies, is_train=True) arrays, metrics = None, None if valid_replies: reply_contents = [msg.content for msg in valid_replies] # Aggregate ArrayRecords arrays = aggregate_arrayrecords( reply_contents, self.weighted_by_key, ) # Aggregate MetricRecords metrics = self.train_metrics_aggr_fn( reply_contents, self.weighted_by_key, ) return arrays, metrics
[문서] def configure_evaluate( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of federated evaluation.""" # Do not configure federated evaluation if fraction_evaluate is 0. if self.fraction_evaluate == 0.0: return [] # Sample nodes num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate) sample_size = max(num_nodes, self.min_evaluate_nodes) node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size) log( INFO, "configure_evaluate: Sampled %s nodes (out of %s)", len(node_ids), len(num_total), ) # Always inject current server round config["server-round"] = server_round # Construct messages record = RecordDict( {self.arrayrecord_key: arrays, self.configrecord_key: config} ) return self._construct_messages(record, node_ids, MessageType.EVALUATE)
[문서] def aggregate_evaluate( self, server_round: int, replies: Iterable[Message], ) -> Optional[MetricRecord]: """Aggregate MetricRecords in the received Messages.""" valid_replies, _ = self._check_and_log_replies(replies, is_train=False) metrics = None if valid_replies: reply_contents = [msg.content for msg in valid_replies] # Aggregate MetricRecords metrics = self.evaluate_metrics_aggr_fn( reply_contents, self.weighted_by_key, ) return metrics