使用联邦学习策略

Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1).

In part 2, we'll begin to customize the federated learning system we built in part 1 using the Flower framework, Flower Datasets, and PyTorch.

Tip

Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:

  • Join Flower Discuss We'd love to hear from you in the Introduction topic! If anything is unclear, post in Flower Help - Beginners.

  • Join Flower Slack We'd love to hear from you in the #introductions channel! If anything is unclear, head over to the #questions channel.

Let's move beyond FedAvg with Flower strategies! 🌼

准备工作

在开始实际代码之前,让我们先确保我们已经准备好了所需的一切。

安装依赖项

Note

If you've completed part 1 of the tutorial, you can skip this step.

First, we install the Flower package flwr:

# In a new Python environment
$ pip install -U "flwr[simulation]"

Then, we create a new Flower app called flower-tutorial using the PyTorch template. We also specify a username (flwrlabs) for the project:

$ flwr new flower-tutorial --framework pytorch --username flwrlabs

After running the command, a new directory called flower-tutorial will be created. It should have the following structure:

flower-tutorial
├── README.md
├── flower_tutorial
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

Next, we install the project and its dependencies, which are specified in the pyproject.toml file:

$ cd flower-tutorial
$ pip install -e .

So far, everything should look familiar if you've worked through the introductory tutorial. With that, we're ready to introduce a number of new features.

Choosing a different strategy

In part 1, we created a ServerApp (in server_app.py). In it, we defined the strategy, the model to federatedly train, and then we launched the strategy by calling its |strategy_start_link| method.

The strategy encapsulates the federated learning approach/algorithm, for example, FedAvg. Let's try to use a different strategy this time. Modify the following lines in your server_app.py to switch from FedAvg to FedAdagrad.

from flwr.serverapp.strategy import FedAdagrad


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Read run config
    fraction_train: float = context.run_config["fraction-train"]
    num_rounds: int = context.run_config["num-server-rounds"]
    lr: float = context.run_config["lr"]

    # Load global model
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())

    # Initialize FedAdagrad strategy
    strategy = FedAdagrad(fraction_train=fraction_train)

    # Start strategy, run FedAdagrad for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
    )

    # Save final model to disk
    print("\nSaving final model to disk...")
    state_dict = result.arrays.to_torch_state_dict()
    torch.save(state_dict, "final_model.pt")

Next, run the training with the following command:

$ flwr run .

服务器端参数**评估**

Flower can evaluate the aggregated model on the server side or on the client side. Client-side and server-side evaluation are similar in some ways, but different in others.

**集中评估**(或*服务器端评估*)在概念上很简单:它的工作方式与集中式机器学习中的评估方式相同。如果有一个服务器端数据集可用于评估目的,那就太好了。我们可以在每一轮训练后对新聚合的模型进行评估,而无需将模型发送给客户端。我们也很幸运,因为我们的整个评估数据集随时可用。

联邦评估**(或*客户端评估*)更为复杂,但也更为强大:它不需要集中的数据集,允许我们在更大的数据集上对模型进行评估,这通常会产生更真实的评估结果。事实上,如果我们想得到有代表性的评估结果,很多情况下都需要使用**联邦评估。但是,这种能力是有代价的:一旦我们开始在客户端进行评估,我们就应该意识到,如果这些客户端并不总是可用,我们的评估数据集可能会在连续几轮学习中发生变化。此外,每个客户端所拥有的数据集也可能在连续几轮学习中发生变化。这可能会导致评估结果不稳定,因此即使我们不改变模型,也会看到评估结果在连续几轮中波动。

We've seen how federated evaluation works on the client side (i.e., by implementing a function wrapped with the @app.evaluate decorator in your ClientApp). Now let's see how we can evaluate the aggregated model parameters on the server side.

To do so, we need to create a new function in task.py that we can name central_evaluate. This function is a callback that will be passed to the start method of our strategy. This means that the strategy will call this function after every round of federated learning passing two arguments: the current round of federated learning and the aggregated model parameters.

Our central_evaluate function performs the following steps:

  1. Load the aggregated model parameters into a PyTorch model

  2. Load the entire CIFAR10 test dataset

  3. Evaluate the model on the test dataset

  4. Return the evaluation metrics as a MetricRecord

from datasets import load_dataset
from flwr.app import ArrayRecord, MetricRecord


