# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fault-tolerant variant of FedAvg strategy."""
from logging import WARNING
from typing import Callable, Optional, Union
from flwr.common import (
EvaluateRes,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from .aggregate import aggregate, weighted_loss_avg
from .fedavg import FedAvg
[docs]
class FaultTolerantFedAvg(FedAvg):
"""Configurable fault-tolerant FedAvg strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 1,
min_evaluate_clients: int = 1,
min_available_clients: int = 1,
evaluate_fn: Optional[
Callable[
[int, NDArrays, dict[str, Scalar]],
Optional[tuple[float, dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
min_completion_rate_fit: float = 0.5,
min_completion_rate_evaluate: float = 0.5,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=True,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.completion_rate_fit = min_completion_rate_fit
self.completion_rate_evaluate = min_completion_rate_evaluate
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
return "FaultTolerantFedAvg()"
[docs]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Check if enough results are available
completion_rate = len(results) / (len(results) + len(failures))
if completion_rate < self.completion_rate_fit:
# Not enough results for aggregation
return None, {}
# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for client, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return parameters_aggregated, metrics_aggregated
[docs]
def aggregate_evaluate(
self,
server_round: int,
results: list[tuple[ClientProxy, EvaluateRes]],
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
) -> tuple[Optional[float], dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Check if enough results are available
completion_rate = len(results) / (len(results) + len(failures))
if completion_rate < self.completion_rate_evaluate:
# Not enough results for aggregation
return None, {}
# Aggregate loss
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
return loss_aggregated, metrics_aggregated