전략 사용하기

Flower allows full customization of the learning process through the Strategy abstraction. A number of built-in strategies are provided in the core framework.

There are four ways to customize the way Flower orchestrates the learning process on the server side:

  • Use an existing strategy, for example, FedAvg

  • Customize an existing strategy with callback functions to its start method

  • Customize an existing strategy by overriding one or more of its methods.

  • Implement a novel strategy from scratch

참고

Flower built-in strategies communicate one ArrayRecord and one MetricRecord in a Message to the ClientApps. The strategies expect replies containing one MetricRecord and, if it’s a round where ClientApps do local training, one ArrayRecord as well. The Message abstraction allows for unlimited records of any type. If you want to communicate multiple records you’d need to either expand an existing strategy or implement one from scratch.

기존 전략 사용

Flower comes with a number of popular federated learning Strategies which can be instantiated as follows as part of a simple ServerApp:

# Create ServerApp
app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Load global model
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())

    # Initialize FedAvg strategy with default settings
    strategy = FedAvg()

    # Start strategy, run FedAvg for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
    )

In the code above, instantiating FedAvg does not launch the logic built into the strategy (i.e. sampling nodes, communicating Message, performing aggregation, etc). In order to do so, we need to execute the start method.

The above ServerApp is very minimal, makes use of the default settings for FedAvg and only passes the required arguments to the start method. Let’s see in a bit more detail what options we have when instantiating strategies and when launching it.

Parameterizing an existing strategy

The constructor of strategies accepts different parameters based on, primarily, the aggregation algorithm they implement. For example, FedAdam accepts additional arguments (i.e. to apply momentum during aggregation) compared to those that FedAvg requires. However, common to all strategies are settings to control how nodes that run ClientApp instances get sampled. Let’s take a look at this set of arguments:

from flwr.serverapp.strategy import FedAvg

# Initialize FedAvg strategy
strategy = FedAvg(
    fraction_train=0.5,  # fraction of nodes to involve in a round of training
    fraction_evaluate=1.0,  # fraction of nodes to involve in a round of evaluation
    min_available_nodes=100,  # minimum connected nodes required before FL starts
)

For most applications specifying one or all of the arguments shown above is sufficient. A Flower strategy defined like the one above would wait for 100 nodes to be connected before any federated stage begins. Then, 50% of the connected nodes will be involved in a stage of federated training, followed by another stage of federated evaluation where all connected nodes will participate. It is possible to set the min_train_nodes and min_evaluate_nodes arguments for finer control.

In addition to arguments to customize how the strategy performs sampling, we can define at construction time which keys will be used to communicate different information between the strategy in the ServerApp and the ClientApp. Note that these keys are used in both types of stages within the strategy start logic, i.e. federated training and federated evaluation.

from flwr.serverapp.strategy import FedAvg

# Initialize FedAvg strategy
# Here we define our own keys instead of using the default
strategy = FedAvg(
    arrayrecord_key="my-arrays",
    configrecord_key="super-config",
    weighted_by_key="num-batches",
)
  • arrayrecord_key: the Message communicated to the ClientApp will contain an ArrayRecord containing the arrays of the global model under this key. By default the key is "arrays".

  • configrecord_key: the Message communicated to the ClientApp will contain a ConfigRecord containing config settings. By default the key is "config".

  • weighted_by_key: A key inside the MetricRecord that the ClientApp returns as part of its reply to the ServerApp. The value under this key is used to perform weighted aggregation of MetricRecords and, after a round of federated training, ArrayRecords. The default value is "num-examples".

With a strategy defined as in the code snippet above, the ClientApp should receive a Message with the following structure:

# The content of a Message arriving to the ClientApp will have
# the following structure and using the keys defined in the strategy
msg = Message(
    # ....
    content=RecordDict(
        {
            "my-arrays": ArrayRecord(...),
            "super-config": ConfigRecord(...),
        }
    )
)

