# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Message-based Central differential privacy with fixed clipping.
Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
"""
from abc import ABC
from collections.abc import Iterable
from logging import INFO, WARNING
from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
from flwr.common.differential_privacy import (
add_gaussian_noise_inplace,
compute_clip_model_update,
compute_stdv,
)
from flwr.common.differential_privacy_constants import (
CLIENTS_DISCREPANCY_WARNING,
KEY_CLIPPING_NORM,
)
from flwr.server import Grid
from .strategy import Strategy
class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
"""Base class for DP strategies with fixed clipping.
This class contains common functionality shared between server-side and
client-side fixed clipping implementations.
Parameters
----------
strategy : Strategy
The strategy to which DP functionalities will be added by this wrapper.
noise_multiplier : float
The noise multiplier for the Gaussian mechanism for model updates.
A value of 1.0 or higher is recommended for strong privacy.
clipping_norm : float
The value of the clipping norm.
num_sampled_clients : int
The number of clients that are sampled on each round.
"""
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(
self,
strategy: Strategy,
noise_multiplier: float,
clipping_norm: float,
num_sampled_clients: int,
) -> None:
super().__init__()
self.strategy = strategy
if noise_multiplier < 0:
raise ValueError("The noise multiplier should be a non-negative value.")
if clipping_norm <= 0:
raise ValueError("The clipping norm should be a positive value.")
if num_sampled_clients <= 0:
raise ValueError(
"The number of sampled clients should be a positive value."
)
self.noise_multiplier = noise_multiplier
self.clipping_norm = clipping_norm
self.num_sampled_clients = num_sampled_clients
def _add_noise_to_aggregated_arrays(
self, aggregated_arrays: ArrayRecord
) -> ArrayRecord:
"""Add Gaussian noise to aggregated arrays.
Parameters
----------
aggregated_arrays : ArrayRecord
The aggregated arrays to add noise to.
Returns
-------
ArrayRecord
The aggregated arrays with noise added.
"""
aggregated_ndarrays = aggregated_arrays.to_numpy_ndarrays()
stdv = compute_stdv(
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
)
add_gaussian_noise_inplace(aggregated_ndarrays, stdv)
log(
INFO,
"aggregate_fit: central DP noise with %.4f stdev added",
stdv,
)
return ArrayRecord(
{
k: Array(v)
for k, v in zip(
aggregated_arrays.keys(), aggregated_ndarrays, strict=True
)
}
)
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of federated evaluation."""
return self.strategy.configure_evaluate(server_round, arrays, config, grid)
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> MetricRecord | None:
"""Aggregate MetricRecords in the received Messages."""
return self.strategy.aggregate_evaluate(server_round, replies)
def summary(self) -> None:
"""Log summary configuration of the strategy."""
self.strategy.summary()
[문서]
class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippingBase):
"""Strategy wrapper for central DP with server-side fixed clipping.
Parameters
----------
strategy : Strategy
The strategy to which DP functionalities will be added by this wrapper.
noise_multiplier : float
The noise multiplier for the Gaussian mechanism for model updates.
A value of 1.0 or higher is recommended for strong privacy.
clipping_norm : float
The value of the clipping norm.
num_sampled_clients : int
The number of clients that are sampled on each round.
Examples
--------
Create a strategy::
strategy = fl.serverapp.FedAvg( ... )
Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper::
dp_strategy = DifferentialPrivacyServerSideFixedClipping(
strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
)
"""
def __init__(
self,
strategy: Strategy,
noise_multiplier: float,
clipping_norm: float,
num_sampled_clients: int,
) -> None:
super().__init__(strategy, noise_multiplier, clipping_norm, num_sampled_clients)
self.current_arrays: ArrayRecord = ArrayRecord()
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
return "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
[문서]
def summary(self) -> None:
"""Log summary configuration of the strategy."""
log(INFO, "\t├──> DP settings:")
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
log(INFO, "\t│\t└── Clipping norm: %s", self.clipping_norm)
super().summary()
[문서]
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[ArrayRecord | None, MetricRecord | None]:
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
if not validate_replies(replies, self.num_sampled_clients):
return None, None
# Clip arrays in replies
current_ndarrays = self.current_arrays.to_numpy_ndarrays()
for reply in replies:
for arr_name, record in reply.content.array_records.items():
# Clip
reply_ndarrays = record.to_numpy_ndarrays()
compute_clip_model_update(
param1=reply_ndarrays,
param2=current_ndarrays,
clipping_norm=self.clipping_norm,
)
# Replace content while preserving keys
reply.content[arr_name] = ArrayRecord(
dict(zip(record.keys(), map(Array, reply_ndarrays), strict=True))
)
log(
INFO,
"aggregate_fit: parameters are clipped by value: %.4f.",
self.clipping_norm,
)
# Pass the new parameters for aggregation
aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
server_round, replies
)
# Add Gaussian noise to the aggregated arrays
if aggregated_arrays:
aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
return aggregated_arrays, aggregated_metrics
[문서]
class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippingBase):
"""Strategy wrapper for central DP with client-side fixed clipping.
Use `fixedclipping_mod` modifier at the client side.
In comparison to `DifferentialPrivacyServerSideFixedClipping`,
which performs clipping on the server-side,
`DifferentialPrivacyClientSideFixedClipping` expects clipping to happen
on the client-side, usually by using the built-in `fixedclipping_mod`.
Parameters
----------
strategy : Strategy
The strategy to which DP functionalities will be added by this wrapper.
noise_multiplier : float
The noise multiplier for the Gaussian mechanism for model updates.
A value of 1.0 or higher is recommended for strong privacy.
clipping_norm : float
The value of the clipping norm.
num_sampled_clients : int
The number of clients that are sampled on each round.
Examples
--------
Create a strategy::
strategy = fl.serverapp.FedAvg(...)
Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper::
dp_strategy = DifferentialPrivacyClientSideFixedClipping(
strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
)
On the client, add the `fixedclipping_mod` to the client-side mods::
app = fl.client.ClientApp(mods=[fixedclipping_mod])
"""
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
return "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
[문서]
def summary(self) -> None:
"""Log summary configuration of the strategy."""
log(INFO, "\t├──> DP settings:")
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
log(INFO, "\t│\t└── Clipping norm: %s", self.clipping_norm)
super().summary()
[문서]
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[ArrayRecord | None, MetricRecord | None]:
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
if not validate_replies(replies, self.num_sampled_clients):
return None, None
# Aggregate
aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
server_round, replies
)
# Add Gaussian noise to the aggregated arrays
if aggregated_arrays:
aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
return aggregated_arrays, aggregated_metrics
def validate_replies(replies: Iterable[Message], num_sampled_clients: int) -> bool:
"""Validate replies and log errors/warnings.
Arguments
----------
replies : Iterable[Message]
The replies to validate.
num_sampled_clients : int
The expected number of sampled clients.
Returns
-------
bool
True if replies are valid for aggregation, False otherwise.
"""
num_errors = 0
num_replies_with_content = 0
for msg in replies:
if msg.has_error():
log(
INFO,
"Received error in reply from node %d: %s",
msg.metadata.src_node_id,
msg.error,
)
num_errors += 1
else:
num_replies_with_content += 1
# Errors are not allowed
if num_errors:
log(
INFO,
"aggregate_train: Some clients reported errors. Skipping aggregation.",
)
return False
log(
INFO,
"aggregate_train: Received %s results and %s failures",
num_replies_with_content,
num_errors,
)
if num_replies_with_content != num_sampled_clients:
log(
WARNING,
CLIENTS_DISCREPANCY_WARNING,
num_replies_with_content,
num_sampled_clients,
)
return True