Source code for flwr.client.mod.centraldp_mods

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Clipping modifiers for central DP with client-side clipping."""


from logging import INFO

from flwr.client.typing import ClientAppCallable
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common import recordset_compat as compat
from flwr.common.constant import MessageType
from flwr.common.context import Context
from flwr.common.differential_privacy import (
    compute_adaptive_clip_model_update,
    compute_clip_model_update,
)
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
from flwr.common.logger import log
from flwr.common.message import Message


[docs] def fixedclipping_mod( msg: Message, ctxt: Context, call_next: ClientAppCallable ) -> Message: """Client-side fixed clipping modifier. This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, fixedclipping_mod should be the last to operate on params. """ if msg.metadata.message_type != MessageType.TRAIN: return call_next(msg, ctxt) fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) if KEY_CLIPPING_NORM not in fit_ins.config: raise KeyError( f"The {KEY_CLIPPING_NORM} value is not supplied by the " f"DifferentialPrivacyClientSideFixedClipping wrapper at" f" the server side." ) clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) # Call inner app out_msg = call_next(msg, ctxt) # Check if the msg has error if out_msg.has_error(): return out_msg fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) client_to_server_params = parameters_to_ndarrays(fit_res.parameters) # Clip the client update compute_clip_model_update( client_to_server_params, server_to_client_params, clipping_norm, ) log( INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm ) fit_res.parameters = ndarrays_to_parameters(client_to_server_params) out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) return out_msg
[docs] def adaptiveclipping_mod( msg: Message, ctxt: Context, call_next: ClientAppCallable ) -> Message: """Client-side adaptive clipping modifier. This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It also sends KEY_NORM_BIT to the server for computing the new clipping value. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, adaptiveclipping_mod should be the last to operate on params. """ if msg.metadata.message_type != MessageType.TRAIN: return call_next(msg, ctxt) fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) if KEY_CLIPPING_NORM not in fit_ins.config: raise KeyError( f"The {KEY_CLIPPING_NORM} value is not supplied by the " f"DifferentialPrivacyClientSideFixedClipping wrapper at" f" the server side." ) if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float): raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.") clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) # Call inner app out_msg = call_next(msg, ctxt) # Check if the msg has error if out_msg.has_error(): return out_msg fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) client_to_server_params = parameters_to_ndarrays(fit_res.parameters) # Clip the client update norm_bit = compute_adaptive_clip_model_update( client_to_server_params, server_to_client_params, clipping_norm, ) log( INFO, "adaptiveclipping_mod: parameters are clipped by value: %.4f.", clipping_norm, ) fit_res.parameters = ndarrays_to_parameters(client_to_server_params) fit_res.metrics[KEY_NORM_BIT] = norm_bit out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) return out_msg