# Copyright 2021 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adaptive Federated Optimization using Adam (FedAdam) strategy.
[Reddi et al., 2020]
Paper: arxiv.org/abs/2003.00295
"""
from typing import Callable, Optional, Union
import numpy as np
from flwr.common import (
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from .fedopt import FedOpt
# pylint: disable=line-too-long
[docs]
class FedAdam(FedOpt):
"""FedAdam - Adaptive Federated Optimization using Adam.
Implementation based on https://arxiv.org/abs/2003.00295v5
Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 1.0.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 1.0.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters
Initial global model parameters.
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
eta : float, optional
Server-side learning rate. Defaults to 1e-1.
eta_l : float, optional
Client-side learning rate. Defaults to 1e-1.
beta_1 : float, optional
Momentum parameter. Defaults to 0.9.
beta_2 : float, optional
Second moment parameter. Defaults to 0.99.
tau : float, optional
Controls the algorithm's degree of adaptability. Defaults to 1e-9.
"""
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, dict[str, Scalar]],
Optional[tuple[float, dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Parameters,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
eta: float = 1e-1,
eta_l: float = 1e-1,
beta_1: float = 0.9,
beta_2: float = 0.99,
tau: float = 1e-9,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
eta=eta,
eta_l=eta_l,
beta_1=beta_1,
beta_2=beta_2,
tau=tau,
)
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"FedAdam(accept_failures={self.accept_failures})"
return rep
[docs]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
server_round=server_round, results=results, failures=failures
)
if fedavg_parameters_aggregated is None:
return None, {}
fedavg_weights_aggregate = parameters_to_ndarrays(fedavg_parameters_aggregated)
# Adam
delta_t: NDArrays = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]
# m_t
if not self.m_t:
self.m_t = [np.zeros_like(x) for x in delta_t]
self.m_t = [
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
for x, y in zip(self.m_t, delta_t)
]
# v_t
if not self.v_t:
self.v_t = [np.zeros_like(x) for x in delta_t]
self.v_t = [
self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
for x, y in zip(self.v_t, delta_t)
]
# Compute the bias-corrected learning rate, `eta_norm` for improving convergence
# in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
# Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
# Optimization" in the formula line right before Section 2.1.
eta_norm = (
self.eta
* np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
/ (1 - np.power(self.beta_1, server_round + 1.0))
)
new_weights = [
x + eta_norm * y / (np.sqrt(z) + self.tau)
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
]
self.current_weights = new_weights
return ndarrays_to_parameters(self.current_weights), metrics_aggregated