# Copyright 2021 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FAIR RESOURCE ALLOCATION IN FEDERATED LEARNING [Li et al., 2020] strategy.
Paper: openreview.net/pdf?id=ByexElSYDr
"""
from logging import WARNING
from typing import Callable, Optional, Union
import numpy as np
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from .aggregate import aggregate_qffl, weighted_loss_avg
from .fedavg import FedAvg
# pylint: disable=too-many-locals
[문서]
class QFedAvg(FedAvg):
"""Configurable QFedAvg strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(
self,
*,
q_param: float = 0.2,
qffl_learning_rate: float = 0.1,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 1,
min_evaluate_clients: int = 1,
min_available_clients: int = 1,
evaluate_fn: Optional[
Callable[
[int, NDArrays, dict[str, Scalar]],
Optional[tuple[float, dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.learning_rate = qffl_learning_rate
self.q_param = q_param
self.pre_weights: Optional[NDArrays] = None
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"QffedAvg(learning_rate={self.learning_rate}, "
rep += f"q_param={self.q_param}, pre_weights={self.pre_weights})"
return rep
[문서]
def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]:
"""Return the sample size and the required number of available clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients
[문서]
def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
[문서]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Convert results
def norm_grad(grad_list: NDArrays) -> float:
# input: nested gradients
# output: square of the L-2 norm
client_grads = grad_list[0]
for i in range(1, len(grad_list)):
client_grads = np.append(
client_grads, grad_list[i]
) # output a flattened array
squared = np.square(client_grads)
summed = np.sum(squared)
return float(summed)
deltas = []
hs_ffl = []
if self.pre_weights is None:
raise AttributeError("QffedAvg pre_weights are None in aggregate_fit")
weights_before = self.pre_weights
eval_result = self.evaluate(
server_round, ndarrays_to_parameters(weights_before)
)
if eval_result is not None:
loss, _ = eval_result
for _, fit_res in results:
new_weights = parameters_to_ndarrays(fit_res.parameters)
# plug in the weight updates into the gradient
grads = [
np.multiply((u - v), 1.0 / self.learning_rate)
for u, v in zip(weights_before, new_weights)
]
deltas.append(
[np.float_power(loss + 1e-10, self.q_param) * grad for grad in grads]
)
# estimation of the local Lipschitz constant
hs_ffl.append(
self.q_param
* np.float_power(loss + 1e-10, (self.q_param - 1))
* norm_grad(grads)
+ (1.0 / self.learning_rate)
* np.float_power(loss + 1e-10, self.q_param)
)
weights_aggregated: NDArrays = aggregate_qffl(weights_before, deltas, hs_ffl)
parameters_aggregated = ndarrays_to_parameters(weights_aggregated)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return parameters_aggregated, metrics_aggregated
[문서]
def aggregate_evaluate(
self,
server_round: int,
results: list[tuple[ClientProxy, EvaluateRes]],
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
) -> tuple[Optional[float], dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Aggregate loss
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
return loss_aggregated, metrics_aggregated