# Copyright 2024 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Clipping modifiers for central DP with client-side clipping."""fromloggingimportINFOfromflwr.client.typingimportClientAppCallablefromflwr.commonimportndarrays_to_parameters,parameters_to_ndarraysfromflwr.commonimportrecordset_compatascompatfromflwr.common.constantimportMessageTypefromflwr.common.contextimportContextfromflwr.common.differential_privacyimport(compute_adaptive_clip_model_update,compute_clip_model_update,)fromflwr.common.differential_privacy_constantsimportKEY_CLIPPING_NORM,KEY_NORM_BITfromflwr.common.loggerimportlogfromflwr.common.messageimportMessage
[docs]deffixedclipping_mod(msg:Message,ctxt:Context,call_next:ClientAppCallable)->Message:"""Client-side fixed clipping modifier. This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, fixedclipping_mod should be the last to operate on params. """ifmsg.metadata.message_type!=MessageType.TRAIN:returncall_next(msg,ctxt)fit_ins=compat.recordset_to_fitins(msg.content,keep_input=True)ifKEY_CLIPPING_NORMnotinfit_ins.config:raiseKeyError(f"The {KEY_CLIPPING_NORM} value is not supplied by the "f"DifferentialPrivacyClientSideFixedClipping wrapper at"f" the server side.")clipping_norm=float(fit_ins.config[KEY_CLIPPING_NORM])server_to_client_params=parameters_to_ndarrays(fit_ins.parameters)# Call inner appout_msg=call_next(msg,ctxt)# Check if the msg has errorifout_msg.has_error():returnout_msgfit_res=compat.recordset_to_fitres(out_msg.content,keep_input=True)client_to_server_params=parameters_to_ndarrays(fit_res.parameters)# Clip the client updatecompute_clip_model_update(client_to_server_params,server_to_client_params,clipping_norm,)log(INFO,"fixedclipping_mod: parameters are clipped by value: %.4f.",clipping_norm)fit_res.parameters=ndarrays_to_parameters(client_to_server_params)out_msg.content=compat.fitres_to_recordset(fit_res,keep_input=True)returnout_msg
[docs]defadaptiveclipping_mod(msg:Message,ctxt:Context,call_next:ClientAppCallable)->Message:"""Client-side adaptive clipping modifier. This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It also sends KEY_NORM_BIT to the server for computing the new clipping value. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, adaptiveclipping_mod should be the last to operate on params. """ifmsg.metadata.message_type!=MessageType.TRAIN:returncall_next(msg,ctxt)fit_ins=compat.recordset_to_fitins(msg.content,keep_input=True)ifKEY_CLIPPING_NORMnotinfit_ins.config:raiseKeyError(f"The {KEY_CLIPPING_NORM} value is not supplied by the "f"DifferentialPrivacyClientSideFixedClipping wrapper at"f" the server side.")ifnotisinstance(fit_ins.config[KEY_CLIPPING_NORM],float):raiseValueError(f"{KEY_CLIPPING_NORM} should be a float value.")clipping_norm=float(fit_ins.config[KEY_CLIPPING_NORM])server_to_client_params=parameters_to_ndarrays(fit_ins.parameters)# Call inner appout_msg=call_next(msg,ctxt)# Check if the msg has errorifout_msg.has_error():returnout_msgfit_res=compat.recordset_to_fitres(out_msg.content,keep_input=True)client_to_server_params=parameters_to_ndarrays(fit_res.parameters)# Clip the client updatenorm_bit=compute_adaptive_clip_model_update(client_to_server_params,server_to_client_params,clipping_norm,)log(INFO,"adaptiveclipping_mod: parameters are clipped by value: %.4f.",clipping_norm,)fit_res.parameters=ndarrays_to_parameters(client_to_server_params)fit_res.metrics[KEY_NORM_BIT]=norm_bitout_msg.content=compat.fitres_to_recordset(fit_res,keep_input=True)returnout_msg