Source code for flwr.client.mod.secure_aggregation.secaggplus_mod

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Modifier for the SecAgg+ protocol."""


import os
from dataclasses import dataclass, field
from logging import DEBUG, WARNING
from typing import Any, cast

from flwr.client.typing import ClientAppCallable
from flwr.common import (
    ConfigsRecord,
    Context,
    Message,
    Parameters,
    RecordSet,
    ndarray_to_bytes,
    parameters_to_ndarrays,
)
from flwr.common import recordset_compat as compat
from flwr.common.constant import MessageType
from flwr.common.logger import log
from flwr.common.secure_aggregation.crypto.shamir import create_shares
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
    bytes_to_private_key,
    bytes_to_public_key,
    decrypt,
    encrypt,
    generate_key_pairs,
    generate_shared_key,
    private_key_to_bytes,
    public_key_to_bytes,
)
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
    factor_combine,
    parameters_addition,
    parameters_mod,
    parameters_multiply,
    parameters_subtraction,
)
from flwr.common.secure_aggregation.quantization import quantize
from flwr.common.secure_aggregation.secaggplus_constants import (
    RECORD_KEY_CONFIGS,
    RECORD_KEY_STATE,
    Key,
    Stage,
)
from flwr.common.secure_aggregation.secaggplus_utils import (
    pseudo_rand_gen,
    share_keys_plaintext_concat,
    share_keys_plaintext_separate,
)
from flwr.common.typing import ConfigsRecordValues


@dataclass
# pylint: disable-next=too-many-instance-attributes
class SecAggPlusState:
    """State of the SecAgg+ protocol."""

    current_stage: str = Stage.UNMASK

    nid: int = 0
    sample_num: int = 0
    share_num: int = 0
    threshold: int = 0
    clipping_range: float = 0.0
    target_range: int = 0
    mod_range: int = 0
    max_weight: float = 0.0

    # Secret key (sk) and public key (pk)
    sk1: bytes = b""
    pk1: bytes = b""
    sk2: bytes = b""
    pk2: bytes = b""

    # Random seed for generating the private mask
    rd_seed: bytes = b""

    rd_seed_share_dict: dict[int, bytes] = field(default_factory=dict)
    sk1_share_dict: dict[int, bytes] = field(default_factory=dict)
    # The dict of the shared secrets from sk2
    ss2_dict: dict[int, bytes] = field(default_factory=dict)
    public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)

    def __init__(self, **kwargs: ConfigsRecordValues) -> None:
        for k, v in kwargs.items():
            if k.endswith(":V"):
                continue
            new_v: Any = v
            if k.endswith(":K"):
                k = k[:-2]
                keys = cast(list[int], v)
                values = cast(list[bytes], kwargs[f"{k}:V"])
                if len(values) > len(keys):
                    updated_values = [
                        tuple(values[i : i + 2]) for i in range(0, len(values), 2)
                    ]
                    new_v = dict(zip(keys, updated_values))
                else:
                    new_v = dict(zip(keys, values))
            self.__setattr__(k, new_v)

    def to_dict(self) -> dict[str, ConfigsRecordValues]:
        """Convert the state to a dictionary."""
        ret = vars(self)
        for k in list(ret.keys()):
            if isinstance(ret[k], dict):
                # Replace dict with two lists
                v = cast(dict[str, Any], ret.pop(k))
                ret[f"{k}:K"] = list(v.keys())
                if k == "public_keys_dict":
                    v_list: list[bytes] = []
                    for b1_b2 in cast(list[tuple[bytes, bytes]], v.values()):
                        v_list.extend(b1_b2)
                    ret[f"{k}:V"] = v_list
                else:
                    ret[f"{k}:V"] = list(v.values())
        return ret


[docs]def secaggplus_mod( msg: Message, ctxt: Context, call_next: ClientAppCallable, ) -> Message: """Handle incoming message and return results, following the SecAgg+ protocol.""" # Ignore non-fit messages if msg.metadata.message_type != MessageType.TRAIN: return call_next(msg, ctxt) # Retrieve local state if RECORD_KEY_STATE not in ctxt.state.configs_records: ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({}) state_dict = ctxt.state.configs_records[RECORD_KEY_STATE] state = SecAggPlusState(**state_dict) # Retrieve incoming configs configs = msg.content.configs_records[RECORD_KEY_CONFIGS] # Check the validity of the next stage check_stage(state.current_stage, configs) # Update the current stage state.current_stage = cast(str, configs.pop(Key.STAGE)) # Check the validity of the configs based on the current stage check_configs(state.current_stage, configs) # Execute out_content = RecordSet() if state.current_stage == Stage.SETUP: state.nid = msg.metadata.dst_node_id res = _setup(state, configs) elif state.current_stage == Stage.SHARE_KEYS: res = _share_keys(state, configs) elif state.current_stage == Stage.COLLECT_MASKED_VECTORS: out_msg = call_next(msg, ctxt) out_content = out_msg.content fitres = compat.recordset_to_fitres(out_content, keep_input=True) res = _collect_masked_vectors( state, configs, fitres.num_examples, fitres.parameters ) for p_record in out_content.parameters_records.values(): p_record.clear() elif state.current_stage == Stage.UNMASK: res = _unmask(state, configs) else: raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}") # Save state ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) # Return message out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) return msg.create_reply(out_content)
def check_stage(current_stage: str, configs: ConfigsRecord) -> None: """Check the validity of the next stage.""" # Check the existence of Config.STAGE if Key.STAGE not in configs: raise KeyError( f"The required key '{Key.STAGE}' is missing from the ConfigsRecord." ) # Check the value type of the Config.STAGE next_stage = configs[Key.STAGE] if not isinstance(next_stage, str): raise TypeError( f"The value for the key '{Key.STAGE}' must be of type {str}, " f"but got {type(next_stage)} instead." ) # Check the validity of the next stage if next_stage == Stage.SETUP: if current_stage != Stage.UNMASK: log(WARNING, "Restart from the setup stage") # If stage is not "setup", # the stage from configs should be the expected next stage else: stages = Stage.all() expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)] if next_stage != expected_next_stage: raise ValueError( "Abort secure aggregation: " f"expect {expected_next_stage} stage, but receive {next_stage} stage" ) # pylint: disable-next=too-many-branches def check_configs(stage: str, configs: ConfigsRecord) -> None: """Check the validity of the configs.""" # Check configs for the setup stage if stage == Stage.SETUP: key_type_pairs = [ (Key.SAMPLE_NUMBER, int), (Key.SHARE_NUMBER, int), (Key.THRESHOLD, int), (Key.CLIPPING_RANGE, float), (Key.TARGET_RANGE, int), (Key.MOD_RANGE, int), ] for key, expected_type in key_type_pairs: if key not in configs: raise KeyError( f"Stage {Stage.SETUP}: the required key '{key}' is " "missing from the ConfigsRecord." ) # Bool is a subclass of int in Python, # so `isinstance(v, int)` will return True even if v is a boolean. # pylint: disable-next=unidiomatic-typecheck if type(configs[key]) is not expected_type: raise TypeError( f"Stage {Stage.SETUP}: The value for the key '{key}' " f"must be of type {expected_type}, " f"but got {type(configs[key])} instead." ) elif stage == Stage.SHARE_KEYS: for key, value in configs.items(): if ( not isinstance(value, list) or len(value) != 2 or not isinstance(value[0], bytes) or not isinstance(value[1], bytes) ): raise TypeError( f"Stage {Stage.SHARE_KEYS}: " f"the value for the key '{key}' must be a list of two bytes." ) elif stage == Stage.COLLECT_MASKED_VECTORS: key_type_pairs = [ (Key.CIPHERTEXT_LIST, bytes), (Key.SOURCE_LIST, int), ] for key, expected_type in key_type_pairs: if key not in configs: raise KeyError( f"Stage {Stage.COLLECT_MASKED_VECTORS}: " f"the required key '{key}' is " "missing from the ConfigsRecord." ) if not isinstance(configs[key], list) or any( elm for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): raise TypeError( f"Stage {Stage.COLLECT_MASKED_VECTORS}: " f"the value for the key '{key}' " f"must be of type List[{expected_type.__name__}]" ) elif stage == Stage.UNMASK: key_type_pairs = [ (Key.ACTIVE_NODE_ID_LIST, int), (Key.DEAD_NODE_ID_LIST, int), ] for key, expected_type in key_type_pairs: if key not in configs: raise KeyError( f"Stage {Stage.UNMASK}: " f"the required key '{key}' is " "missing from the ConfigsRecord." ) if not isinstance(configs[key], list) or any( elm for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): raise TypeError( f"Stage {Stage.UNMASK}: " f"the value for the key '{key}' " f"must be of type List[{expected_type.__name__}]" ) else: raise ValueError(f"Unknown secagg stage: {stage}") def _setup( state: SecAggPlusState, configs: ConfigsRecord ) -> dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) log(DEBUG, "Node %d: starting stage 0...", state.nid) state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER]) state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD]) state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE]) state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE]) state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE]) state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT]) # Dictionaries containing node IDs as keys # and their respective secret shares as values. state.rd_seed_share_dict = {} state.sk1_share_dict = {} # Dictionary containing node IDs as keys # and their respective shared secrets (with this client) as values. state.ss2_dict = {} # Create 2 sets private public key pairs # One for creating pairwise masks # One for encrypting message to distribute shares sk1, pk1 = generate_key_pairs() sk2, pk2 = generate_key_pairs() state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid) return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2} # pylint: disable-next=too-many-locals def _share_keys( state: SecAggPlusState, configs: ConfigsRecord ) -> dict[str, ConfigsRecordValues]: named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(DEBUG, "Node %d: starting stage 1...", state.nid) state.public_keys_dict = key_dict # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: list[bytes] = [] for pk1, pk2 in state.public_keys_dict.values(): pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.nid][0] != state.pk1 or state.public_keys_dict[state.nid][1] != state.pk2 ): raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) # Generate the private mask seed state.rd_seed = os.urandom(32) # Create shares for the private mask seed and the first private key b_shares = create_shares(state.rd_seed, state.threshold, state.share_num) sk1_shares = create_shares(state.sk1, state.threshold, state.share_num) srcs, dsts, ciphertexts = [], [], [] # Distribute shares for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()): if nid == state.nid: state.rd_seed_share_dict[state.nid] = b_shares[idx] state.sk1_share_dict[state.nid] = sk1_shares[idx] else: shared_key = generate_shared_key( bytes_to_private_key(state.sk2), bytes_to_public_key(pk2), ) state.ss2_dict[nid] = shared_key plaintext = share_keys_plaintext_concat( state.nid, nid, b_shares[idx], sk1_shares[idx] ) ciphertext = encrypt(shared_key, plaintext) srcs.append(state.nid) dsts.append(nid) ciphertexts.append(ciphertext) log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid) return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts} # pylint: disable-next=too-many-locals def _collect_masked_vectors( state: SecAggPlusState, configs: ConfigsRecord, num_examples: int, updated_parameters: Parameters, ) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 2...", state.nid) available_clients: list[int] = [] ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST]) srcs = cast(list[int], configs[Key.SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): shared_key = state.ss2_dict[src] plaintext = decrypt(shared_key, ciphertext) actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate( plaintext ) available_clients.append(src) if src != actual_src: raise ValueError( f"Node {state.nid}: received ciphertext " f"from {actual_src} instead of {src}." ) if dst != state.nid: raise ValueError( f"Node {state.nid}: received an encrypted message" f"for Node {dst} from Node {src}." ) state.rd_seed_share_dict[src] = rd_seed_share state.sk1_share_dict[src] = sk1_share # Fit ratio = num_examples / state.max_weight if ratio > 1: log( WARNING, "Potential overflow warning: the provided weight (%s) exceeds the specified" " max_weight (%s). This may lead to overflow issues.", num_examples, state.max_weight, ) q_ratio = round(ratio * state.target_range) dq_ratio = q_ratio / state.target_range parameters = parameters_to_ndarrays(updated_parameters) parameters = parameters_multiply(parameters, dq_ratio) # Quantize parameter update (vector) quantized_parameters = quantize( parameters, state.clipping_range, state.target_range ) quantized_parameters = factor_combine(q_ratio, quantized_parameters) dimensions_list: list[tuple[int, ...]] = [a.shape for a in quantized_parameters] # Add private mask private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) quantized_parameters = parameters_addition(quantized_parameters, private_mask) for node_id in available_clients: # Add pairwise masks shared_key = generate_shared_key( bytes_to_private_key(state.sk1), bytes_to_public_key(state.public_keys_dict[node_id][0]), ) pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list) if state.nid > node_id: quantized_parameters = parameters_addition( quantized_parameters, pairwise_mask ) else: quantized_parameters = parameters_subtraction( quantized_parameters, pairwise_mask ) # Take mod of final weight update vector and return to server quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid) return { Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] } def _unmask( state: SecAggPlusState, configs: ConfigsRecord ) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 3...", state.nid) active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST]) dead_nids = cast(list[int], configs[Key.DEAD_NODE_ID_LIST]) # Send private mask seed share for every avaliable client (including itself) # Send first private key share for building pairwise mask for every dropped client if len(active_nids) < state.threshold: raise ValueError("Available neighbours number smaller than threshold") all_nids, shares = [], [] all_nids = active_nids + dead_nids shares += [state.rd_seed_share_dict[nid] for nid in active_nids] shares += [state.sk1_share_dict[nid] for nid in dead_nids] log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid) return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}