Source code for flwr.client.client_app

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower ClientApp."""


import inspect
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Callable, Optional

from flwr.client.client import Client
from flwr.client.message_handler.message_handler import (
    handle_legacy_message_from_msgtype,
)
from flwr.client.mod.utils import make_ffn
from flwr.client.typing import ClientFnExt, Mod
from flwr.common import Context, Message, MessageType
from flwr.common.logger import warn_deprecated_feature
from flwr.common.message import validate_message_type

from .typing import ClientAppCallable

DEFAULT_ACTION = "default"


def _alert_erroneous_client_fn() -> None:
    raise ValueError(
        "A `ClientApp` cannot make use of a `client_fn` that does "
        "not have a signature in the form: `def client_fn(context: "
        "Context)`. You can import the `Context` like this: "
        "`from flwr.common import Context`"
    )


def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
    client_fn_args = inspect.signature(client_fn).parameters

    if len(client_fn_args) != 1:
        _alert_erroneous_client_fn()

    first_arg = list(client_fn_args.keys())[0]
    first_arg_type = client_fn_args[first_arg].annotation

    if first_arg_type is str or first_arg == "cid":
        # Warn previous signature for `client_fn` seems to be used
        warn_deprecated_feature(
            "`client_fn` now expects a signature `def client_fn(context: Context)`."
            "The provided `client_fn` has signature: "
            f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
            " `from flwr.common import Context`"
        )

        # Wrap depcreated client_fn inside a function with the expected signature
        def adaptor_fn(
            context: Context,
        ) -> Client:  # pylint: disable=unused-argument
            # if patition-id is defined, pass it. Else pass node_id that should
            # always be defined during Context init.
            cid = context.node_config.get("partition-id", context.node_id)
            return client_fn(str(cid))  # type: ignore

        return adaptor_fn

    return client_fn


@contextmanager
def _empty_lifespan(_: Context) -> Iterator[None]:
    yield


class ClientAppException(Exception):
    """Exception raised when an exception is raised while executing a ClientApp."""

    def __init__(self, message: str):
        ex_name = self.__class__.__name__
        self.message = f"\nException {ex_name} occurred. Message: " + message
        super().__init__(self.message)


[docs] class ClientApp: """Flower ClientApp. Examples -------- Assuming a typical `Client` implementation named `FlowerClient`, you can wrap it in a `ClientApp` as follows: >>> class FlowerClient(NumPyClient): >>> # ... >>> >>> def client_fn(context: Context): >>> return FlowerClient().to_client() >>> >>> app = ClientApp(client_fn) """ def __init__( self, client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility mods: Optional[list[Mod]] = None, ) -> None: self._mods: list[Mod] = mods if mods is not None else [] self._registered_funcs: dict[str, ClientAppCallable] = {} # Create wrapper function for `handle` self._call: Optional[ClientAppCallable] = None if client_fn is not None: client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn) def ffn( message: Message, context: Context, ) -> Message: # pylint: disable=invalid-name out_message = handle_legacy_message_from_msgtype( client_fn=client_fn, message=message, context=context ) return out_message # Wrap mods around the wrapped handle function self._call = make_ffn(ffn, mods if mods is not None else []) # Lifespan function self._lifespan = _empty_lifespan def __call__(self, message: Message, context: Context) -> Message: """Execute `ClientApp`.""" with self._lifespan(context): # Execute message using `client_fn` if self._call: return self._call(message, context) # Get the category and the action # A valid message type is of the form "<category>" or "<category>.<action>", # where <category> must be "train"/"evaluate"/"query", and <action> is a # valid Python identifier if not validate_message_type(message.metadata.message_type): raise ValueError( f"Invalid message type: {message.metadata.message_type}" ) category, action = message.metadata.message_type, DEFAULT_ACTION if "." in category: category, action = category.split(".") # Check if the function is registered if (full_name := f"{category}.{action}") in self._registered_funcs: return self._registered_funcs[full_name](message, context) raise ValueError(f"No {category} function registered with name '{action}'")
[docs] def train( self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None ) -> Callable[[ClientAppCallable], ClientAppCallable]: """Register a train function with the ``ClientApp``. Parameters ---------- action : str (default: "default") The action name used to route messages. Defaults to "default". mods : Optional[list[Mod]] (default: None) A list of function-specific modifiers. Returns ------- Callable[[ClientAppCallable], ClientAppCallable] A decorator that registers a train function with the ``ClientApp``. Examples -------- Registering a train function: >>> app = ClientApp() >>> >>> @app.train() >>> def train(message: Message, context: Context) -> Message: >>> print("Executing default train function") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) Registering a train function with a custom action name: >>> app = ClientApp() >>> >>> # Messages with `message_type="train.custom_action"` will be >>> # routed to this function. >>> @app.train("custom_action") >>> def custom_action(message: Message, context: Context) -> Message: >>> print("Executing train function for custom action") >>> return Message(message.content, reply_to=message) Registering a train function with a function-specific Flower Mod: >>> from flwr.client.mod import message_size_mod >>> >>> app = ClientApp() >>> >>> # Using the `mods` argument to apply a function-specific mod. >>> @app.train(mods=[message_size_mod]) >>> def train(message: Message, context: Context) -> Message: >>> print("Executing train function with message size mod") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) """ return _get_decorator(self, MessageType.TRAIN, action, mods)
[docs] def evaluate( self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None ) -> Callable[[ClientAppCallable], ClientAppCallable]: """Register an evaluate function with the ``ClientApp``. Parameters ---------- action : str (default: "default") The action name used to route messages. Defaults to "default". mods : Optional[list[Mod]] (default: None) A list of function-specific modifiers. Returns ------- Callable[[ClientAppCallable], ClientAppCallable] A decorator that registers an evaluate function with the ``ClientApp``. Examples -------- Registering an evaluate function: >>> app = ClientApp() >>> >>> @app.evaluate() >>> def evaluate(message: Message, context: Context) -> Message: >>> print("Executing default evaluate function") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) Registering an evaluate function with a custom action name: >>> app = ClientApp() >>> >>> # Messages with `message_type="evaluate.custom_action"` will be >>> # routed to this function. >>> @app.evaluate("custom_action") >>> def custom_action(message: Message, context: Context) -> Message: >>> print("Executing evaluate function for custom action") >>> return Message(message.content, reply_to=message) Registering an evaluate function with a function-specific Flower Mod: >>> from flwr.client.mod import message_size_mod >>> >>> app = ClientApp() >>> >>> # Using the `mods` argument to apply a function-specific mod. >>> @app.evaluate(mods=[message_size_mod]) >>> def evaluate(message: Message, context: Context) -> Message: >>> print("Executing evaluate function with message size mod") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) """ return _get_decorator(self, MessageType.EVALUATE, action, mods)
[docs] def query( self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None ) -> Callable[[ClientAppCallable], ClientAppCallable]: """Register a query function with the ``ClientApp``. Parameters ---------- action : str (default: "default") The action name used to route messages. Defaults to "default". mods : Optional[list[Mod]] (default: None) A list of function-specific modifiers. Returns ------- Callable[[ClientAppCallable], ClientAppCallable] A decorator that registers a query function with the ``ClientApp``. Examples -------- Registering a query function: >>> app = ClientApp() >>> >>> @app.query() >>> def query(message: Message, context: Context) -> Message: >>> print("Executing default query function") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) Registering a query function with a custom action name: >>> app = ClientApp() >>> >>> # Messages with `message_type="query.custom_action"` will be >>> # routed to this function. >>> @app.query("custom_action") >>> def custom_action(message: Message, context: Context) -> Message: >>> print("Executing query function for custom action") >>> return Message(message.content, reply_to=message) Registering a query function with a function-specific Flower Mod: >>> from flwr.client.mod import message_size_mod >>> >>> app = ClientApp() >>> >>> # Using the `mods` argument to apply a function-specific mod. >>> @app.query(mods=[message_size_mod]) >>> def query(message: Message, context: Context) -> Message: >>> print("Executing query function with message size mod") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) """ return _get_decorator(self, MessageType.QUERY, action, mods)
[docs] def lifespan( self, ) -> Callable[ [Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]] ]: """Return a decorator that registers the lifespan fn with the client app. The decorated function should accept a `Context` object and use `yield` to define enter and exit behavior. Examples -------- >>> app = ClientApp() >>> >>> @app.lifespan() >>> def lifespan(context: Context) -> None: >>> # Perform initialization tasks before the app starts >>> print("Initializing ClientApp") >>> >>> yield # ClientApp is running >>> >>> # Perform cleanup tasks after the app stops >>> print("Cleaning up ClientApp") """ def lifespan_decorator( lifespan_fn: Callable[[Context], Iterator[None]] ) -> Callable[[Context], Iterator[None]]: """Register the lifespan fn with the ServerApp object.""" @contextmanager def decorated_lifespan(context: Context) -> Iterator[None]: # Execute the code before `yield` in lifespan_fn try: if not isinstance(it := lifespan_fn(context), Iterator): raise StopIteration next(it) except StopIteration: raise RuntimeError( "lifespan function should yield at least once." ) from None try: # Enter the context yield finally: try: # Execute the code after `yield` in lifespan_fn next(it) except StopIteration: pass else: raise RuntimeError("lifespan function should only yield once.") # Register provided function with the ClientApp object # Ignore mypy error because of different argument names (`_` vs `context`) self._lifespan = decorated_lifespan # type: ignore # Return provided function unmodified return lifespan_fn return lifespan_decorator
class LoadClientAppError(Exception): """Error when trying to load `ClientApp`.""" def _get_decorator( app: ClientApp, category: str, action: str, mods: Optional[list[Mod]] ) -> Callable[[ClientAppCallable], ClientAppCallable]: """Get the decorator for the given category and action.""" # pylint: disable=protected-access if app._call: raise _registration_error(category) def decorator(fn: ClientAppCallable) -> ClientAppCallable: # Check if the name is a valid Python identifier if not action.isidentifier(): raise ValueError( f"Cannot register {category} function with name '{action}'. " "The name must follow Python's function naming rules." ) # Check if the name is already registered full_name = f"{category}.{action}" # Full name of the message type if full_name in app._registered_funcs: raise ValueError( f"Cannot register {category} function with name '{action}'. " f"A {category} function with the name '{action}' is already registered." ) # Register provided function with the ClientApp object app._registered_funcs[full_name] = make_ffn(fn, app._mods + (mods or [])) # Return provided function unmodified return fn # pylint: enable=protected-access return decorator def _registration_error(fn_name: str) -> ValueError: return ValueError( f"""Use either `@app.{fn_name}()` or `client_fn`, but not both. Use the `ClientApp` with an existing `client_fn`: >>> class FlowerClient(NumPyClient): >>> # ... >>> >>> def client_fn(context: Context): >>> return FlowerClient().to_client() >>> >>> app = ClientApp( >>> client_fn=client_fn, >>> ) Use the `ClientApp` with a custom {fn_name} function: >>> app = ClientApp() >>> >>> @app.{fn_name}() >>> def {fn_name}(message: Message, context: Context) -> Message: >>> print("ClientApp {fn_name} running") >>> # Create and return an echo reply message >>> return Message(message.content, reply_to=message) """, )