Flower Strategy Abstraction¶

The strategy abstraction enables the implementation of fully custom federated learning strategies. In Flower, a strategy is essentially the federated learning algorithm that runs inside the ServerApp. Strategies define how to:

  • Sample clients

  • Configure instructions for training and evaluation

  • Aggregate updates and metrics

  • Evaluate models

Flower ships with a number of built-in strategies, all following the same API described below. You can also implement your own strategies with full access to the same capabilities.

The Strategy abstraction¶

All strategy implementations must derive from the abstract base class Strategy. This includes both built-in strategies and third-party/custom strategies. By extending this base class, user-defined strategies gain the exact same power and flexibility as the built-in ones.

The Strategy base class defines a start method and requires subclasses to implement several abstract methods:

class Strategy(ABC):
    """Abstract base class for server strategy implementations."""

    @abstractmethod
    def configure_train(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of training."""

    @abstractmethod
    def aggregate_train(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
        """Aggregate training results from client nodes."""

    @abstractmethod
    def configure_evaluate(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of evaluation."""

    @abstractmethod
    def aggregate_evaluate(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> Optional[MetricRecord]:
        """Aggregate evaluation metrics from client nodes."""

    @abstractmethod
    def summary(self) -> None:
        """Log a summary of the strategy configuration."""

    def start(
        self,
        grid: Grid,
        initial_arrays: ArrayRecord,
        num_rounds: int = 3,
        timeout: float = 3600,
        train_config: Optional[ConfigRecord] = None,
        evaluate_config: Optional[ConfigRecord] = None,
        evaluate_fn: Optional[
            Callable[[int, ArrayRecord], Optional[MetricRecord]]
        ] = None,
    ) -> Result:
        """Execute the federated learning strategy."""
        # Implementation details
        pass

Creating a new strategy¶

You can customize an existing strategy (e.g., FedAvg) by overriding one or several of its methods. For full flexibility, you can also implement a strategy from scratch. To implement a brand new strategy, simply define a class that derives from Strategy and implement the abstract methods:

class SotaStrategy(Strategy):

    def configure_train(self, server_round, arrays, config, grid):
        # Your implementation here
        pass

    def aggregate_train(self, server_round, replies):
        # Your implementation here
        pass

    def configure_evaluate(self, server_round, arrays, config, grid):
        # Your implementation here
        pass

    def aggregate_evaluate(self, server_round, replies):
        # Your implementation here
        pass

    def summary(self):
        print("SotaStrategy: This is the state-of-the-art strategy!")

The start method is already implemented in the base class and typically does not need to be overridden. It orchestrates the federated learning process by invoking the abstract methods in sequence.

Understand start method¶

The start method of the Strategy base class follows this workflow:

  1. Call evaluate_fn (if provided) to evaluate the initial model on the ServerApp side.

  2. Call configure_train to generate training messages for ClientApps.

  3. Send training messages to ClientApps.

  4. ClientApps run their @app.train() function and return training replies.

  5. Call aggregate_train to aggregate the training replies.

  6. Call configure_evaluate to generate evaluation messages for ClientApps.

  7. Send evaluation messages to ClientApps.

  8. ClientApps run their @app.evaluate() function and return evaluation replies.

  9. Call aggregate_evaluate to aggregate the evaluation replies.

  10. Call evaluate_fn (if provided) to evaluate the aggregated model on the ServerApp side.

  11. Repeat steps 2-10 for the specified number of rounds.

  12. Return the final Result, which contains the final model and metrics history.

The following diagram illustrates the flow.

Note

The sequence diagram below shows the interaction between ServerApp, Strategy (inside ServerApp), and ClientApp. In reality, they do not communicate directly over the network—Flower infrastructure (SuperLink and SuperNode) transparently manages all communication. You can read more about it in the Flower Network Communication guide.

sequenceDiagram participant SA as ServerApp participant ST as Strategy participant CA as ClientApps SA->>ST: start(num_rounds, ...) opt ST->>ST: evaluate_fn() end loop rounds 1..N Note over ST: --- Training Phase --- ST->>ST: configure_train() ST->>CA: train_messages CA->>CA: @app.train() callback CA-->>ST: train_replies ST->>ST: aggregate_train(train_replies) Note over ST: --- Evaluation Phase --- ST->>ST: configure_evaluate() ST->>CA: evaluate_messages CA->>CA: @app.evaluate() callback CA-->>ST: evaluate_replies ST->>ST: aggregate_evaluate(evaluate_replies) opt ST->>ST: evaluate_fn() end end ST-->>SA: final Result

The configure_train method¶

The configure_train method is responsible for preparing the next round of training. But what does configure mean in this context? It means selecting which clients should participate in the round and deciding what instructions they should receive.

Here is the method signature:

@abstractmethod
def configure_train(
    self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
    """Configure the next round of training."""

This method takes four arguments:

  • server_round: The current round number

  • arrays: The current global model parameters

  • config: A configuration dictionary for the round

  • grid: The object responsible for managing communication with clients

The return value is an iterable of Message objects, where each message contains the instructions to be sent to a specific client. A typical implementation of configure_train will:

  • Use the grid to randomly sample a subset (or all) of the available clients

  • Create one Message per selected client, containing the global model parameters and configuration values

More advanced strategies can implement custom client selection logic by using the capabilities of grid. A client only participates in a round if configure_train generates a message for its node ID.

Note

Because the return value is defined per client, strategies can easily implement heterogeneous configurations. For example, different clients can receive different models or hyperparameters, enabling highly customized training behaviors.

The aggregate_train method¶

The aggregate_train method is responsible for aggregating the training results received from the clients selected in configure_train.

Here is the method signature:

@abstractmethod
def aggregate_train(
    self,
    server_round: int,
    replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
    """Aggregate training results from client nodes."""

This method takes two arguments:

  • server_round: The current round number

  • replies: An iterable of Message objects from the participating clients

It returns a tuple consisting of:

  1. ArrayRecord: The updated global model parameters

  2. MetricRecord: Aggregated training metrics (such as loss or accuracy)

If aggregation cannot be performed (e.g., if too many clients failed during the round), the method may decide to return (None, None) instead.

Indication

You can use Message.has_error() to check if a reply contains an error and decide how to handle it during aggregation.

The configure_evaluate method¶

The configure_evaluate method is responsible for preparing the next round of evaluation. Similar to configure_train, this involves selecting which clients should participate and deciding what instructions they should receive for evaluation.

Here is the method signature:

@abstractmethod
def configure_evaluate(
    self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
    """Configure the next round of evaluation."""

This method takes four arguments:

  • server_round: The current round number

  • arrays: The current global model parameters to be evaluated

  • config: A configuration dictionary for evaluation

  • grid: The object that manages communication with clients

The return value is an iterable of Message objects, one for each selected client. Each message typically contains the current global model parameters and any evaluation configuration.

A typical implementation of configure_evaluate will:

  • Use grid to select a subset (or all) of the available clients

  • Create one Message per selected client containing the global model and evaluation configuration

As with training, more advanced strategies may apply custom client selection logic or send different evaluation configurations to different clients.

Note

Because each client receives its own message, strategies can implement heterogeneous evaluation setups. For example, some clients might evaluate on larger test sets, while others might use specialized metrics.

The aggregate_evaluate method¶

The aggregate_evaluate method is responsible for aggregating the evaluation results received from the clients selected in configure_evaluate.

Here is the method signature:

@abstractmethod
def aggregate_evaluate(
    self,
    server_round: int,
    replies: Iterable[Message],
) -> Optional[MetricRecord]:
    """Aggregate evaluation metrics from client nodes."""

This method takes two arguments:

  • server_round: The current round number

  • replies: An iterable of Message objects returned by the clients after they executed evaluation

It returns a single MetricRecord that represents the aggregated evaluation metrics across all participating clients. If aggregation cannot be performed (for example, due to excessive client failures or missing metrics), the method may return None.

Indication

As with training, Message.has_error() can be used to detect and handle client errors during aggregation.