Source code for flwr.clientapp.mod.centraldp_mods

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Clipping modifiers for central DP with client-side clipping."""


from collections import OrderedDict
from logging import ERROR, INFO
from typing import cast

from flwr.app import Error
from flwr.clientapp.typing import ClientAppCallable
from flwr.common import (
    Array,
    ArrayRecord,
    ConfigRecord,
    Context,
    Message,
    MetricRecord,
    log,
)
from flwr.common.constant import ErrorCode
from flwr.common.differential_privacy import (
    compute_adaptive_clip_model_update,
    compute_clip_model_update,
)
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT


# pylint: disable=too-many-return-statements
[docs] def fixedclipping_mod( msg: Message, ctxt: Context, call_next: ClientAppCallable ) -> Message: """Client-side fixed clipping modifier. This mod needs to be used with the `DifferentialPrivacyClientSideFixedClipping` server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, fixedclipping_mod should be the last to operate on params. """ if len(msg.content.array_records) != 1: return _handle_multi_record_err("fixedclipping_mod", msg, ArrayRecord) if len(msg.content.config_records) != 1: return _handle_multi_record_err("fixedclipping_mod", msg, ConfigRecord) # Get keys in the single ConfigRecord keys_in_config = set(next(iter(msg.content.config_records.values())).keys()) if KEY_CLIPPING_NORM not in keys_in_config: return _handle_no_key_err("fixedclipping_mod", msg) # Record array record communicated to client and clipping norm original_array_record = next(iter(msg.content.array_records.values())) clipping_norm = cast( float, next(iter(msg.content.config_records.values()))[KEY_CLIPPING_NORM] ) # Call inner app out_msg = call_next(msg, ctxt) # Check if the msg has error if out_msg.has_error(): return out_msg # Ensure reply has a single ArrayRecord if len(out_msg.content.array_records) != 1: return _handle_multi_record_err("fixedclipping_mod", out_msg, ArrayRecord) new_array_record_key, client_to_server_arrecord = next( iter(out_msg.content.array_records.items()) ) # Ensure keys in returned ArrayRecord match those in the one sent from server if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()): return _handle_array_key_mismatch_err("fixedclipping_mod", out_msg) client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays() # Clip the client update compute_clip_model_update( param1=client_to_server_ndarrays, param2=original_array_record.to_numpy_ndarrays(), clipping_norm=clipping_norm, ) log( INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm ) # Replace outgoing ArrayRecord's Array while preserving their keys out_msg.content.array_records[new_array_record_key] = ArrayRecord( OrderedDict( { k: Array(v) for k, v in zip( client_to_server_arrecord.keys(), client_to_server_ndarrays ) } ) ) return out_msg
[docs] def adaptiveclipping_mod( msg: Message, ctxt: Context, call_next: ClientAppCallable ) -> Message: """Client-side adaptive clipping modifier. This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It also sends KEY_NORM_BIT to the server for computing the new clipping value. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, adaptiveclipping_mod should be the last to operate on params. """ if len(msg.content.array_records) != 1: return _handle_multi_record_err("adaptiveclipping_mod", msg, ArrayRecord) if len(msg.content.config_records) != 1: return _handle_multi_record_err("adaptiveclipping_mod", msg, ConfigRecord) # Get keys in the single ConfigRecord keys_in_config = set(next(iter(msg.content.config_records.values())).keys()) if KEY_CLIPPING_NORM not in keys_in_config: return _handle_no_key_err("adaptiveclipping_mod", msg) # Record array record communicated to client and clipping norm original_array_record = next(iter(msg.content.array_records.values())) clipping_norm = cast( float, next(iter(msg.content.config_records.values()))[KEY_CLIPPING_NORM] ) # Call inner app out_msg = call_next(msg, ctxt) # Check if the msg has error if out_msg.has_error(): return out_msg # Ensure reply has a single ArrayRecord if len(out_msg.content.array_records) != 1: return _handle_multi_record_err("adaptiveclipping_mod", out_msg, ArrayRecord) # Ensure reply has a single MetricRecord if len(out_msg.content.metric_records) != 1: return _handle_multi_record_err("adaptiveclipping_mod", out_msg, MetricRecord) new_array_record_key, client_to_server_arrecord = next( iter(out_msg.content.array_records.items()) ) # Ensure keys in returned ArrayRecord match those in the one sent from server if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()): return _handle_array_key_mismatch_err("adaptiveclipping_mod", out_msg) client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays() # Clip the client update norm_bit = compute_adaptive_clip_model_update( client_to_server_ndarrays, original_array_record.to_numpy_ndarrays(), clipping_norm, ) log( INFO, "adaptiveclipping_mod: ndarrays are clipped by value: %.4f.", clipping_norm, ) # Replace outgoing ArrayRecord's Array while preserving their keys out_msg.content.array_records[new_array_record_key] = ArrayRecord( OrderedDict( { k: Array(v) for k, v in zip( client_to_server_arrecord.keys(), client_to_server_ndarrays ) } ) ) # Add to the MetricRecords the norm bit (recall reply messages only contain # one MetricRecord) metric_record_key = list(out_msg.content.metric_records.keys())[0] # We cast it to `int` because MetricRecord does not support `bool` values out_msg.content.metric_records[metric_record_key][KEY_NORM_BIT] = int(norm_bit) return out_msg
def _handle_err(msg: Message, reason: str) -> Message: """Log and return error message.""" log(ERROR, reason) return Message( Error(code=ErrorCode.MOD_FAILED_PRECONDITION, reason=reason), reply_to=msg, ) def _handle_multi_record_err(mod_name: str, msg: Message, record_type: type) -> Message: """Log and return multi-record error.""" cnt = sum(isinstance(_, record_type) for _ in msg.content.values()) return _handle_err( msg, f"{mod_name} expects exactly one {record_type.__name__}, " f"but found {cnt} {record_type.__name__}(s).", ) def _handle_no_key_err(mod_name: str, msg: Message) -> Message: """Log and return no-key error.""" return _handle_err( msg, f"{mod_name} requires the key '{KEY_CLIPPING_NORM}' to be present in the " "ConfigRecord, but it was not found. " "Please ensure the `DifferentialPrivacyClientSideFixedClipping` wrapper " "is used in the ServerApp.", ) def _handle_array_key_mismatch_err(mod_name: str, msg: Message) -> Message: """Create array-key-mismatch error reasons.""" return _handle_err( msg, f"{mod_name} expects the keys in the ArrayRecord of the reply message to match " "those from the ArrayRecord that the ClientApp received, but they do not.", )