Code source de flwr.serverapp.strategy.dp_fixed_clipping

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Message-based Central differential privacy with fixed clipping.

Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
"""


from abc import ABC
from collections.abc import Iterable
from logging import INFO, WARNING

from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
from flwr.common.differential_privacy import (
    add_gaussian_noise_inplace,
    compute_clip_model_update,
    compute_stdv,
)
from flwr.common.differential_privacy_constants import (
    CLIENTS_DISCREPANCY_WARNING,
    KEY_CLIPPING_NORM,
)
from flwr.server import Grid

from .strategy import Strategy


class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
    """Base class for DP strategies with fixed clipping.

    This class contains common functionality shared between server-side and
    client-side fixed clipping implementations.

    Parameters
    ----------
    strategy : Strategy
        The strategy to which DP functionalities will be added by this wrapper.
    noise_multiplier : float
        The noise multiplier for the Gaussian mechanism for model updates.
        A value of 1.0 or higher is recommended for strong privacy.
    clipping_norm : float
        The value of the clipping norm.
    num_sampled_clients : int
        The number of clients that are sampled on each round.
    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes
    def __init__(
        self,
        strategy: Strategy,
        noise_multiplier: float,
        clipping_norm: float,
        num_sampled_clients: int,
    ) -> None:
        super().__init__()

        self.strategy = strategy

        if noise_multiplier < 0:
            raise ValueError("The noise multiplier should be a non-negative value.")

        if clipping_norm <= 0:
            raise ValueError("The clipping norm should be a positive value.")

        if num_sampled_clients <= 0:
            raise ValueError(
                "The number of sampled clients should be a positive value."
            )

        self.noise_multiplier = noise_multiplier
        self.clipping_norm = clipping_norm
        self.num_sampled_clients = num_sampled_clients

    def _add_noise_to_aggregated_arrays(
        self, aggregated_arrays: ArrayRecord
    ) -> ArrayRecord:
        """Add Gaussian noise to aggregated arrays.

        Parameters
        ----------
        aggregated_arrays : ArrayRecord
            The aggregated arrays to add noise to.

        Returns
        -------
        ArrayRecord
            The aggregated arrays with noise added.
        """
        aggregated_ndarrays = aggregated_arrays.to_numpy_ndarrays()
        stdv = compute_stdv(
            self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
        )
        add_gaussian_noise_inplace(aggregated_ndarrays, stdv)

        log(
            INFO,
            "aggregate_fit: central DP noise with %.4f stdev added",
            stdv,
        )

        return ArrayRecord(
            {
                k: Array(v)
                for k, v in zip(
                    aggregated_arrays.keys(), aggregated_ndarrays, strict=True
                )
            }
        )

    def configure_evaluate(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of federated evaluation."""
        return self.strategy.configure_evaluate(server_round, arrays, config, grid)

    def aggregate_evaluate(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> MetricRecord | None:
        """Aggregate MetricRecords in the received Messages."""
        return self.strategy.aggregate_evaluate(server_round, replies)

    def summary(self) -> None:
        """Log summary configuration of the strategy."""
        self.strategy.summary()


[docs] class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippingBase): """Strategy wrapper for central DP with server-side fixed clipping. Parameters ---------- strategy : Strategy The strategy to which DP functionalities will be added by this wrapper. noise_multiplier : float The noise multiplier for the Gaussian mechanism for model updates. A value of 1.0 or higher is recommended for strong privacy. clipping_norm : float The value of the clipping norm. num_sampled_clients : int The number of clients that are sampled on each round. Examples -------- Create a strategy:: strategy = fl.serverapp.FedAvg( ... ) Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper:: dp_strategy = DifferentialPrivacyServerSideFixedClipping( strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients ) """ def __init__( self, strategy: Strategy, noise_multiplier: float, clipping_norm: float, num_sampled_clients: int, ) -> None: super().__init__(strategy, noise_multiplier, clipping_norm, num_sampled_clients) self.current_arrays: ArrayRecord = ArrayRecord() def __repr__(self) -> str: """Compute a string representation of the strategy.""" return "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
[docs] def summary(self) -> None: """Log summary configuration of the strategy.""" log(INFO, "\t├──> DP settings:") log(INFO, "\t\t├── Noise multiplier: %s", self.noise_multiplier) log(INFO, "\t\t└── Clipping norm: %s", self.clipping_norm) super().summary()
[docs] def configure_train( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of training.""" self.current_arrays = arrays return self.strategy.configure_train(server_round, arrays, config, grid)
[docs] def aggregate_train( self, server_round: int, replies: Iterable[Message], ) -> tuple[ArrayRecord | None, MetricRecord | None]: """Aggregate ArrayRecords and MetricRecords in the received Messages.""" if not validate_replies(replies, self.num_sampled_clients): return None, None # Clip arrays in replies current_ndarrays = self.current_arrays.to_numpy_ndarrays() for reply in replies: for arr_name, record in reply.content.array_records.items(): # Clip reply_ndarrays = record.to_numpy_ndarrays() compute_clip_model_update( param1=reply_ndarrays, param2=current_ndarrays, clipping_norm=self.clipping_norm, ) # Replace content while preserving keys reply.content[arr_name] = ArrayRecord( dict(zip(record.keys(), map(Array, reply_ndarrays), strict=True)) ) log( INFO, "aggregate_fit: parameters are clipped by value: %.4f.", self.clipping_norm, ) # Pass the new parameters for aggregation aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train( server_round, replies ) # Add Gaussian noise to the aggregated arrays if aggregated_arrays: aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays) return aggregated_arrays, aggregated_metrics
[docs] class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippingBase): """Strategy wrapper for central DP with client-side fixed clipping. Use `fixedclipping_mod` modifier at the client side. In comparison to `DifferentialPrivacyServerSideFixedClipping`, which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping` expects clipping to happen on the client-side, usually by using the built-in `fixedclipping_mod`. Parameters ---------- strategy : Strategy The strategy to which DP functionalities will be added by this wrapper. noise_multiplier : float The noise multiplier for the Gaussian mechanism for model updates. A value of 1.0 or higher is recommended for strong privacy. clipping_norm : float The value of the clipping norm. num_sampled_clients : int The number of clients that are sampled on each round. Examples -------- Create a strategy:: strategy = fl.serverapp.FedAvg(...) Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper:: dp_strategy = DifferentialPrivacyClientSideFixedClipping( strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients ) On the client, add the `fixedclipping_mod` to the client-side mods:: app = fl.client.ClientApp(mods=[fixedclipping_mod]) """ def __repr__(self) -> str: """Compute a string representation of the strategy.""" return "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
[docs] def summary(self) -> None: """Log summary configuration of the strategy.""" log(INFO, "\t├──> DP settings:") log(INFO, "\t\t├── Noise multiplier: %s", self.noise_multiplier) log(INFO, "\t\t└── Clipping norm: %s", self.clipping_norm) super().summary()
[docs] def configure_train( self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid ) -> Iterable[Message]: """Configure the next round of training.""" # Inject clipping norm in config config[KEY_CLIPPING_NORM] = self.clipping_norm # Call parent method return self.strategy.configure_train(server_round, arrays, config, grid)
[docs] def aggregate_train( self, server_round: int, replies: Iterable[Message], ) -> tuple[ArrayRecord | None, MetricRecord | None]: """Aggregate ArrayRecords and MetricRecords in the received Messages.""" if not validate_replies(replies, self.num_sampled_clients): return None, None # Aggregate aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train( server_round, replies ) # Add Gaussian noise to the aggregated arrays if aggregated_arrays: aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays) return aggregated_arrays, aggregated_metrics
def validate_replies(replies: Iterable[Message], num_sampled_clients: int) -> bool: """Validate replies and log errors/warnings. Arguments ---------- replies : Iterable[Message] The replies to validate. num_sampled_clients : int The expected number of sampled clients. Returns ------- bool True if replies are valid for aggregation, False otherwise. """ num_errors = 0 num_replies_with_content = 0 for msg in replies: if msg.has_error(): log( INFO, "Received error in reply from node %d: %s", msg.metadata.src_node_id, msg.error, ) num_errors += 1 else: num_replies_with_content += 1 # Errors are not allowed if num_errors: log( INFO, "aggregate_train: Some clients reported errors. Skipping aggregation.", ) return False log( INFO, "aggregate_train: Received %s results and %s failures", num_replies_with_content, num_errors, ) if num_replies_with_content != num_sampled_clients: log( WARNING, CLIENTS_DISCREPANCY_WARNING, num_replies_with_content, num_sampled_clients, ) return True