# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ArrayRecord and Array."""
from __future__ import annotations
import gc
import sys
from collections import OrderedDict
from dataclasses import dataclass
from io import BytesIO
from logging import WARN
from typing import TYPE_CHECKING, Any, cast, overload
import numpy as np
from ..constant import GC_THRESHOLD, SType
from ..logger import log
from ..typing import NDArray
from .typeddict import TypedDict
if TYPE_CHECKING:
import torch
def _raise_array_init_error() -> None:
raise TypeError(
f"Invalid arguments for {Array.__qualname__}. Expected either a "
"PyTorch tensor, a NumPy ndarray, or explicit"
" dtype/shape/stype/data values."
)
def _raise_array_record_init_error() -> None:
raise TypeError(
f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
"a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
"The `keep_input` argument is keyword-only."
)
[docs]
@dataclass
class Array:
"""Array type.
A dataclass containing serialized data from an array-like or tensor-like object
along with metadata about it. The class can be initialized in one of three ways:
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
2. By providing a NumPy ndarray (via the `ndarray` argument).
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
derived from the input. In scenario (1), these fields must be specified manually.
Parameters
----------
dtype : Optional[str] (default: None)
A string representing the data type of the serialized object (e.g. `"float32"`).
Only required if you are not passing in a ndarray or a tensor.
shape : Optional[list[int]] (default: None)
A list representing the shape of the unserialized array-like object. Only
required if you are not passing in a ndarray or a tensor.
stype : Optional[str] (default: None)
A string indicating the serialization mechanism used to generate the bytes in
`data` from an array-like or tensor-like object. Only required if you are not
passing in a ndarray or a tensor.
data : Optional[bytes] (default: None)
A buffer of bytes containing the data. Only required if you are not passing in
a ndarray or a tensor.
ndarray : Optional[NDArray] (default: None)
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
fields are derived automatically from it.
torch_tensor : Optional[torch.Tensor] (default: None)
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
will be derived automatically from it.
Examples
--------
Initializing by specifying all fields directly:
>>> arr1 = Array(
>>> dtype="float32",
>>> shape=[3, 3],
>>> stype="numpy.ndarray",
>>> data=b"serialized_data...",
>>> )
Initializing with a NumPy ndarray:
>>> import numpy as np
>>> arr2 = Array(np.random.randn(3, 3))
Initializing with a PyTorch tensor:
>>> import torch
>>> arr3 = Array(torch.randn(3, 3))
"""
dtype: str
shape: list[int]
stype: str
data: bytes
@overload
def __init__( # noqa: E704
self, dtype: str, shape: list[int], stype: str, data: bytes
) -> None: ...
@overload
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
@overload
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
def __init__( # pylint: disable=too-many-arguments, too-many-locals
self,
*args: Any,
dtype: str | None = None,
shape: list[int] | None = None,
stype: str | None = None,
data: bytes | None = None,
ndarray: NDArray | None = None,
torch_tensor: torch.Tensor | None = None,
) -> None:
# Determine the initialization method and validate input arguments.
# Support three initialization formats:
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
# 2. Array(ndarray: NDArray)
# 3. Array(torch_tensor: torch.Tensor)
# Initialize all arguments
# If more than 4 positional arguments are provided, raise an error.
if len(args) > 4:
_raise_array_init_error()
all_args = [None] * 4
for i, arg in enumerate(args):
all_args[i] = arg
init_method: str | None = None # Track which init method is being used
# Try to assign a value to all_args[index] if it's not already set.
# If an initialization method is provided, update init_method.
def _try_set_arg(index: int, arg: Any, method: str) -> None:
# Skip if arg is None
if arg is None:
return
# Raise an error if all_args[index] is already set
if all_args[index] is not None:
_raise_array_init_error()
# Raise an error if a different initialization method is already set
nonlocal init_method
if init_method is not None and init_method != method:
_raise_array_init_error()
# Set init_method and all_args[index]
if init_method is None:
init_method = method
all_args[index] = arg
# Try to set keyword arguments in all_args
_try_set_arg(0, dtype, "direct")
_try_set_arg(1, shape, "direct")
_try_set_arg(2, stype, "direct")
_try_set_arg(3, data, "direct")
_try_set_arg(0, ndarray, "ndarray")
_try_set_arg(0, torch_tensor, "torch_tensor")
# Check if all arguments are correctly set
all_args = [arg for arg in all_args if arg is not None]
# Handle direct field initialization
if not init_method or init_method == "direct":
if (
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
and isinstance(all_args[0], str)
and isinstance(all_args[1], list)
and all(isinstance(i, int) for i in all_args[1])
and isinstance(all_args[2], str)
and isinstance(all_args[3], bytes)
):
self.dtype, self.shape, self.stype, self.data = all_args
return
# Handle NumPy array
if not init_method or init_method == "ndarray":
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
return
# Handle PyTorch tensor
if not init_method or init_method == "torch_tensor":
if (
len(all_args) == 1
and "torch" in sys.modules
and isinstance(all_args[0], sys.modules["torch"].Tensor)
):
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
return
_raise_array_init_error()
[docs]
@classmethod
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
"""Create Array from NumPy ndarray."""
assert isinstance(
ndarray, np.ndarray
), f"Expected NumPy ndarray, got {type(ndarray)}"
buffer = BytesIO()
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.save(buffer, ndarray, allow_pickle=False)
data = buffer.getvalue()
return Array(
dtype=str(ndarray.dtype),
shape=list(ndarray.shape),
stype=SType.NUMPY,
data=data,
)
[docs]
@classmethod
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
"""Create Array from PyTorch tensor."""
if not (torch := sys.modules.get("torch")):
raise RuntimeError(
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
)
assert isinstance(
tensor, torch.Tensor
), f"Expected PyTorch Tensor, got {type(tensor)}"
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
[docs]
def numpy(self) -> NDArray:
"""Return the array as a NumPy array."""
if self.stype != SType.NUMPY:
raise TypeError(
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
)
bytes_io = BytesIO(self.data)
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
return cast(NDArray, ndarray_deserialized)
def _check_key(key: str) -> None:
"""Check if key is of expected type."""
if not isinstance(key, str):
raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.")
def _check_value(value: Array) -> None:
if not isinstance(value, Array):
raise TypeError(
f"Value must be of type `{Array}` but `{type(value)}` was passed."
)
[docs]
class ArrayRecord(TypedDict[str, Array]):
"""Array record.
A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
including model parameters, gradients, embeddings or non-parameter arrays.
Internally, this behaves similarly to an ``OrderedDict[str, Array]``.
An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``,
but it holds arrays in a serialized form.
This object is one of the record types supported by :class:`RecordDict` and can
therefore be stored in the ``content`` of a :class:`Message` or the ``state``
of a :class:`Context`.
This class can be instantiated in multiple ways:
1. By providing nothing (empty container).
2. By providing a dictionary of :class:`Array` (via the ``array_dict`` argument).
3. By providing a list of NumPy ``ndarray`` (via the ``numpy_ndarrays`` argument).
4. By providing a PyTorch ``state_dict`` (via the ``torch_state_dict`` argument).
Parameters
----------
array_dict : Optional[OrderedDict[str, Array]] (default: None)
An existing dictionary containing named :class:`Array` instances. If
provided, these entries will be used directly to populate the record.
numpy_ndarrays : Optional[list[NDArray]] (default: None)
A list of NumPy arrays. Each array will be automatically converted
into an :class:`Array` and stored in this record with generated keys.
torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None)
A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
tensor will be converted into an :class:`Array` and stored in this record.
keep_input : bool (default: True)
If ``False``, entries from the input are removed after being added to
this record to free up memory. If ``True``, the input remains unchanged.
Regardless of this value, no duplicate memory is used if the input is a
dictionary of :class:`Array`, i.e., ``array_dict``.
Examples
--------
Initializing an empty ArrayRecord:
>>> record = ArrayRecord()
Initializing with a dictionary of :class:`Array`:
>>> arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
>>> record = ArrayRecord({"weight": arr})
Initializing with a list of NumPy arrays:
>>> import numpy as np
>>> arr1 = np.random.randn(3, 3)
>>> arr2 = np.random.randn(2, 2)
>>> record = ArrayRecord([arr1, arr2])
Initializing with a PyTorch model state_dict:
>>> import torch.nn as nn
>>> model = nn.Linear(10, 5)
>>> record = ArrayRecord(model.state_dict())
Initializing with a TensorFlow model weights (a list of NumPy arrays):
>>> import tensorflow as tf
>>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
>>> record = ArrayRecord(model.get_weights())
"""
@overload
def __init__(self) -> None: ... # noqa: E704
@overload
def __init__( # noqa: E704
self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True
) -> None: ...
@overload
def __init__( # noqa: E704
self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
) -> None: ...
@overload
def __init__( # noqa: E704
self,
torch_state_dict: OrderedDict[str, torch.Tensor],
*,
keep_input: bool = True,
) -> None: ...
def __init__( # pylint: disable=too-many-arguments
self,
*args: Any,
numpy_ndarrays: list[NDArray] | None = None,
torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
array_dict: OrderedDict[str, Array] | None = None,
keep_input: bool = True,
) -> None:
super().__init__(_check_key, _check_value)
# Determine the initialization method and validates input arguments.
# Support the following initialization formats:
# 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool)
# 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
# 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
# Init the argument
if len(args) > 1:
_raise_array_record_init_error()
arg = args[0] if args else None
init_method: str | None = None # Track which init method is being used
# Try to assign a value to arg if it's not already set.
# If an initialization method is provided, update init_method.
def _try_set_arg(_arg: Any, method: str) -> None:
# Skip if _arg is None
if _arg is None:
return
nonlocal arg, init_method
# Raise an error if arg is already set
if arg is not None:
_raise_array_record_init_error()
# Raise an error if a different initialization method is already set
if init_method is not None:
_raise_array_record_init_error()
# Set init_method and arg
if init_method is None:
init_method = method
arg = _arg
# Try to set keyword arguments
_try_set_arg(array_dict, "array_dict")
_try_set_arg(numpy_ndarrays, "numpy_ndarrays")
_try_set_arg(torch_state_dict, "state_dict")
# If no arguments are provided, return and keep self empty
if arg is None:
return
# Handle dictionary of Arrays
if not init_method or init_method == "array_dict":
# Type check the input
if (
isinstance(arg, dict)
and all(isinstance(k, str) for k in arg.keys())
and all(isinstance(v, Array) for v in arg.values())
):
array_dict = cast(OrderedDict[str, Array], arg)
converted = self.from_array_dict(array_dict, keep_input=keep_input)
self.__dict__.update(converted.__dict__)
return
# Handle NumPy ndarrays
if not init_method or init_method == "numpy_ndarrays":
# Type check the input
# pylint: disable-next=not-an-iterable
if isinstance(arg, list) and all(isinstance(v, np.ndarray) for v in arg):
numpy_ndarrays = cast(list[NDArray], arg)
converted = self.from_numpy_ndarrays(
numpy_ndarrays, keep_input=keep_input
)
self.__dict__.update(converted.__dict__)
return
# Handle PyTorch state_dict
if not init_method or init_method == "state_dict":
# Type check the input
if (
(torch := sys.modules.get("torch")) is not None
and isinstance(arg, dict)
and all(isinstance(k, str) for k in arg.keys())
and all(isinstance(v, torch.Tensor) for v in arg.values())
):
torch_state_dict = cast(
OrderedDict[str, torch.Tensor], arg # type: ignore
)
converted = self.from_torch_state_dict(
torch_state_dict, keep_input=keep_input
)
self.__dict__.update(converted.__dict__)
return
_raise_array_record_init_error()
[docs]
@classmethod
def from_array_dict(
cls,
array_dict: OrderedDict[str, Array],
*,
keep_input: bool = True,
) -> ArrayRecord:
"""Create ArrayRecord from a dictionary of :class:`Array`."""
record = ArrayRecord()
for k, v in array_dict.items():
record[k] = Array(
dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
)
if not keep_input:
array_dict.clear()
return record
[docs]
@classmethod
def from_numpy_ndarrays(
cls,
ndarrays: list[NDArray],
*,
keep_input: bool = True,
) -> ArrayRecord:
"""Create ArrayRecord from a list of NumPy ``ndarray``."""
record = ArrayRecord()
total_serialized_bytes = 0
for i in range(len(ndarrays)): # pylint: disable=C0200
record[str(i)] = Array.from_numpy_ndarray(ndarrays[i])
if not keep_input:
# Remove the reference
ndarrays[i] = None # type: ignore
total_serialized_bytes += len(record[str(i)].data)
# If total serialized data exceeds the threshold, trigger GC
if total_serialized_bytes > GC_THRESHOLD:
total_serialized_bytes = 0
gc.collect()
if not keep_input:
# Clear the entire list to remove all references and force GC
ndarrays.clear()
gc.collect()
return record
[docs]
@classmethod
def from_torch_state_dict(
cls,
state_dict: OrderedDict[str, torch.Tensor],
*,
keep_input: bool = True,
) -> ArrayRecord:
"""Create ArrayRecord from PyTorch ``state_dict``."""
if "torch" not in sys.modules:
raise RuntimeError(
f"PyTorch is required to use {cls.from_torch_state_dict.__name__}"
)
record = ArrayRecord()
for k in list(state_dict.keys()):
v = state_dict[k] if keep_input else state_dict.pop(k)
record[k] = Array.from_numpy_ndarray(v.detach().cpu().numpy())
return record
[docs]
def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]:
"""Return the ArrayRecord as a list of NumPy ``ndarray``."""
if keep_input:
return [v.numpy() for v in self.values()]
# Clear the record and return the list of NumPy arrays
ret: list[NDArray] = []
total_serialized_bytes = 0
for k in list(self.keys()):
arr = self.pop(k)
ret.append(arr.numpy())
total_serialized_bytes += len(arr.data)
del arr
# If total serialized data exceeds the threshold, trigger GC
if total_serialized_bytes > GC_THRESHOLD:
total_serialized_bytes = 0
gc.collect()
if not keep_input:
# Force GC
gc.collect()
return ret
[docs]
def to_torch_state_dict(
self, *, keep_input: bool = True
) -> OrderedDict[str, torch.Tensor]:
"""Return the ArrayRecord as a PyTorch ``state_dict``."""
if not (torch := sys.modules.get("torch")):
raise RuntimeError(
f"PyTorch is required to use {self.to_torch_state_dict.__name__}"
)
state_dict = OrderedDict()
for k in list(self.keys()):
arr = self[k] if keep_input else self.pop(k)
state_dict[k] = torch.from_numpy(arr.numpy())
return state_dict
[docs]
def count_bytes(self) -> int:
"""Return number of Bytes stored in this object.
Note that a small amount of Bytes might also be included in this counting that
correspond to metadata of the serialized object (e.g. of NumPy array) needed for
deseralization.
"""
num_bytes = 0
for k, v in self.items():
num_bytes += len(v.data)
# We also count the bytes footprint of the keys
num_bytes += len(k)
return num_bytes
[docs]
class ParametersRecord(ArrayRecord):
"""Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
This class exists solely for backward compatibility with legacy
code that previously used ``ParametersRecord``. It has been renamed
to ``ArrayRecord``.
.. warning::
``ParametersRecord`` is deprecated and will be removed in a future release.
Use ``ArrayRecord`` instead.
Examples
--------
Legacy (deprecated) usage::
from flwr.common import ParametersRecord
record = ParametersRecord()
Updated usage::
from flwr.common import ArrayRecord
record = ArrayRecord()
"""
_warning_logged = False
def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None:
if not ParametersRecord._warning_logged:
ParametersRecord._warning_logged = True
log(
WARN,
"The `ParametersRecord` class has been renamed to `ArrayRecord`. "
"Support for `ParametersRecord` will be removed in a future release. "
"Please update your code accordingly.",
)
super().__init__(*args, **kwargs)