# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower client (abstract base class)."""
# Needed to `Client` class can return a type of `Client` (not needed in py3.11+)
from __future__ import annotations
from abc import ABC
from flwr.common import (
Code,
Context,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
Parameters,
Status,
)
from flwr.common.logger import warn_deprecated_feature_with_example
[docs]
class Client(ABC):
"""Abstract base class for Flower clients."""
_context: Context
[docs]
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Parameters
----------
ins : GetPropertiesIns
The get properties instructions received from the server containing
a dictionary of configuration values.
Returns
-------
GetPropertiesRes
The current client properties.
"""
_ = (self, ins)
return GetPropertiesRes(
status=Status(
code=Code.GET_PROPERTIES_NOT_IMPLEMENTED,
message="Client does not implement `get_properties`",
),
properties={},
)
[docs]
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
"""Return the current local model parameters.
Parameters
----------
ins : GetParametersIns
The get parameters instructions received from the server containing
a dictionary of configuration values.
Returns
-------
GetParametersRes
The current local model parameters.
"""
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.GET_PARAMETERS_NOT_IMPLEMENTED,
message="Client does not implement `get_parameters`",
),
parameters=Parameters(tensor_type="", tensors=[]),
)
[docs]
def fit(self, ins: FitIns) -> FitRes:
"""Refine the provided parameters using the locally held dataset.
Parameters
----------
ins : FitIns
The training instructions containing (global) model parameters
received from the server and a dictionary of configuration values
used to customize the local training process.
Returns
-------
FitRes
The training result containing updated parameters and other details
such as the number of local training examples used for training.
"""
_ = (self, ins)
return FitRes(
status=Status(
code=Code.FIT_NOT_IMPLEMENTED,
message="Client does not implement `fit`",
),
parameters=Parameters(tensor_type="", tensors=[]),
num_examples=0,
metrics={},
)
[docs]
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Evaluate the provided parameters using the locally held dataset.
Parameters
----------
ins : EvaluateIns
The evaluation instructions containing (global) model parameters
received from the server and a dictionary of configuration values
used to customize the local evaluation process.
Returns
-------
EvaluateRes
The evaluation result containing the loss on the local dataset and
other details such as the number of local data examples used for
evaluation.
"""
_ = (self, ins)
return EvaluateRes(
status=Status(
code=Code.EVALUATE_NOT_IMPLEMENTED,
message="Client does not implement `evaluate`",
),
loss=0.0,
num_examples=0,
metrics={},
)
@property
def context(self) -> Context:
"""Getter for `Context` client attribute."""
warn_deprecated_feature_with_example(
"Accessing the context via the client's attribute is deprecated.",
example_message="Instead, pass it to the client's "
"constructor in your `client_fn()` which already "
"receives a context object.",
code_example="def client_fn(context: Context) -> Client:\n\n"
"\t\t# Your existing client_fn\n\n"
"\t\t# Pass `context` to the constructor\n"
"\t\treturn FlowerClient(context).to_client()",
)
return self._context
@context.setter
def context(self, context: Context) -> None:
"""Setter for `Context` client attribute."""
self._context = context
[docs]
def get_context(self) -> Context:
"""Get the run context from this client."""
return self.context
[docs]
def set_context(self, context: Context) -> None:
"""Apply a run context to this client."""
self.context = context
[docs]
def to_client(self) -> Client:
"""Return client (itself)."""
return self
def has_get_properties(client: Client) -> bool:
"""Check if Client implements get_properties."""
return type(client).get_properties != Client.get_properties
def has_get_parameters(client: Client) -> bool:
"""Check if Client implements get_parameters."""
return type(client).get_parameters != Client.get_parameters
def has_fit(client: Client) -> bool:
"""Check if Client implements fit."""
return type(client).fit != Client.fit
def has_evaluate(client: Client) -> bool:
"""Check if Client implements evaluate."""
return type(client).evaluate != Client.evaluate
def maybe_call_get_properties(
client: Client, get_properties_ins: GetPropertiesIns
) -> GetPropertiesRes:
"""Call `get_properties` if the client overrides it."""
# Check if client overrides `get_properties`
if not has_get_properties(client=client):
# If client does not override `get_properties`, don't call it
status = Status(
code=Code.GET_PROPERTIES_NOT_IMPLEMENTED,
message="Client does not implement `get_properties`",
)
return GetPropertiesRes(
status=status,
properties={},
)
# If the client implements `get_properties`, call it
return client.get_properties(get_properties_ins)
def maybe_call_get_parameters(
client: Client, get_parameters_ins: GetParametersIns
) -> GetParametersRes:
"""Call `get_parameters` if the client overrides it."""
# Check if client overrides `get_parameters`
if not has_get_parameters(client=client):
# If client does not override `get_parameters`, don't call it
status = Status(
code=Code.GET_PARAMETERS_NOT_IMPLEMENTED,
message="Client does not implement `get_parameters`",
)
return GetParametersRes(
status=status,
parameters=Parameters(tensor_type="", tensors=[]),
)
# If the client implements `get_parameters`, call it
return client.get_parameters(get_parameters_ins)
def maybe_call_fit(client: Client, fit_ins: FitIns) -> FitRes:
"""Call `fit` if the client overrides it."""
# Check if client overrides `fit`
if not has_fit(client=client):
# If client does not override `fit`, don't call it
status = Status(
code=Code.FIT_NOT_IMPLEMENTED,
message="Client does not implement `fit`",
)
return FitRes(
status=status,
parameters=Parameters(tensor_type="", tensors=[]),
num_examples=0,
metrics={},
)
# If the client implements `fit`, call it
return client.fit(fit_ins)
def maybe_call_evaluate(client: Client, evaluate_ins: EvaluateIns) -> EvaluateRes:
"""Call `evaluate` if the client overrides it."""
# Check if client overrides `evaluate`
if not has_evaluate(client=client):
# If client does not override `evaluate`, don't call it
status = Status(
code=Code.EVALUATE_NOT_IMPLEMENTED,
message="Client does not implement `evaluate`",
)
return EvaluateRes(
status=status,
loss=0.0,
num_examples=0,
metrics={},
)
# If the client implements `evaluate`, call it
return client.evaluate(evaluate_ins)