def central_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
    """Evaluate model on the server side."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(arrays.to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the entire CIFAR10 test dataset
    # It's a huggingface dataset, so we can load it directly and apply transforms
    cifar10_test = load_dataset("cifar10", split="test")
    pytorch_transforms = Compose(
        [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    # Define transforms and construct DataLoader for the test set
    def apply_transforms(batch):
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    testset = cifar10_test.with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=64)

    # Evaluate the model on the test set
    loss, accuracy = test(model, testloader, device)

    # Return the evaluation metrics
    return MetricRecord({"accuracy": accuracy, "loss": loss})

Remember we mentioned this central_evaluate will be called by the strategy. To do so we need to pass it to the strategy's start method as shown below.

from flower_tutorial.task import central_evaluate


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # ... unchanged

    # Start strategy, run FedAdagrad for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=central_evaluate,
    )

    # .. unchanged

Finally, we run the simulation.

$ flwr run .

You'll note that the server logs the metrics returned by the callback after each round. Also, at the end of the run, note the ServerApp-side Evaluate Metrics shown:

INFO :          ServerApp-side Evaluate Metrics:
INFO :          { 0: {'accuracy': '1.0000e-01', 'loss': '2.3053e+00'},
INFO :            1: {'accuracy': '1.0000e-01', 'loss': '2.3203e+00'},
INFO :            2: {'accuracy': '2.3230e-01', 'loss': '2.0144e+00'},
INFO :            3: {'accuracy': '2.5720e-01', 'loss': '1.9258e+00'}}

Sending configurations to clients from strategies

In some situations, we want to configure client-side execution (training, evaluation) from the server side. One example of this is the server asking the clients to train for with a different learning rate based on the current round number. Flower provides a way to send configuration values from the server to the clients as part of the Message that the ClientApp receives. Let's see how we can do this.

To the start method of our strategy we are already passing a ConfigRecord specifying the initial learning rate. This ConfigRecord will be sent to the clients in all the Messages addressing the @app.train() function of the ClientApp. Let's say we want to decrease the learning rate by a factor of 0.5 every 5 rounds, then we need to override the configure_train method of our strategy and embed such logic.

To do so, we create a new class inheriting from FedAdagrad and override the configure_train method. We then use this new strategy in our ServerApp. Let's see how this looks like in code. Create a new file called custom_strategy.py in the flower_tutorial directory and add the following code:

from typing import Iterable
from flwr.serverapp import Grid
from flwr.serverapp.strategy import FedAdagrad
from flwr.app import ArrayRecord, ConfigRecord, Message


class CustomFedAdagrad(FedAdagrad):
    def configure_train(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of federated training and maybe do LR decay."""
        # Decrease learning rate by a factor of 0.5 every 5 rounds
        if server_round % 5 == 0 and server_round > 0:
            config["lr"] *= 0.5
            print("LR decreased to:", config["lr"])
        # Pass the updated config and the rest of arguments to the parent class
        return super().configure_train(server_round, arrays, config, grid)

Next, we use this new strategy in our ServerApp by importing it in your server_app.py and use it instead of the standard FedAdagrad.

Finally, run the training with the following command. Here we increase the number of rounds to 15 to see the learning rate decay in action.

$ flwr run . --run-config="num-server-rounds=15"

You'll note that in the configure_train stage of rounds 5 and 10, the learning rate is decreased by a factor of 0.5 and the new learning rate is printed to the terminal.

How do we know the ClientApp is using that new learning rate? Recall that in client_app.py, we are reading the learning rate from the Message received by the @app.train() function:

@app.train()
def train(msg: Message, context: Context):

    # ... setup

    # Call the training function
    train_loss = train_fn(
        model,
        trainloader,
        context.run_config["local-epochs"],
        msg.content["config"]["lr"],
        device,
    )

    # ... prepare reply Message
    return Message(content=content, reply_to=msg)

Congratulations! You have created your first custom strategy adding dynamism to the ConfigRecord that is sent to clients.

扩大联邦学习的规模

As a last step in this tutorial, let's see how we can use Flower to experiment with a large number of clients. In the pyproject.toml, increase the number of SuperNodes to 1000:

[tool.flwr.federations.local-simulation]
options.num-supernodes = 1000

Note that we can reuse the ClientApp for different num-supernodes since the Context carries the num-partitions key and for simulations with Flower, the number of partitions is equal to the number of SuperNodes.

We now have 1000 partitions, each holding 45 training and 5 validation examples. Given that the number of training examples on each client is quite small, we should probably train the model a bit longer, so we configure the clients to perform 3 local training epochs. We should also adjust the fraction of clients selected for training during each round (we don't want all 1000 clients participating in every round), so we adjust fraction_train to 0.025, which means that only 2.5% of available clients (so 25 clients) will be selected for training each round. We update the fraction-train value in the pyproject.toml:

[tool.flwr.app.config]
fraction-train = 0.025

Then, we update the initialization of our strategy in server_app.py to the following:

@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # ... unchanged
    # Initialize FedAdagrad strategy
    strategy = CustomFedAdagrad(
        fraction_train=fraction_train,
        fraction_evaluate=0.05,  # Evaluate on 50 clients (each round)
        min_train_nodes=20,  # Optional config
        min_evaluate_nodes=40,  # Optional config
        min_available_nodes=1000,  # Optional config
    )

    # ... rest unchanged

Finally, run the simulation with the following command:

$ flwr run .

回顾

In this tutorial, we've seen how we can gradually enhance our system by customizing the strategy, choosing a different strategy, applying learning rate decay at the strategy level, and evaluating models on the server side. That's quite a bit of flexibility with so little code, right?

In the later sections, we've seen how we can communicate arbitrary values between server and clients to fully customize client-side execution. With that capability, we built a large-scale Federated Learning simulation using the Flower Virtual Client Engine and ran an experiment involving 1000 clients in the same workload — all in the same Flower project!

接下来的步骤

Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).

There's a dedicated #questions Slack channel if you need help, but we'd also love to hear who you are in #introductions!

The Flower Federated Learning Tutorial - Part 3 shows how to build a fully custom Strategy from scratch.