# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bulyan [El Mhamdi et al., 2018] strategy.
Paper: arxiv.org/abs/1802.07927
"""
from collections import OrderedDict
from collections.abc import Iterable
from logging import INFO, WARN
from typing import Callable, Optional, cast
import numpy as np
from flwr.common import (
Array,
ArrayRecord,
Message,
MetricRecord,
NDArrays,
RecordDict,
log,
)
from .fedavg import FedAvg
from .multikrum import select_multikrum
# pylint: disable=too-many-instance-attributes
[๋ฌธ์]
class Bulyan(FedAvg):
"""Bulyan strategy.
Implementation based on https://arxiv.org/abs/1802.07927.
Parameters
----------
fraction_train : float (default: 1.0)
Fraction of nodes used during training. In case `min_train_nodes`
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
will still be sampled.
fraction_evaluate : float (default: 1.0)
Fraction of nodes used during validation. In case `min_evaluate_nodes`
is larger than `fraction_evaluate * total_connected_nodes`,
`min_evaluate_nodes` will still be sampled.
min_train_nodes : int (default: 2)
Minimum number of nodes used during training.
min_evaluate_nodes : int (default: 2)
Minimum number of nodes used during validation.
min_available_nodes : int (default: 2)
Minimum number of total nodes in the system.
num_malicious_nodes : int (default: 0)
Number of malicious nodes in the system.
weighted_by_key : str (default: "num-examples")
The key within each MetricRecord whose value is used as the weight when
computing weighted averages for MetricRecords.
arrayrecord_key : str (default: "arrays")
Key used to store the ArrayRecord when constructing Messages.
configrecord_key : str (default: "config")
Key used to store the ConfigRecord when constructing Messages.
train_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
Function with signature (list[RecordDict], str) -> MetricRecord,
used to aggregate MetricRecords from training round replies.
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
average using the provided weight factor key.
selection_rule : Optional[Callable] (default: None)
Function with signature (list[RecordDict], int, int) -> list[RecordDict].
The inputs are:
- a list of contents from reply messages,
- the assumed number of malicious nodes (`num_malicious_nodes`),
- the number of nodes to select (`num_nodes_to_select`).
The function should implement a Byzantine-resilient selection rule that
serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
which selects nodes according to the Multi-Krum algorithm.
"""
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
fraction_train: float = 1.0,
fraction_evaluate: float = 1.0,
min_train_nodes: int = 2,
min_evaluate_nodes: int = 2,
min_available_nodes: int = 2,
num_malicious_nodes: int = 0,
weighted_by_key: str = "num-examples",
arrayrecord_key: str = "arrays",
configrecord_key: str = "config",
train_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
evaluate_metrics_aggr_fn: Optional[
Callable[[list[RecordDict], str], MetricRecord]
] = None,
selection_rule: Optional[
Callable[[list[RecordDict], int, int], list[RecordDict]]
] = None,
) -> None:
super().__init__(
fraction_train=fraction_train,
fraction_evaluate=fraction_evaluate,
min_train_nodes=min_train_nodes,
min_evaluate_nodes=min_evaluate_nodes,
min_available_nodes=min_available_nodes,
weighted_by_key=weighted_by_key,
arrayrecord_key=arrayrecord_key,
configrecord_key=configrecord_key,
train_metrics_aggr_fn=train_metrics_aggr_fn,
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
)
self.num_malicious_nodes = num_malicious_nodes
self.selection_rule = selection_rule or select_multikrum
[๋ฌธ์]
def summary(self) -> None:
"""Log summary configuration of the strategy."""
log(INFO, "\tโโโ> Bulyan settings:")
log(INFO, "\tโ\tโโโ Number of malicious nodes: %d", self.num_malicious_nodes)
log(INFO, "\tโ\tโโโ Selection rule: %s", self.selection_rule.__name__)
super().summary()
[๋ฌธ์]
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
# Check if sufficient replies have been received
if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
log(
WARN,
"Insufficient replies, skipping Bulyan aggregation: "
"Required at least %d (4*num_malicious_nodes + 3), but received %d.",
4 * self.num_malicious_nodes + 3,
len(valid_replies),
)
return None, None
reply_contents = [msg.content for msg in valid_replies]
# Compute theta and beta
theta = len(valid_replies) - 2 * self.num_malicious_nodes
beta = theta - 2 * self.num_malicious_nodes
# Byzantine-resilient selection rule
selected_contents = self.selection_rule(
reply_contents, self.num_malicious_nodes, theta
)
# Convert each ArrayRecord to a list of NDArray for easier computation
key = list(selected_contents[0].array_records.keys())[0]
array_keys = list(selected_contents[0][key].keys())
selected_ndarrays = [
cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
for ctnt in selected_contents
]
# Compute median
median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
# Aggregate the beta closest weights element-wise
aggregated_ndarrays = aggregate_n_closest_weights(
median_ndarrays, selected_ndarrays, beta
)
# Convert to ArrayRecord
arrays = ArrayRecord(
OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
)
# Aggregate MetricRecords
metrics = self.train_metrics_aggr_fn(
selected_contents,
self.weighted_by_key,
)
return arrays, metrics
def aggregate_n_closest_weights(
ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
) -> NDArrays:
"""Compute the element-wise mean of the `beta` closest weight arrays.
For each element (i-th coordinate), the output is the average of the
`beta` weight arrays that are closest to the reference weights.
Parameters
----------
ref_weights : NDArrays
Reference weights used to compute distances.
weights_list : list[NDArrays]
List of weight arrays (e.g., from selected nodes).
beta : int
Number of closest weight arrays to include in the averaging.
Returns
-------
aggregated_weights : NDArrays
Element-wise average of the `beta` closest weight arrays to the
reference weights.
"""
aggregated_weights = []
for layer_id, ref_layer in enumerate(ref_weights):
# Shape: (n_models, *layer_shape)
layer_stack = np.stack([weights[layer_id] for weights in weights_list])
# Compute absolute differences: shape (n_models, *layer_shape)
diffs = np.abs(layer_stack - ref_layer)
# Find indices of `beta` smallest per coordinate
idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
# Gather the closest weights
closest = np.take_along_axis(layer_stack, idx, axis=0)
# Average them
aggregated_weights.append(np.mean(closest, axis=0))
return aggregated_weights