# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Array."""
from __future__ import annotations
import sys
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Any, cast, overload
import numpy as np
from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
from ..constant import SType
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
from ..typing import NDArray
if TYPE_CHECKING:
import torch
def _raise_array_init_error() -> None:
raise TypeError(
f"Invalid arguments for {Array.__qualname__}. Expected either a "
"PyTorch tensor, a NumPy ndarray, or explicit"
" dtype/shape/stype/data values."
)
[docs]
@dataclass
class Array(InflatableObject):
"""Array type.
A dataclass containing serialized data from an array-like or tensor-like object
along with metadata about it. The class can be initialized in one of three ways:
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
2. By providing a NumPy ndarray (via the `ndarray` argument).
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
derived from the input. In scenario (1), these fields must be specified manually.
Parameters
----------
dtype : Optional[str] (default: None)
A string representing the data type of the serialized object (e.g. `"float32"`).
Only required if you are not passing in a ndarray or a tensor.
shape : Optional[list[int]] (default: None)
A list representing the shape of the unserialized array-like object. Only
required if you are not passing in a ndarray or a tensor.
stype : Optional[str] (default: None)
A string indicating the serialization mechanism used to generate the bytes in
`data` from an array-like or tensor-like object. Only required if you are not
passing in a ndarray or a tensor.
data : Optional[bytes] (default: None)
A buffer of bytes containing the data. Only required if you are not passing in
a ndarray or a tensor.
ndarray : Optional[NDArray] (default: None)
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
fields are derived automatically from it.
torch_tensor : Optional[torch.Tensor] (default: None)
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
will be derived automatically from it.
Examples
--------
Initializing by specifying all fields directly::
arr1 = Array(
dtype="float32",
shape=[3, 3],
stype="numpy.ndarray",
data=b"serialized_data...",
)
Initializing with a NumPy ndarray::
import numpy as np
arr2 = Array(np.random.randn(3, 3))
Initializing with a PyTorch tensor::
import torch
arr3 = Array(torch.randn(3, 3))
"""
dtype: str
shape: list[int]
stype: str
data: bytes
@overload
def __init__( # noqa: E704
self, dtype: str, shape: list[int], stype: str, data: bytes
) -> None: ...
@overload
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
@overload
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
def __init__( # pylint: disable=too-many-arguments, too-many-locals
self,
*args: Any,
dtype: str | None = None,
shape: list[int] | None = None,
stype: str | None = None,
data: bytes | None = None,
ndarray: NDArray | None = None,
torch_tensor: torch.Tensor | None = None,
) -> None:
# Determine the initialization method and validate input arguments.
# Support three initialization formats:
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
# 2. Array(ndarray: NDArray)
# 3. Array(torch_tensor: torch.Tensor)
# Initialize all arguments
# If more than 4 positional arguments are provided, raise an error.
if len(args) > 4:
_raise_array_init_error()
all_args = [None] * 4
for i, arg in enumerate(args):
all_args[i] = arg
init_method: str | None = None # Track which init method is being used
# Try to assign a value to all_args[index] if it's not already set.
# If an initialization method is provided, update init_method.
def _try_set_arg(index: int, arg: Any, method: str) -> None:
# Skip if arg is None
if arg is None:
return
# Raise an error if all_args[index] is already set
if all_args[index] is not None:
_raise_array_init_error()
# Raise an error if a different initialization method is already set
nonlocal init_method
if init_method is not None and init_method != method:
_raise_array_init_error()
# Set init_method and all_args[index]
if init_method is None:
init_method = method
all_args[index] = arg
# Try to set keyword arguments in all_args
_try_set_arg(0, dtype, "direct")
_try_set_arg(1, shape, "direct")
_try_set_arg(2, stype, "direct")
_try_set_arg(3, data, "direct")
_try_set_arg(0, ndarray, "ndarray")
_try_set_arg(0, torch_tensor, "torch_tensor")
# Check if all arguments are correctly set
all_args = [arg for arg in all_args if arg is not None]
# Handle direct field initialization
if not init_method or init_method == "direct":
if (
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
and isinstance(all_args[0], str)
and isinstance(all_args[1], list)
and all(isinstance(i, int) for i in all_args[1])
and isinstance(all_args[2], str)
and isinstance(all_args[3], bytes)
):
self.dtype, self.shape, self.stype, self.data = all_args
return
# Handle NumPy array
if not init_method or init_method == "ndarray":
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
return
# Handle PyTorch tensor
if not init_method or init_method == "torch_tensor":
if (
len(all_args) == 1
and "torch" in sys.modules
and isinstance(all_args[0], sys.modules["torch"].Tensor)
):
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
return
_raise_array_init_error()
[docs]
@classmethod
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
"""Create Array from NumPy ndarray."""
assert isinstance(
ndarray, np.ndarray
), f"Expected NumPy ndarray, got {type(ndarray)}"
buffer = BytesIO()
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.save(buffer, ndarray, allow_pickle=False)
data = buffer.getvalue()
return Array(
dtype=str(ndarray.dtype),
shape=list(ndarray.shape),
stype=SType.NUMPY,
data=data,
)
[docs]
@classmethod
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
"""Create Array from PyTorch tensor."""
if not (torch := sys.modules.get("torch")):
raise RuntimeError(
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
)
assert isinstance(
tensor, torch.Tensor
), f"Expected PyTorch Tensor, got {type(tensor)}"
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
[docs]
def numpy(self) -> NDArray:
"""Return the array as a NumPy array."""
if self.stype != SType.NUMPY:
raise TypeError(
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
)
bytes_io = BytesIO(self.data)
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
return cast(NDArray, ndarray_deserialized)
[docs]
def deflate(self) -> bytes:
"""Deflate the Array."""
array_proto = ArrayProto(
dtype=self.dtype,
shape=self.shape,
stype=self.stype,
data=self.data,
)
obj_body = array_proto.SerializeToString(deterministic=True)
return add_header_to_object_body(object_body=obj_body, cls=self)
[docs]
@classmethod
def inflate(cls, object_content: bytes) -> Array:
"""Inflate an Array from bytes.
Parameters
----------
object_content : bytes
The deflated object content of the Array.
Returns
-------
Array
The inflated Array.
"""
obj_body = get_object_body(object_content, cls)
proto_array = ArrayProto.FromString(obj_body)
return cls(
dtype=proto_array.dtype,
shape=list(proto_array.shape),
stype=proto_array.stype,
data=proto_array.data,
)