# Copyright 2020 Flower Labs GmbH. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Parameter conversion."""fromioimportBytesIOfromtypingimportcastimportnumpyasnpfrom.typingimportNDArray,NDArrays,Parameters
[docs]defndarrays_to_parameters(ndarrays:NDArrays)->Parameters:"""Convert NumPy ndarrays to parameters object."""tensors=[ndarray_to_bytes(ndarray)forndarrayinndarrays]returnParameters(tensors=tensors,tensor_type="numpy.ndarray")
[docs]defparameters_to_ndarrays(parameters:Parameters)->NDArrays:"""Convert parameters object to NumPy ndarrays."""return[bytes_to_ndarray(tensor)fortensorinparameters.tensors]
[docs]defndarray_to_bytes(ndarray:NDArray)->bytes:"""Serialize NumPy ndarray to bytes."""bytes_io=BytesIO()# WARNING: NEVER set allow_pickle to true.# Reason: loading pickled data can execute arbitrary code# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.htmlnp.save(bytes_io,ndarray,allow_pickle=False)returnbytes_io.getvalue()
[docs]defbytes_to_ndarray(tensor:bytes)->NDArray:"""Deserialize NumPy ndarray from bytes."""bytes_io=BytesIO(tensor)# WARNING: NEVER set allow_pickle to true.# Reason: loading pickled data can execute arbitrary code# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.htmlndarray_deserialized=np.load(bytes_io,allow_pickle=False)returncast(NDArray,ndarray_deserialized)