Code source de flwr.client.numpy_client

# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower client app."""


from abc import ABC
from typing import Callable

from flwr.client.client import Client
from flwr.common import (
    Config,
    Context,
    NDArrays,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import warn_deprecated_feature_with_example
from flwr.common.typing import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    GetPropertiesIns,
    GetPropertiesRes,
    Status,
)

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
NumPyClient.fit did not return a tuple with 3 elements.
The returned values should have the following type signature:

    Tuple[NDArrays, int, Dict[str, Scalar]]

Example
-------

    model.get_weights(), 10, {"accuracy": 0.95}

"""

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE = """
NumPyClient.evaluate did not return a tuple with 3 elements.
The returned values should have the following type signature:

    Tuple[float, int, Dict[str, Scalar]]

Example
-------

    0.5, 10, {"accuracy": 0.95}

"""


[docs] class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" _context: Context
[docs] def get_properties(self, config: Config) -> dict[str, Scalar]: """Return a client's set of properties. Parameters ---------- config : Config Configuration parameters requested by the server. This can be used to tell the client which properties are needed along with some Scalar attributes. Returns ------- properties : Dict[str, Scalar] A dictionary mapping arbitrary string keys to values of type bool, bytes, float, int, or str. It can be used to communicate arbitrary property values back to the server. """ _ = (self, config) return {}
[docs] def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: """Return the current local model parameters. Parameters ---------- config : Config Configuration parameters requested by the server. This can be used to tell the client which parameters are needed along with some Scalar attributes. Returns ------- parameters : NDArrays The local model parameters as a list of NumPy ndarrays. """ _ = (self, config) return []
[docs] def fit( self, parameters: NDArrays, config: dict[str, Scalar] ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Train the provided parameters using the locally held dataset. Parameters ---------- parameters : NDArrays The current (global) model parameters. config : Dict[str, Scalar] Configuration parameters which allow the server to influence training on the client. It can be used to communicate arbitrary values from the server to the client, for example, to set the number of (local) training epochs. Returns ------- parameters : NDArrays The locally updated model parameters. num_examples : int The number of examples used for training. metrics : Dict[str, Scalar] A dictionary mapping arbitrary string keys to values of type bool, bytes, float, int, or str. It can be used to communicate arbitrary values back to the server. """ _ = (self, parameters, config) return [], 0, {}
[docs] def evaluate( self, parameters: NDArrays, config: dict[str, Scalar] ) -> tuple[float, int, dict[str, Scalar]]: """Evaluate the provided parameters using the locally held dataset. Parameters ---------- parameters : NDArrays The current (global) model parameters. config : Dict[str, Scalar] Configuration parameters which allow the server to influence evaluation on the client. It can be used to communicate arbitrary values from the server to the client, for example, to influence the number of examples used for evaluation. Returns ------- loss : float The evaluation loss of the model on the local dataset. num_examples : int The number of examples used for evaluation. metrics : Dict[str, Scalar] A dictionary mapping arbitrary string keys to values of type bool, bytes, float, int, or str. It can be used to communicate arbitrary values back to the server. Warning ------- The previous return type format (int, float, float) and the extended format (int, float, float, Dict[str, Scalar]) have been deprecated and removed since Flower 0.19. """ _ = (self, parameters, config) return 0.0, 0, {}
@property def context(self) -> Context: """Getter for `Context` client attribute.""" warn_deprecated_feature_with_example( "Accessing the context via the client's attribute is deprecated.", example_message="Instead, pass it to the client's " "constructor in your `client_fn()` which already " "receives a context object.", code_example="def client_fn(context: Context) -> Client:\n\n" "\t\t# Your existing client_fn\n\n" "\t\t# Pass `context` to the constructor\n" "\t\treturn FlowerClient(context).to_client()", ) return self._context @context.setter def context(self, context: Context) -> None: """Setter for `Context` client attribute.""" self._context = context
[docs] def get_context(self) -> Context: """Get the run context from this client.""" return self.context
[docs] def set_context(self, context: Context) -> None: """Apply a run context to this client.""" self.context = context
[docs] def to_client(self) -> Client: """Convert to object to Client type and return it.""" return _wrap_numpy_client(client=self)
def has_get_properties(client: NumPyClient) -> bool: """Check if NumPyClient implements get_properties.""" return type(client).get_properties != NumPyClient.get_properties def has_get_parameters(client: NumPyClient) -> bool: """Check if NumPyClient implements get_parameters.""" return type(client).get_parameters != NumPyClient.get_parameters def has_fit(client: NumPyClient) -> bool: """Check if NumPyClient implements fit.""" return type(client).fit != NumPyClient.fit def has_evaluate(client: NumPyClient) -> bool: """Check if NumPyClient implements evaluate.""" return type(client).evaluate != NumPyClient.evaluate def _constructor(self: Client, numpy_client: NumPyClient) -> None: self.numpy_client = numpy_client # type: ignore def _get_properties(self: Client, ins: GetPropertiesIns) -> GetPropertiesRes: """Return the current client properties.""" properties = self.numpy_client.get_properties(config=ins.config) # type: ignore return GetPropertiesRes( status=Status(code=Code.OK, message="Success"), properties=properties, ) def _get_parameters(self: Client, ins: GetParametersIns) -> GetParametersRes: """Return the current local model parameters.""" parameters = self.numpy_client.get_parameters(config=ins.config) # type: ignore parameters_proto = ndarrays_to_parameters(parameters) return GetParametersRes( status=Status(code=Code.OK, message="Success"), parameters=parameters_proto ) def _fit(self: Client, ins: FitIns) -> FitRes: """Refine the provided parameters using the locally held dataset.""" # Deconstruct FitIns parameters: NDArrays = parameters_to_ndarrays(ins.parameters) # Train results = self.numpy_client.fit(parameters, ins.config) # type: ignore if not ( len(results) == 3 and isinstance(results[0], list) and isinstance(results[1], int) and isinstance(results[2], dict) ): raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results parameters_prime_proto = ndarrays_to_parameters(parameters_prime) return FitRes( status=Status(code=Code.OK, message="Success"), parameters=parameters_prime_proto, num_examples=num_examples, metrics=metrics, ) def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: """Evaluate the provided parameters using the locally held dataset.""" parameters: NDArrays = parameters_to_ndarrays(ins.parameters) results = self.numpy_client.evaluate(parameters, ins.config) # type: ignore if not ( len(results) == 3 and isinstance(results[0], float) and isinstance(results[1], int) and isinstance(results[2], dict) ): raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results return EvaluateRes( status=Status(code=Code.OK, message="Success"), loss=loss, num_examples=num_examples, metrics=metrics, ) def _get_context(self: Client) -> Context: """Return context of underlying NumPyClient.""" return self.numpy_client.get_context() # type: ignore def _set_context(self: Client, context: Context) -> None: """Apply context to underlying NumPyClient.""" self.numpy_client.set_context(context) # type: ignore def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: dict[str, Callable] = { # type: ignore "__init__": _constructor, "get_context": _get_context, "set_context": _set_context, } # Add wrapper type methods (if overridden) if has_get_properties(client=client): member_dict["get_properties"] = _get_properties if has_get_parameters(client=client): member_dict["get_parameters"] = _get_parameters if has_fit(client=client): member_dict["fit"] = _fit if has_evaluate(client=client): member_dict["evaluate"] = _evaluate # Create wrapper class wrapper_class = type("NumPyClientWrapper", (Client,), member_dict) # Create and return an instance of the newly created class return wrapper_class(numpy_client=client) # type: ignore