# Copyright 2025 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Clipping modifiers for central DP with client-side clipping."""fromcollectionsimportOrderedDictfromloggingimportINFO,WARNfromtypingimportcastfromflwr.client.typingimportClientAppCallablefromflwr.commonimportArray,ArrayRecord,Context,Message,MessageType,logfromflwr.common.differential_privacyimportcompute_clip_model_updatefromflwr.common.differential_privacy_constantsimportKEY_CLIPPING_NORM# pylint: disable=too-many-return-statements
[๋ฌธ์]deffixedclipping_mod(msg:Message,ctxt:Context,call_next:ClientAppCallable)->Message:"""Client-side fixed clipping modifier. This mod needs to be used with the `DifferentialPrivacyClientSideFixedClipping` server-side strategy wrapper. The wrapper sends the clipping_norm value to the client. This mod clips the client model updates before sending them to the server. It operates on messages of type `MessageType.TRAIN`. Notes ----- Consider the order of mods when using multiple. Typically, fixedclipping_mod should be the last to operate on params. """ifmsg.metadata.message_type!=MessageType.TRAIN:returncall_next(msg,ctxt)iflen(msg.content.array_records)!=1:log(WARN,"fixedclipping_mod is designed to work with a single ArrayRecord. ""Skipping.",)returncall_next(msg,ctxt)iflen(msg.content.config_records)!=1:log(WARN,"fixedclipping_mod is designed to work with a single ConfigRecord. ""Skipping.",)returncall_next(msg,ctxt)# Get keys in the single ConfigRecordkeys_in_config=set(next(iter(msg.content.config_records.values())).keys())ifKEY_CLIPPING_NORMnotinkeys_in_config:raiseKeyError(f"The {KEY_CLIPPING_NORM} value is not supplied by the "f"`DifferentialPrivacyClientSideFixedClipping` wrapper at"f" the server side.")# Record array record communicated to client and clipping normoriginal_array_record=next(iter(msg.content.array_records.values()))clipping_norm=cast(float,next(iter(msg.content.config_records.values()))[KEY_CLIPPING_NORM])# Call inner appout_msg=call_next(msg,ctxt)# Check if the msg has errorifout_msg.has_error():returnout_msg# Ensure there is a single ArrayRecordiflen(out_msg.content.array_records)!=1:log(WARN,"fixedclipping_mod is designed to work with a single ArrayRecord. ""Skipping.",)returnout_msgnew_array_record_key,client_to_server_arrecord=next(iter(out_msg.content.array_records.items()))# Ensure keys in returned ArrayRecord match those in the one sent from serverifset(original_array_record.keys())!=set(client_to_server_arrecord.keys()):log(WARN,"fixedclipping_mod: Keys in ArrayRecord must match those from the model ""that the ClientApp received. Skipping.",)returnout_msgclient_to_server_ndarrays=client_to_server_arrecord.to_numpy_ndarrays()# Clip the client updatecompute_clip_model_update(param1=client_to_server_ndarrays,param2=original_array_record.to_numpy_ndarrays(),clipping_norm=clipping_norm,)log(INFO,"fixedclipping_mod: parameters are clipped by value: %.4f.",clipping_norm)# Replace outgoing ArrayRecord's Array while preserving their keysout_msg.content.array_records[new_array_record_key]=ArrayRecord(OrderedDict({k:Array(v)fork,vinzip(client_to_server_arrecord.keys(),client_to_server_ndarrays)}))returnout_msg