Source code for flwr.common.record.arrayrecord

# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ArrayRecord."""


from __future__ import annotations

import gc
import json
import sys
from collections import OrderedDict
from logging import WARN
from typing import TYPE_CHECKING, Any, cast, overload

import numpy as np

from ..constant import GC_THRESHOLD
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
from ..logger import log
from ..typing import NDArray
from .array import Array
from .typeddict import TypedDict

if TYPE_CHECKING:
    import torch


def _raise_array_record_init_error() -> None:
    raise TypeError(
        f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
        "a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
        "The `keep_input` argument is keyword-only."
    )


def _check_key(key: str) -> None:
    """Check if key is of expected type."""
    if not isinstance(key, str):
        raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.")


def _check_value(value: Array) -> None:
    if not isinstance(value, Array):
        raise TypeError(
            f"Value must be of type `{Array}` but `{type(value)}` was passed."
        )


[docs] class ArrayRecord(TypedDict[str, Array], InflatableObject): """Array record. A typed dictionary (``str`` to :class:`Array`) that can store named arrays, including model parameters, gradients, embeddings or non-parameter arrays. Internally, this behaves similarly to an ``OrderedDict[str, Array]``. An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``, but it holds arrays in a serialized form. This object is one of the record types supported by :class:`RecordDict` and can therefore be stored in the ``content`` of a :class:`Message` or the ``state`` of a :class:`Context`. This class can be instantiated in multiple ways: 1. By providing nothing (empty container). 2. By providing a dictionary of :class:`Array` (via the ``array_dict`` argument). 3. By providing a list of NumPy ``ndarray`` (via the ``numpy_ndarrays`` argument). 4. By providing a PyTorch ``state_dict`` (via the ``torch_state_dict`` argument). Parameters ---------- array_dict : Optional[OrderedDict[str, Array]] (default: None) An existing dictionary containing named :class:`Array` instances. If provided, these entries will be used directly to populate the record. numpy_ndarrays : Optional[list[NDArray]] (default: None) A list of NumPy arrays. Each array will be automatically converted into an :class:`Array` and stored in this record with generated keys. torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None) A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each tensor will be converted into an :class:`Array` and stored in this record. keep_input : bool (default: True) If ``False``, entries from the input are removed after being added to this record to free up memory. If ``True``, the input remains unchanged. Regardless of this value, no duplicate memory is used if the input is a dictionary of :class:`Array`, i.e., ``array_dict``. Examples -------- Initializing an empty ArrayRecord:: record = ArrayRecord() Initializing with a dictionary of :class:`Array`:: arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...") record = ArrayRecord({"weight": arr}) Initializing with a list of NumPy arrays:: import numpy as np arr1 = np.random.randn(3, 3) arr2 = np.random.randn(2, 2) record = ArrayRecord([arr1, arr2]) Initializing with a PyTorch model state_dict:: import torch.nn as nn model = nn.Linear(10, 5) record = ArrayRecord(model.state_dict()) Initializing with a TensorFlow model weights (a list of NumPy arrays):: import tensorflow as tf model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))]) record = ArrayRecord(model.get_weights()) """ @overload def __init__(self) -> None: ... # noqa: E704 @overload def __init__( # noqa: E704 self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True ) -> None: ... @overload def __init__( # noqa: E704 self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True ) -> None: ... @overload def __init__( # noqa: E704 self, torch_state_dict: OrderedDict[str, torch.Tensor], *, keep_input: bool = True, ) -> None: ... def __init__( # pylint: disable=too-many-arguments self, *args: Any, numpy_ndarrays: list[NDArray] | None = None, torch_state_dict: OrderedDict[str, torch.Tensor] | None = None, array_dict: OrderedDict[str, Array] | None = None, keep_input: bool = True, ) -> None: super().__init__(_check_key, _check_value) # Determine the initialization method and validates input arguments. # Support the following initialization formats: # 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool) # 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool) # 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool) # Init the argument if len(args) > 1: _raise_array_record_init_error() arg = args[0] if args else None init_method: str | None = None # Track which init method is being used # Try to assign a value to arg if it's not already set. # If an initialization method is provided, update init_method. def _try_set_arg(_arg: Any, method: str) -> None: # Skip if _arg is None if _arg is None: return nonlocal arg, init_method # Raise an error if arg is already set if arg is not None: _raise_array_record_init_error() # Raise an error if a different initialization method is already set if init_method is not None: _raise_array_record_init_error() # Set init_method and arg if init_method is None: init_method = method arg = _arg # Try to set keyword arguments _try_set_arg(array_dict, "array_dict") _try_set_arg(numpy_ndarrays, "numpy_ndarrays") _try_set_arg(torch_state_dict, "state_dict") # If no arguments are provided, return and keep self empty if arg is None: return # Handle dictionary of Arrays if not init_method or init_method == "array_dict": # Type check the input if ( isinstance(arg, dict) and all(isinstance(k, str) for k in arg.keys()) and all(isinstance(v, Array) for v in arg.values()) ): array_dict = cast(OrderedDict[str, Array], arg) converted = self.from_array_dict(array_dict, keep_input=keep_input) self.__dict__.update(converted.__dict__) return # Handle NumPy ndarrays if not init_method or init_method == "numpy_ndarrays": # Type check the input # pylint: disable-next=not-an-iterable if isinstance(arg, list) and all(isinstance(v, np.ndarray) for v in arg): numpy_ndarrays = cast(list[NDArray], arg) converted = self.from_numpy_ndarrays( numpy_ndarrays, keep_input=keep_input ) self.__dict__.update(converted.__dict__) return # Handle PyTorch state_dict if not init_method or init_method == "state_dict": # Type check the input if ( (torch := sys.modules.get("torch")) is not None and isinstance(arg, dict) and all(isinstance(k, str) for k in arg.keys()) and all(isinstance(v, torch.Tensor) for v in arg.values()) ): torch_state_dict = cast( OrderedDict[str, torch.Tensor], arg # type: ignore ) converted = self.from_torch_state_dict( torch_state_dict, keep_input=keep_input ) self.__dict__.update(converted.__dict__) return _raise_array_record_init_error()
[docs] @classmethod def from_array_dict( cls, array_dict: OrderedDict[str, Array], *, keep_input: bool = True, ) -> ArrayRecord: """Create ArrayRecord from a dictionary of :class:`Array`.""" record = ArrayRecord() for k, v in array_dict.items(): record[k] = Array( dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data ) if not keep_input: array_dict.clear() return record
[docs] @classmethod def from_numpy_ndarrays( cls, ndarrays: list[NDArray], *, keep_input: bool = True, ) -> ArrayRecord: """Create ArrayRecord from a list of NumPy ``ndarray``.""" record = ArrayRecord() total_serialized_bytes = 0 for i in range(len(ndarrays)): # pylint: disable=C0200 record[str(i)] = Array.from_numpy_ndarray(ndarrays[i]) if not keep_input: # Remove the reference ndarrays[i] = None # type: ignore total_serialized_bytes += len(record[str(i)].data) # If total serialized data exceeds the threshold, trigger GC if total_serialized_bytes > GC_THRESHOLD: total_serialized_bytes = 0 gc.collect() if not keep_input: # Clear the entire list to remove all references and force GC ndarrays.clear() gc.collect() return record
[docs] @classmethod def from_torch_state_dict( cls, state_dict: OrderedDict[str, torch.Tensor], *, keep_input: bool = True, ) -> ArrayRecord: """Create ArrayRecord from PyTorch ``state_dict``.""" if "torch" not in sys.modules: raise RuntimeError( f"PyTorch is required to use {cls.from_torch_state_dict.__name__}" ) record = ArrayRecord() for k in list(state_dict.keys()): v = state_dict[k] if keep_input else state_dict.pop(k) record[k] = Array.from_numpy_ndarray(v.detach().cpu().numpy()) return record
[docs] def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]: """Return the ArrayRecord as a list of NumPy ``ndarray``.""" if keep_input: return [v.numpy() for v in self.values()] # Clear the record and return the list of NumPy arrays ret: list[NDArray] = [] total_serialized_bytes = 0 for k in list(self.keys()): arr = self.pop(k) ret.append(arr.numpy()) total_serialized_bytes += len(arr.data) del arr # If total serialized data exceeds the threshold, trigger GC if total_serialized_bytes > GC_THRESHOLD: total_serialized_bytes = 0 gc.collect() if not keep_input: # Force GC gc.collect() return ret
[docs] def to_torch_state_dict( self, *, keep_input: bool = True ) -> OrderedDict[str, torch.Tensor]: """Return the ArrayRecord as a PyTorch ``state_dict``.""" if not (torch := sys.modules.get("torch")): raise RuntimeError( f"PyTorch is required to use {self.to_torch_state_dict.__name__}" ) state_dict = OrderedDict() for k in list(self.keys()): arr = self[k] if keep_input else self.pop(k) state_dict[k] = torch.from_numpy(arr.numpy()) return state_dict
[docs] def count_bytes(self) -> int: """Return number of Bytes stored in this object. Note that a small amount of Bytes might also be included in this counting that correspond to metadata of the serialized object (e.g. of NumPy array) needed for deseralization. """ num_bytes = 0 for k, v in self.items(): num_bytes += len(v.data) # We also count the bytes footprint of the keys num_bytes += len(k) return num_bytes
@property def children(self) -> dict[str, InflatableObject]: """Return a dictionary of Arrays with their Object IDs as keys.""" return {arr.object_id: arr for arr in self.values()}
[docs] def deflate(self) -> bytes: """Deflate the ArrayRecord.""" # array_name: array_object_id mapping array_refs: dict[str, str] = {} for array_name, array in self.items(): array_refs[array_name] = array.object_id # Serialize references dict object_body = json.dumps(array_refs).encode("utf-8") return add_header_to_object_body(object_body=object_body, obj=self)
[docs] @classmethod def inflate( cls, object_content: bytes, children: dict[str, InflatableObject] | None = None ) -> ArrayRecord: """Inflate an ArrayRecord from bytes. Parameters ---------- object_content : bytes The deflated object content of the ArrayRecord. children : Optional[dict[str, InflatableObject]] (default: None) Dictionary of children InflatableObjects mapped to their Object IDs. These children enable the full inflation of the ArrayRecord. Returns ------- ArrayRecord The inflated ArrayRecord. """ if children is None: children = {} # Inflate mapping of array_names (keys in the ArrayRecord) to Arrays' object IDs obj_body = get_object_body(object_content, cls) array_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8")) unique_arrays = set(array_refs.values()) children_obj_ids = set(children.keys()) if unique_arrays != children_obj_ids: raise ValueError( "Unexpected set of `children`. " f"Expected {unique_arrays} but got {children_obj_ids}." ) # Ensure children are of type Array if not all(isinstance(arr, Array) for arr in children.values()): raise ValueError("`Children` are expected to be of type `Array`.") # Instantiate new ArrayRecord return ArrayRecord( OrderedDict( {name: children[object_id] for name, object_id in array_refs.items()} ) )
@property def object_id(self) -> str: """Get object ID.""" ret = super().object_id self.is_dirty = False # Reset dirty flag return ret @property def is_dirty(self) -> bool: """Check if the object is dirty after the last deflation.""" if "_is_dirty" not in self.__dict__: self.__dict__["_is_dirty"] = True if not self.__dict__["_is_dirty"]: if any(v.is_dirty for v in self.values()): # If any Array is dirty, mark the record as dirty self.__dict__["_is_dirty"] = True return cast(bool, self.__dict__["_is_dirty"]) @is_dirty.setter def is_dirty(self, value: bool) -> None: """Set the dirty flag.""" self.__dict__["_is_dirty"] = value def __setitem__(self, key: str, value: Array) -> None: """Set item and mark the record as dirty.""" self.is_dirty = True # Mark as dirty when setting an item super().__setitem__(key, value) def __delitem__(self, key: str) -> None: """Delete item and mark the record as dirty.""" self.is_dirty = True # Mark as dirty when deleting an item super().__delitem__(key)
[docs] class ParametersRecord(ArrayRecord): """Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead. This class exists solely for backward compatibility with legacy code that previously used ``ParametersRecord``. It has been renamed to ``ArrayRecord``. .. warning:: ``ParametersRecord`` is deprecated and will be removed in a future release. Use ``ArrayRecord`` instead. Examples -------- Legacy (deprecated) usage:: from flwr.common import ParametersRecord record = ParametersRecord() Updated usage:: from flwr.common import ArrayRecord record = ArrayRecord() """ _warning_logged = False def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: if not ParametersRecord._warning_logged: ParametersRecord._warning_logged = True log( WARN, "The `ParametersRecord` class has been renamed to `ArrayRecord`. " "Support for `ParametersRecord` will be removed in a future release. " "Please update your code accordingly.", ) super().__init__(*args, **kwargs)