# The reply Message should contain a MetricRecord and inside it
# an item associated with the key used to initialize the strategy
reply_msg_content = RecordDict(
    {
        "locally-updated-params": ArrayRecord(...),
        "local-metrics": MetricRecord(
            {
                "num-batches": N,
                # ... Other metrics
            }
        ),
    }
)

참고

While the strategies fix the keys used to communicate the ArrayRecord and MetricRecord to the ClientApps, the replies these send back to the ServerApp can use different keys. In the code snippet above we used "locally-updated-params" and "local-metrics". However, all ClientApps need to use the same keys in their reply Messages otherwise the aggregation of replies (ArrayRecord and MetricRecord) cannot be performed.

Finally, the strategy constructor also allows passing two callbacks to control how the MetricRecords in the replies that ClientApps send are aggregated. Follow the 종합 평가 결과 guide for a walkthrough on how to define these callbacks.

Using the strategy’s start method

As mentioned earlier, it is the start method of the strategy that launches the federated learning process. Let’s see what each argument passed to this method represents.

Check the Flower Strategy Abstraction explainer for a deep dive into how the different stages implemented as part of the start method operate.

The only required arguments are the Grid and an ArrayRecord. The former is an object that will be used to interface with the nodes running the ClientApp to involve them in a round of train/evaluate/query or other. The latter contains the parameters of the model we want to federate. Therefore, a minimal execution of the start method looks like this:

# Start strategy
result = strategy.start(
    grid=grid,
    initial_arrays=ArrayRecord(...),
)

In most settings, we want to customize how the start method is executed by passing also the number of rounds to execute and, a pair of ConfigRecord objects to be sent to the ClientApp during a step of training and evaluation respectively.

# Define configs to send to ClientApp
train_cfg = ConfigRecord({"lr": 0.1, "optim": "adam"})
eval_cfg = ConfigRecord({"max-steps": 500, "local-checkpoint": True})

# Start strategy
result = strategy.start(
    grid=grid,
    initial_arrays=ArrayRecord(...),
    train_config=train_cfg,
    evaluate_config=eval_cfg,
    num_rounds=100,
)

The start method also allows you to limit for how long the strategy will wait for replies from the ClientApps until it proceeds with the rest of the stages. This can be controlled with the argument timeout (which defaults to 3600s, i.e., 1h). For example, if we want to increase the timeout to 2 hours, we would do:

# Define configs to send to ClientApp
train_cfg = ConfigRecord({"lr": 0.1, "optim": "adam"})
eval_cfg = ConfigRecord({"max-steps": 500, "local-checkpoint": True})

# Start strategy
result = strategy.start(
    grid=grid,
    initial_arrays=ArrayRecord(...),
    train_config=train_cfg,
    evaluate_config=eval_cfg,
    num_rounds=100,
    timeout=7200,  # 2 hours
)

Finally, the last argument in start is named evaluate_fn and it allows passing to it a callback function to evaluate the aggregated model on some local data that the ServerApp might have access to. This callback is also useful if you want to save the global model at the end of every round (or every N rounds). Let’s see what the signature of this callback is and how to use it:

# Callback definition. The function can have any name
# but the arguments are fixed
def my_callback(server_round: int, arrays: ArrayRecord) -> MetricRecord:
    """Evaluate model on central data."""

    # Save checkpoint
    state_dict = arrays.to_torch_state_dict()
    torch.save(state_dict, f"model_at_round_{server_round}.pt")

    # eval model on local data
    model = MyModel()
    model.load_state_dict(state_dict)
    acc, loss = test(model, ...)

    # Return MetricRecord
    return MetricRecord({"acc": acc, "loss": loss})


# Pass the callback to the start method
strategy.start(..., evaluate_fn=my_callback)

Take a look at the quickstart-pytorch example on GitHub for a complete example using several of the concepts presented in this how-to guide.