Source code for flwr.serverapp.strategy.krum

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.

[Blanchard et al., 2017].

Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf
"""


from logging import INFO
from typing import Callable, Optional

from flwr.common import MetricRecord, RecordDict, log

from .multikrum import MultiKrum


# pylint: disable=too-many-instance-attributes
[docs] class Krum(MultiKrum): """Krum [Blanchard et al., 2017] strategy. Implementation based on https://arxiv.org/abs/1703.02757 Parameters ---------- fraction_train : float (default: 1.0) Fraction of nodes used during training. In case `min_train_nodes` is larger than `fraction_train * total_connected_nodes`, `min_train_nodes` will still be sampled. fraction_evaluate : float (default: 1.0) Fraction of nodes used during validation. In case `min_evaluate_nodes` is larger than `fraction_evaluate * total_connected_nodes`, `min_evaluate_nodes` will still be sampled. min_train_nodes : int (default: 2) Minimum number of nodes used during training. min_evaluate_nodes : int (default: 2) Minimum number of nodes used during validation. min_available_nodes : int (default: 2) Minimum number of total nodes in the system. num_malicious_nodes : int (default: 0) Number of malicious nodes in the system. Defaults to 0. weighted_by_key : str (default: "num-examples") The key within each MetricRecord whose value is used as the weight when computing weighted averages for MetricRecords. arrayrecord_key : str (default: "arrays") Key used to store the ArrayRecord when constructing Messages. configrecord_key : str (default: "config") Key used to store the ConfigRecord when constructing Messages. train_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. evaluate_metrics_aggr_fn : Optional[callable] (default: None) Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If `None`, defaults to `aggregate_metricrecords`, which performs a weighted average using the provided weight factor key. """ # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, num_malicious_nodes: int = 0, weighted_by_key: str = "num-examples", arrayrecord_key: str = "arrays", configrecord_key: str = "config", train_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, evaluate_metrics_aggr_fn: Optional[ Callable[[list[RecordDict], str], MetricRecord] ] = None, ) -> None: super().__init__( fraction_train=fraction_train, fraction_evaluate=fraction_evaluate, min_train_nodes=min_train_nodes, min_evaluate_nodes=min_evaluate_nodes, min_available_nodes=min_available_nodes, weighted_by_key=weighted_by_key, num_malicious_nodes=num_malicious_nodes, num_nodes_to_select=1, # Krum selects 1 node arrayrecord_key=arrayrecord_key, configrecord_key=configrecord_key, train_metrics_aggr_fn=train_metrics_aggr_fn, evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn, )
[docs] def summary(self) -> None: """Log summary configuration of the strategy.""" log(INFO, "\t├──> Krum settings:") log(INFO, "\t\t└── Number of malicious nodes: %d", self.num_malicious_nodes) super().summary()