Flower Strategy Abstraction¶
The strategy abstraction enables the implementation of fully custom federated learning
strategies. In Flower, a strategy is essentially the federated learning algorithm that
runs inside the ServerApp. Strategies define how to:
Sample clients
Configure instructions for training and evaluation
Aggregate updates and metrics
Evaluate models
Flower ships with a number of built-in strategies, all following the same API described below. You can also implement your own strategies with full access to the same capabilities.
The Strategy abstraction¶
All strategy implementations must derive from the abstract base class Strategy.
This includes both built-in strategies and third-party/custom strategies. By extending
this base class, user-defined strategies gain the exact same power and flexibility as
the built-in ones.
The Strategy base class defines a start method and requires subclasses to
implement several abstract methods:
class Strategy(ABC):
"""Abstract base class for server strategy implementations."""
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
@abstractmethod
def summary(self) -> None:
"""Log a summary of the strategy configuration."""
def start(
self,
grid: Grid,
initial_arrays: ArrayRecord,
num_rounds: int = 3,
timeout: float = 3600,
train_config: Optional[ConfigRecord] = None,
evaluate_config: Optional[ConfigRecord] = None,
evaluate_fn: Optional[
Callable[[int, ArrayRecord], Optional[MetricRecord]]
] = None,
) -> Result:
"""Execute the federated learning strategy."""
# Implementation details
pass
Creating a new strategy¶
You can customize an existing strategy (e.g., FedAvg) by overriding one or
several of its methods. For full flexibility, you can also implement a strategy from
scratch. To implement a brand new strategy, simply define a class that derives from
Strategy and implement the abstract methods:
class SotaStrategy(Strategy):
def configure_train(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_train(self, server_round, replies):
# Your implementation here
pass
def configure_evaluate(self, server_round, arrays, config, grid):
# Your implementation here
pass
def aggregate_evaluate(self, server_round, replies):
# Your implementation here
pass
def summary(self):
print("SotaStrategy: This is the state-of-the-art strategy!")
The start method is already implemented in the base class and typically does not
need to be overridden. It orchestrates the federated learning process by invoking the
abstract methods in sequence.
Understand start method¶
The start method of the Strategy base class follows this workflow:
Call
evaluate_fn(if provided) to evaluate the initial model on the ServerApp side.Call
configure_trainto generate training messages for ClientApps.Send training messages to ClientApps.
ClientApps run their
@app.train()function and return training replies.Call
aggregate_trainto aggregate the training replies.Call
configure_evaluateto generate evaluation messages for ClientApps.Send evaluation messages to ClientApps.
ClientApps run their
@app.evaluate()function and return evaluation replies.Call
aggregate_evaluateto aggregate the evaluation replies.Call
evaluate_fn(if provided) to evaluate the aggregated model on the ServerApp side.Repeat steps 2-10 for the specified number of rounds.
Return the final
Result, which contains the final model and metrics history.
The following diagram illustrates the flow.
참고
The sequence diagram below shows the interaction between ServerApp, Strategy
(inside ServerApp), and ClientApp. In reality, they do not communicate
directly over the network—Flower infrastructure (SuperLink and SuperNode)
transparently manages all communication. You can read more about it in the
Flower Network Communication guide.
The configure_train method¶
The configure_train method is responsible for preparing the next round of training.
But what does configure mean in this context? It means selecting which clients should
participate in the round and deciding what instructions they should receive.
Here is the method signature:
@abstractmethod
def configure_train(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of training."""
This method takes four arguments:
server_round: The current round numberarrays: The current global model parametersconfig: A configuration dictionary for the roundgrid: The object responsible for managing communication with clients
The return value is an iterable of Message objects, where each message contains the
instructions to be sent to a specific client. A typical implementation of
configure_train will:
Use the
gridto randomly sample a subset (or all) of the available clientsCreate one
Messageper selected client, containing the global model parameters and configuration values
More advanced strategies can implement custom client selection logic by using the
capabilities of grid. A client only participates in a round if configure_train
generates a message for its node ID.
참고
Because the return value is defined per client, strategies can easily implement heterogeneous configurations. For example, different clients can receive different models or hyperparameters, enabling highly customized training behaviors.
The aggregate_train method¶
The aggregate_train method is responsible for aggregating the training results
received from the clients selected in configure_train.
Here is the method signature:
@abstractmethod
def aggregate_train(
self,
server_round: int,
replies: Iterable[Message],
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
"""Aggregate training results from client nodes."""
This method takes two arguments:
server_round: The current round numberreplies: An iterable ofMessageobjects from the participating clients
It returns a tuple consisting of:
ArrayRecord: The updated global model parametersMetricRecord: Aggregated training metrics (such as loss or accuracy)
If aggregation cannot be performed (e.g., if too many clients failed during the round),
the method may decide to return (None, None) instead.
힌트
You can use Message.has_error() to check if a reply contains an error and decide
how to handle it during aggregation.
The configure_evaluate method¶
The configure_evaluate method is responsible for preparing the next round of
evaluation. Similar to configure_train, this involves selecting which clients should
participate and deciding what instructions they should receive for evaluation.
Here is the method signature:
@abstractmethod
def configure_evaluate(
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
"""Configure the next round of evaluation."""
This method takes four arguments:
server_round: The current round numberarrays: The current global model parameters to be evaluatedconfig: A configuration dictionary for evaluationgrid: The object that manages communication with clients
The return value is an iterable of Message objects, one for each selected client.
Each message typically contains the current global model parameters and any evaluation
configuration.
A typical implementation of configure_evaluate will:
Use
gridto select a subset (or all) of the available clientsCreate one
Messageper selected client containing the global model and evaluation configuration
As with training, more advanced strategies may apply custom client selection logic or send different evaluation configurations to different clients.
참고
Because each client receives its own message, strategies can implement heterogeneous evaluation setups. For example, some clients might evaluate on larger test sets, while others might use specialized metrics.
The aggregate_evaluate method¶
The aggregate_evaluate method is responsible for aggregating the evaluation results
received from the clients selected in configure_evaluate.
Here is the method signature:
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
replies: Iterable[Message],
) -> Optional[MetricRecord]:
"""Aggregate evaluation metrics from client nodes."""
This method takes two arguments:
server_round: The current round numberreplies: An iterable ofMessageobjects returned by the clients after they executed evaluation
It returns a single MetricRecord that represents the aggregated evaluation metrics
across all participating clients. If aggregation cannot be performed (for example, due
to excessive client failures or missing metrics), the method may return None.
힌트
As with training, Message.has_error() can be used to detect and handle client
errors during aggregation.