Source code for flwr.server.server_app

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower ServerApp."""


import inspect
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Callable, Optional

from flwr.common import Context
from flwr.common.logger import warn_deprecated_feature_with_example
from flwr.server.strategy import Strategy

from .client_manager import ClientManager
from .compat import start_grid
from .grid import Driver, Grid
from .server import Server
from .server_config import ServerConfig
from .typing import ServerAppCallable, ServerFn

SERVER_FN_USAGE_EXAMPLE = """

        def server_fn(context: Context):
            server_config = ServerConfig(num_rounds=3)
            strategy = FedAvg()
            return ServerAppComponents(
                strategy=strategy,
                server_config=server_config,
        )

        app = ServerApp(server_fn=server_fn)
"""

GRID_USAGE_EXAMPLE = """
                app = ServerApp()

                @app.main()
                def main(grid: Grid, context: Context) -> None:
                    # Your existing ServerApp code ...
"""

DRIVER_DEPRECATION_MSG = """
            The `Driver` class is deprecated, it will be removed in a future release.
"""
DRIVER_EXAMPLE_MSG = """
            Instead, use `Grid` in the signature of your `ServerApp`. For example:
"""


@contextmanager
def _empty_lifespan(_: Context) -> Iterator[None]:
    yield


[docs] class ServerApp: # pylint: disable=too-many-instance-attributes """Flower ServerApp. Examples -------- Use the ``ServerApp`` with an existing ``Strategy``: >>> def server_fn(context: Context): >>> server_config = ServerConfig(num_rounds=3) >>> strategy = FedAvg() >>> return ServerAppComponents( >>> strategy=strategy, >>> server_config=server_config, >>> ) >>> >>> app = ServerApp(server_fn=server_fn) Use the ``ServerApp`` with a custom main function: >>> app = ServerApp() >>> >>> @app.main() >>> def main(grid: Grid, context: Context) -> None: >>> print("ServerApp running") """ # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, server_fn: Optional[ServerFn] = None, ) -> None: if any([server, config, strategy, client_manager]): warn_deprecated_feature_with_example( deprecation_message="Passing either `server`, `config`, `strategy` or " "`client_manager` directly to the ServerApp " "constructor is deprecated.", example_message="Pass `ServerApp` arguments wrapped " "in a `flwr.server.ServerAppComponents` object that gets " "returned by a function passed as the `server_fn` argument " "to the `ServerApp` constructor. For example: ", code_example=SERVER_FN_USAGE_EXAMPLE, ) if server_fn: raise ValueError( "Passing `server_fn` is incompatible with passing the " "other arguments (now deprecated) to ServerApp. " "Use `server_fn` exclusively." ) self._server = server self._config = config self._strategy = strategy self._client_manager = client_manager self._server_fn = server_fn self._main: Optional[ServerAppCallable] = None self._lifespan = _empty_lifespan def __call__(self, grid: Grid, context: Context) -> None: """Execute `ServerApp`.""" with self._lifespan(context): # Compatibility mode if not self._main: if self._server_fn: # Execute server_fn() components = self._server_fn(context) self._server = components.server self._config = components.config self._strategy = components.strategy self._client_manager = components.client_manager start_grid( server=self._server, config=self._config, strategy=self._strategy, client_manager=self._client_manager, grid=grid, ) return # New execution mode self._main(grid, context)
[docs] def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: """Return a decorator that registers the main fn with the server app. Examples -------- >>> app = ServerApp() >>> >>> @app.main() >>> def main(grid: Grid, context: Context) -> None: >>> print("ServerApp running") """ def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable: """Register the main fn with the ServerApp object.""" if self._server or self._config or self._strategy or self._client_manager: raise ValueError( """Use either a custom main function or a `Strategy`, but not both. Use the `ServerApp` with an existing `Strategy`: >>> server_config = ServerConfig(num_rounds=3) >>> strategy = FedAvg() >>> >>> app = ServerApp( >>> server_config=server_config, >>> strategy=strategy, >>> ) Use the `ServerApp` with a custom main function: >>> app = ServerApp() >>> >>> @app.main() >>> def main(grid: Grid, context: Context) -> None: >>> print("ServerApp running") """, ) sig = inspect.signature(main_fn) param = list(sig.parameters.values())[0] # Check if parameter name or the annotation should be updated if param.name == "driver" or param.annotation is Driver: warn_deprecated_feature_with_example( deprecation_message=DRIVER_DEPRECATION_MSG, example_message=DRIVER_EXAMPLE_MSG, code_example=GRID_USAGE_EXAMPLE, ) # Register provided function with the ServerApp object self._main = main_fn # Return provided function unmodified return main_fn return main_decorator
[docs] def lifespan( self, ) -> Callable[ [Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]] ]: """Return a decorator that registers the lifespan fn with the server app. The decorated function should accept a `Context` object and use `yield` to define enter and exit behavior. Examples -------- >>> app = ServerApp() >>> >>> @app.lifespan() >>> def lifespan(context: Context) -> None: >>> # Perform initialization tasks before the app starts >>> print("Initializing ServerApp") >>> >>> yield # ServerApp is running >>> >>> # Perform cleanup tasks after the app stops >>> print("Cleaning up ServerApp") """ def lifespan_decorator( lifespan_fn: Callable[[Context], Iterator[None]], ) -> Callable[[Context], Iterator[None]]: """Register the lifespan fn with the ServerApp object.""" @contextmanager def decorated_lifespan(context: Context) -> Iterator[None]: # Execute the code before `yield` in lifespan_fn try: if not isinstance(it := lifespan_fn(context), Iterator): raise StopIteration next(it) except StopIteration: raise RuntimeError( "lifespan function should yield at least once." ) from None try: # Enter the context yield finally: try: # Execute the code after `yield` in lifespan_fn next(it) except StopIteration: pass else: raise RuntimeError("lifespan function should only yield once.") # Register provided function with the ServerApp object # Ignore mypy error because of different argument names (`_` vs `context`) self._lifespan = decorated_lifespan # type: ignore # Return provided function unmodified return lifespan_fn return lifespan_decorator
class LoadServerAppError(Exception): """Error when trying to load `ServerApp`."""