Code source de flwr.client.mod.secure_aggregation.secaggplus_mod
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Modifier for the SecAgg+ protocol."""importosfromdataclassesimportdataclass,fieldfromloggingimportDEBUG,WARNINGfromtypingimportAny,castfromflwr.client.typingimportClientAppCallablefromflwr.commonimport(ConfigRecord,Context,Message,Parameters,RecordDict,ndarray_to_bytes,parameters_to_ndarrays,)fromflwr.commonimportrecorddict_compatascompatfromflwr.common.constantimportMessageTypefromflwr.common.loggerimportlogfromflwr.common.secure_aggregation.crypto.shamirimportcreate_sharesfromflwr.common.secure_aggregation.crypto.symmetric_encryptionimport(bytes_to_private_key,bytes_to_public_key,decrypt,encrypt,generate_key_pairs,generate_shared_key,private_key_to_bytes,public_key_to_bytes,)fromflwr.common.secure_aggregation.ndarrays_arithmeticimport(factor_combine,parameters_addition,parameters_mod,parameters_multiply,parameters_subtraction,)fromflwr.common.secure_aggregation.quantizationimportquantizefromflwr.common.secure_aggregation.secaggplus_constantsimport(RECORD_KEY_CONFIGS,RECORD_KEY_STATE,Key,Stage,)fromflwr.common.secure_aggregation.secaggplus_utilsimport(pseudo_rand_gen,share_keys_plaintext_concat,share_keys_plaintext_separate,)fromflwr.common.typingimportConfigRecordValues@dataclass# pylint: disable-next=too-many-instance-attributesclassSecAggPlusState:"""State of the SecAgg+ protocol."""current_stage:str=Stage.UNMASKnid:int=0sample_num:int=0share_num:int=0threshold:int=0clipping_range:float=0.0target_range:int=0mod_range:int=0max_weight:float=0.0# Secret key (sk) and public key (pk)sk1:bytes=b""pk1:bytes=b""sk2:bytes=b""pk2:bytes=b""# Random seed for generating the private maskrd_seed:bytes=b""rd_seed_share_dict:dict[int,bytes]=field(default_factory=dict)sk1_share_dict:dict[int,bytes]=field(default_factory=dict)# The dict of the shared secrets from sk2ss2_dict:dict[int,bytes]=field(default_factory=dict)public_keys_dict:dict[int,tuple[bytes,bytes]]=field(default_factory=dict)def__init__(self,**kwargs:ConfigRecordValues)->None:fork,vinkwargs.items():ifk.endswith(":V"):continuenew_v:Any=vifk.endswith(":K"):k=k[:-2]keys=cast(list[int],v)values=cast(list[bytes],kwargs[f"{k}:V"])iflen(values)>len(keys):updated_values=[tuple(values[i:i+2])foriinrange(0,len(values),2)]new_v=dict(zip(keys,updated_values))else:new_v=dict(zip(keys,values))self.__setattr__(k,new_v)defto_dict(self)->dict[str,ConfigRecordValues]:"""Convert the state to a dictionary."""ret=vars(self)forkinlist(ret.keys()):ifisinstance(ret[k],dict):# Replace dict with two listsv=cast(dict[str,Any],ret.pop(k))ret[f"{k}:K"]=list(v.keys())ifk=="public_keys_dict":v_list:list[bytes]=[]forb1_b2incast(list[tuple[bytes,bytes]],v.values()):v_list.extend(b1_b2)ret[f"{k}:V"]=v_listelse:ret[f"{k}:V"]=list(v.values())returnret
[docs]defsecaggplus_mod(msg:Message,ctxt:Context,call_next:ClientAppCallable,)->Message:"""Handle incoming message and return results, following the SecAgg+ protocol."""# Ignore non-fit messagesifmsg.metadata.message_type!=MessageType.TRAIN:returncall_next(msg,ctxt)# Retrieve local stateifRECORD_KEY_STATEnotinctxt.state.config_records:ctxt.state.config_records[RECORD_KEY_STATE]=ConfigRecord({})state_dict=ctxt.state.config_records[RECORD_KEY_STATE]state=SecAggPlusState(**state_dict)# Retrieve incoming configsconfigs=msg.content.config_records[RECORD_KEY_CONFIGS]# Check the validity of the next stagecheck_stage(state.current_stage,configs)# Update the current stagestate.current_stage=cast(str,configs.pop(Key.STAGE))# Check the validity of the configs based on the current stagecheck_configs(state.current_stage,configs)# Executeout_content=RecordDict()ifstate.current_stage==Stage.SETUP:state.nid=msg.metadata.dst_node_idres=_setup(state,configs)elifstate.current_stage==Stage.SHARE_KEYS:res=_share_keys(state,configs)elifstate.current_stage==Stage.COLLECT_MASKED_VECTORS:out_msg=call_next(msg,ctxt)out_content=out_msg.contentfitres=compat.recorddict_to_fitres(out_content,keep_input=True)res=_collect_masked_vectors(state,configs,fitres.num_examples,fitres.parameters)forarr_recordinout_content.array_records.values():arr_record.clear()elifstate.current_stage==Stage.UNMASK:res=_unmask(state,configs)else:raiseValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")# Save statectxt.state.config_records[RECORD_KEY_STATE]=ConfigRecord(state.to_dict())# Return messageout_content.config_records[RECORD_KEY_CONFIGS]=ConfigRecord(res,False)returnMessage(out_content,reply_to=msg)
defcheck_stage(current_stage:str,configs:ConfigRecord)->None:"""Check the validity of the next stage."""# Check the existence of Config.STAGEifKey.STAGEnotinconfigs:raiseKeyError(f"The required key '{Key.STAGE}' is missing from the ConfigRecord.")# Check the value type of the Config.STAGEnext_stage=configs[Key.STAGE]ifnotisinstance(next_stage,str):raiseTypeError(f"The value for the key '{Key.STAGE}' must be of type {str}, "f"but got {type(next_stage)} instead.")# Check the validity of the next stageifnext_stage==Stage.SETUP:ifcurrent_stage!=Stage.UNMASK:log(WARNING,"Restart from the setup stage")# If stage is not "setup",# the stage from configs should be the expected next stageelse:stages=Stage.all()expected_next_stage=stages[(stages.index(current_stage)+1)%len(stages)]ifnext_stage!=expected_next_stage:raiseValueError("Abort secure aggregation: "f"expect {expected_next_stage} stage, but receive {next_stage} stage")# pylint: disable-next=too-many-branchesdefcheck_configs(stage:str,configs:ConfigRecord)->None:"""Check the validity of the configs."""# Check configs for the setup stageifstage==Stage.SETUP:key_type_pairs=[(Key.SAMPLE_NUMBER,int),(Key.SHARE_NUMBER,int),(Key.THRESHOLD,int),(Key.CLIPPING_RANGE,float),(Key.TARGET_RANGE,int),(Key.MOD_RANGE,int),]forkey,expected_typeinkey_type_pairs:ifkeynotinconfigs:raiseKeyError(f"Stage {Stage.SETUP}: the required key '{key}' is ""missing from the ConfigRecord.")# Bool is a subclass of int in Python,# so `isinstance(v, int)` will return True even if v is a boolean.# pylint: disable-next=unidiomatic-typecheckiftype(configs[key])isnotexpected_type:raiseTypeError(f"Stage {Stage.SETUP}: The value for the key '{key}' "f"must be of type {expected_type}, "f"but got {type(configs[key])} instead.")elifstage==Stage.SHARE_KEYS:forkey,valueinconfigs.items():if(notisinstance(value,list)orlen(value)!=2ornotisinstance(value[0],bytes)ornotisinstance(value[1],bytes)):raiseTypeError(f"Stage {Stage.SHARE_KEYS}: "f"the value for the key '{key}' must be a list of two bytes.")elifstage==Stage.COLLECT_MASKED_VECTORS:key_type_pairs=[(Key.CIPHERTEXT_LIST,bytes),(Key.SOURCE_LIST,int),]forkey,expected_typeinkey_type_pairs:ifkeynotinconfigs:raiseKeyError(f"Stage {Stage.COLLECT_MASKED_VECTORS}: "f"the required key '{key}' is ""missing from the ConfigRecord.")ifnotisinstance(configs[key],list)orany(elmforelmincast(list[Any],configs[key])# pylint: disable-next=unidiomatic-typecheckiftype(elm)isnotexpected_type):raiseTypeError(f"Stage {Stage.COLLECT_MASKED_VECTORS}: "f"the value for the key '{key}' "f"must be of type List[{expected_type.__name__}]")elifstage==Stage.UNMASK:key_type_pairs=[(Key.ACTIVE_NODE_ID_LIST,int),(Key.DEAD_NODE_ID_LIST,int),]forkey,expected_typeinkey_type_pairs:ifkeynotinconfigs:raiseKeyError(f"Stage {Stage.UNMASK}: "f"the required key '{key}' is ""missing from the ConfigRecord.")ifnotisinstance(configs[key],list)orany(elmforelmincast(list[Any],configs[key])# pylint: disable-next=unidiomatic-typecheckiftype(elm)isnotexpected_type):raiseTypeError(f"Stage {Stage.UNMASK}: "f"the value for the key '{key}' "f"must be of type List[{expected_type.__name__}]")else:raiseValueError(f"Unknown secagg stage: {stage}")def_setup(state:SecAggPlusState,configs:ConfigRecord)->dict[str,ConfigRecordValues]:# Assigning parameter values to object fieldssec_agg_param_dict=configsstate.sample_num=cast(int,sec_agg_param_dict[Key.SAMPLE_NUMBER])log(DEBUG,"Node %d: starting stage 0...",state.nid)state.share_num=cast(int,sec_agg_param_dict[Key.SHARE_NUMBER])state.threshold=cast(int,sec_agg_param_dict[Key.THRESHOLD])state.clipping_range=cast(float,sec_agg_param_dict[Key.CLIPPING_RANGE])state.target_range=cast(int,sec_agg_param_dict[Key.TARGET_RANGE])state.mod_range=cast(int,sec_agg_param_dict[Key.MOD_RANGE])state.max_weight=cast(float,sec_agg_param_dict[Key.MAX_WEIGHT])# Dictionaries containing node IDs as keys# and their respective secret shares as values.state.rd_seed_share_dict={}state.sk1_share_dict={}# Dictionary containing node IDs as keys# and their respective shared secrets (with this client) as values.state.ss2_dict={}# Create 2 sets private public key pairs# One for creating pairwise masks# One for encrypting message to distribute sharessk1,pk1=generate_key_pairs()sk2,pk2=generate_key_pairs()state.sk1,state.pk1=private_key_to_bytes(sk1),public_key_to_bytes(pk1)state.sk2,state.pk2=private_key_to_bytes(sk2),public_key_to_bytes(pk2)log(DEBUG,"Node %d: stage 0 completes. uploading public keys...",state.nid)return{Key.PUBLIC_KEY_1:state.pk1,Key.PUBLIC_KEY_2:state.pk2}# pylint: disable-next=too-many-localsdef_share_keys(state:SecAggPlusState,configs:ConfigRecord)->dict[str,ConfigRecordValues]:named_bytes_tuples=cast(dict[str,tuple[bytes,bytes]],configs)key_dict={int(sid):(pk1,pk2)forsid,(pk1,pk2)innamed_bytes_tuples.items()}log(DEBUG,"Node %d: starting stage 1...",state.nid)state.public_keys_dict=key_dict# Check if the size is larger than thresholdiflen(state.public_keys_dict)<state.threshold:raiseValueError("Available neighbours number smaller than threshold")# Check if all public keys are uniquepk_list:list[bytes]=[]forpk1,pk2instate.public_keys_dict.values():pk_list.append(pk1)pk_list.append(pk2)iflen(set(pk_list))!=len(pk_list):raiseValueError("Some public keys are identical")# Check if public keys of this client are correct in the dictionaryif(state.public_keys_dict[state.nid][0]!=state.pk1orstate.public_keys_dict[state.nid][1]!=state.pk2):raiseValueError("Own public keys are displayed in dict incorrectly, should not happen!")# Generate the private mask seedstate.rd_seed=os.urandom(32)# Create shares for the private mask seed and the first private keyb_shares=create_shares(state.rd_seed,state.threshold,state.share_num)sk1_shares=create_shares(state.sk1,state.threshold,state.share_num)srcs,dsts,ciphertexts=[],[],[]# Distribute sharesforidx,(nid,(_,pk2))inenumerate(state.public_keys_dict.items()):ifnid==state.nid:state.rd_seed_share_dict[state.nid]=b_shares[idx]state.sk1_share_dict[state.nid]=sk1_shares[idx]else:shared_key=generate_shared_key(bytes_to_private_key(state.sk2),bytes_to_public_key(pk2),)state.ss2_dict[nid]=shared_keyplaintext=share_keys_plaintext_concat(state.nid,nid,b_shares[idx],sk1_shares[idx])ciphertext=encrypt(shared_key,plaintext)srcs.append(state.nid)dsts.append(nid)ciphertexts.append(ciphertext)log(DEBUG,"Node %d: stage 1 completes. uploading key shares...",state.nid)return{Key.DESTINATION_LIST:dsts,Key.CIPHERTEXT_LIST:ciphertexts}# pylint: disable-next=too-many-localsdef_collect_masked_vectors(state:SecAggPlusState,configs:ConfigRecord,num_examples:int,updated_parameters:Parameters,)->dict[str,ConfigRecordValues]:log(DEBUG,"Node %d: starting stage 2...",state.nid)available_clients:list[int]=[]ciphertexts=cast(list[bytes],configs[Key.CIPHERTEXT_LIST])srcs=cast(list[int],configs[Key.SOURCE_LIST])iflen(ciphertexts)+1<state.threshold:raiseValueError("Not enough available neighbour clients.")# Decrypt ciphertexts, verify their sources, and store shares.forsrc,ciphertextinzip(srcs,ciphertexts):shared_key=state.ss2_dict[src]plaintext=decrypt(shared_key,ciphertext)actual_src,dst,rd_seed_share,sk1_share=share_keys_plaintext_separate(plaintext)available_clients.append(src)ifsrc!=actual_src:raiseValueError(f"Node {state.nid}: received ciphertext "f"from {actual_src} instead of {src}.")ifdst!=state.nid:raiseValueError(f"Node {state.nid}: received an encrypted message"f"for Node {dst} from Node {src}.")state.rd_seed_share_dict[src]=rd_seed_sharestate.sk1_share_dict[src]=sk1_share# Fitratio=num_examples/state.max_weightifratio>1:log(WARNING,"Potential overflow warning: the provided weight (%s) exceeds the specified"" max_weight (%s). This may lead to overflow issues.",num_examples,state.max_weight,)q_ratio=round(ratio*state.target_range)dq_ratio=q_ratio/state.target_rangeparameters=parameters_to_ndarrays(updated_parameters)parameters=parameters_multiply(parameters,dq_ratio)# Quantize parameter update (vector)quantized_parameters=quantize(parameters,state.clipping_range,state.target_range)quantized_parameters=factor_combine(q_ratio,quantized_parameters)dimensions_list:list[tuple[int,...]]=[a.shapeforainquantized_parameters]# Add private maskprivate_mask=pseudo_rand_gen(state.rd_seed,state.mod_range,dimensions_list)quantized_parameters=parameters_addition(quantized_parameters,private_mask)fornode_idinavailable_clients:# Add pairwise masksshared_key=generate_shared_key(bytes_to_private_key(state.sk1),bytes_to_public_key(state.public_keys_dict[node_id][0]),)pairwise_mask=pseudo_rand_gen(shared_key,state.mod_range,dimensions_list)ifstate.nid>node_id:quantized_parameters=parameters_addition(quantized_parameters,pairwise_mask)else:quantized_parameters=parameters_subtraction(quantized_parameters,pairwise_mask)# Take mod of final weight update vector and return to serverquantized_parameters=parameters_mod(quantized_parameters,state.mod_range)log(DEBUG,"Node %d: stage 2 completed, uploading masked parameters...",state.nid)return{Key.MASKED_PARAMETERS:[ndarray_to_bytes(arr)forarrinquantized_parameters]}def_unmask(state:SecAggPlusState,configs:ConfigRecord)->dict[str,ConfigRecordValues]:log(DEBUG,"Node %d: starting stage 3...",state.nid)active_nids=cast(list[int],configs[Key.ACTIVE_NODE_ID_LIST])dead_nids=cast(list[int],configs[Key.DEAD_NODE_ID_LIST])# Send private mask seed share for every avaliable client (including itself)# Send first private key share for building pairwise mask for every dropped clientiflen(active_nids)<state.threshold:raiseValueError("Available neighbours number smaller than threshold")all_nids,shares=[],[]all_nids=active_nids+dead_nidsshares+=[state.rd_seed_share_dict[nid]fornidinactive_nids]shares+=[state.sk1_share_dict[nid]fornidindead_nids]log(DEBUG,"Node %d: stage 3 completes. uploading key shares...",state.nid)return{Key.NODE_ID_LIST:all_nids,Key.SHARE_LIST:shares}