Flower Strategy Abstraction¶
The strategy abstraction enables the implementation of fully custom federated learning
strategies. In Flower, a strategy is essentially the federated learning algorithm that
runs inside the ServerApp
. Strategies define how to:
Sample clients
Configure instructions for training and evaluation
Aggregate updates and metrics
Evaluate models
Flower ships with a number of built-in strategies, all following the same API described below. You can also implement your own strategies with full access to the same capabilities.
The Strategy
abstraction¶
All strategy implementations must derive from the abstract base class Strategy
.
This includes both built-in strategies and third-party/custom strategies. By extending
this base class, user-defined strategies gain the exact same power and flexibility as
the built-in ones.
The Strategy
base class defines a start
method and requires subclasses to
implement several abstract methods:
class Strategy(ABC):
"""Abstract base class for server strategy implementations."""
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
@abstractmethod
def summary(self) -> None:
"""Log a summary of the strategy configuration."""
def start(
self,
grid: Grid,
initial_arrays: ArrayRecord,
num_rounds: int = 3,
timeout: float = 3600,
train_config: Optional[ConfigRecord] = None,
evaluate_config: Optional[ConfigRecord] = None,
evaluate_fn: Optional[
Callable[[int, ArrayRecord], Optional[MetricRecord]]
] = None,
) -> Result:
"""Execute the federated learning strategy."""
# Implementation details
pass
Creating a new strategy¶
You can customize an existing strategy (e.g., FedAvg
) by overriding one or
several of its methods. For full flexibility, you can also implement a strategy from
scratch. To implement a brand new strategy, simply define a class that derives from
Strategy
and implement the abstract methods:
class SotaStrategy(Strategy):
def configure_train(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_train(self, server_round, replies):
# Your implementation here
pass
def configure_evaluate(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_evaluate(self, server_round, replies):
# Your implementation here
pass
def summary(self):
print("SotaStrategy: This is the state-of-the-art strategy!")
The start
method is already implemented in the base class and typically does not
need to be overridden. It orchestrates the federated learning process by invoking the
abstract methods in sequence.
Understand start
method¶
The start
method of the Strategy
base class follows this workflow:
Call
evaluate_fn
(if provided) to evaluate the initial model on the ServerApp side.Call
configure_train
to generate training messages for ClientApps.Send training messages to ClientApps.
ClientApps run their
@app.train()
function and return training replies.Call
aggregate_train
to aggregate the training replies.Call
configure_evaluate
to generate evaluation messages for ClientApps.Send evaluation messages to ClientApps.
ClientApps run their
@app.evaluate()
function and return evaluation replies.Call
aggregate_evaluate
to aggregate the evaluation replies.Call
evaluate_fn
(if provided) to evaluate the aggregated model on the ServerApp side.Repeat steps 2-10 for the specified number of rounds.
Return the final
Result
, which contains the final model and metrics history.
The following diagram illustrates the flow.
Note
The sequence diagram below shows the interaction between ServerApp
, Strategy
(inside ServerApp
), and ClientApp
. In reality, they do not communicate
directly over the network—Flower infrastructure (SuperLink
and SuperNode
)
transparently manages all communication. You can read more about it in the
Flower Network Communication guide.
The configure_train
method¶
The configure_train
method is responsible for preparing the next round of training.
But what does configure mean in this context? It means selecting which clients should
participate in the round and deciding what instructions they should receive.
Here is the method signature:
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
This method takes four arguments:
server_round
: The current round numberarrays
: The current global model parametersconfig
: A configuration dictionary for the roundgrid
: The object responsible for managing communication with clients
The return value is an iterable of Message
objects, where each message contains the
instructions to be sent to a specific client. A typical implementation of
configure_train
will:
Use the
grid
to randomly sample a subset (or all) of the available clientsCreate one
Message
per selected client, containing the global model parameters and configuration values
More advanced strategies can implement custom client selection logic by using the
capabilities of grid
. A client only participates in a round if configure_train
generates a message for its node ID.
Note
Because the return value is defined per client, strategies can easily implement heterogeneous configurations. For example, different clients can receive different models or hyperparameters, enabling highly customized training behaviors.
The aggregate_train
method¶
The aggregate_train
method is responsible for aggregating the training results
received from the clients selected in configure_train
.
Here is the method signature:
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
This method takes two arguments:
server_round
: The current round numberreplies
: An iterable ofMessage
objects from the participating clients
It returns a tuple consisting of:
ArrayRecord
: The updated global model parametersMetricRecord
: Aggregated training metrics (such as loss or accuracy)
If aggregation cannot be performed (e.g., if too many clients failed during the round),
the method may decide to return (None, None)
instead.
Indication
You can use Message.has_error()
to check if a reply contains an error and decide
how to handle it during aggregation.
The configure_evaluate
method¶
The configure_evaluate
method is responsible for preparing the next round of
evaluation. Similar to configure_train
, this involves selecting which clients should
participate and deciding what instructions they should receive for evaluation.
Here is the method signature:
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
This method takes four arguments:
server_round
: The current round numberarrays
: The current global model parameters to be evaluatedconfig
: A configuration dictionary for evaluationgrid
: The object that manages communication with clients
The return value is an iterable of Message
objects, one for each selected client.
Each message typically contains the current global model parameters and any evaluation
configuration.
A typical implementation of configure_evaluate
will:
Use
grid
to select a subset (or all) of the available clientsCreate one
Message
per selected client containing the global model and evaluation configuration
As with training, more advanced strategies may apply custom client selection logic or send different evaluation configurations to different clients.
Note
Because each client receives its own message, strategies can implement heterogeneous evaluation setups. For example, some clients might evaluate on larger test sets, while others might use specialized metrics.
The aggregate_evaluate
method¶
The aggregate_evaluate
method is responsible for aggregating the evaluation results
received from the clients selected in configure_evaluate
.
Here is the method signature:
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
This method takes two arguments:
server_round
: The current round numberreplies
: An iterable ofMessage
objects returned by the clients after they executed evaluation
It returns a single MetricRecord
that represents the aggregated evaluation metrics
across all participating clients. If aggregation cannot be performed (for example, due
to excessive client failures or missing metrics), the method may return None
.
Indication
As with training, Message.has_error()
can be used to detect and handle client
errors during aggregation.