Write your first Flower App with PyTorch

Welcome to the next part of the Flower collaborative AI tutorial!

In the previous tutorials, you created a simulated federation on SuperGrid, ran a Flower App, downloaded the @flwrlabs/demo app, and learned how ServerApp, ClientApp, strategies, and pyproject.toml fit together. In this tutorial, you will use the same workflow with a more realistic Flower App: a PyTorch app that trains a small image classifier on CIFAR-10.

Tip

Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss or Flower Slack to introduce yourself, ask questions, and get help.

Let’s get started! 🌼

Create the App

Use flwr new to fetch the PyTorch quickstart app from Flower Hub:

$ flwr new @flwrlabs/quickstart-pytorch

After running the command, a new directory named quickstart-pytorch will be created:

quickstart-pytorch
├── pytorchexample
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

This app has the same Flower structure as the NumPy demo from the previous tutorial, but the workload is now a real PyTorch training task. The app trains a small convolutional neural network on CIFAR-10, an image classification dataset with ten classes such as airplane, automobile, bird, cat, dog, ship, and truck.

Quick App Overview

Note

A more detailed walkthrough of the app is available later in this tutorial.

Before running the app, it helps to know what each file is responsible for:

  • pytorchexample/task.py contains the PyTorch-specific code: the neural network, CIFAR-10 data loading and partitioning, the local training loop, the evaluation loop, and server-side evaluation helpers.

  • pytorchexample/client_app.py defines the ClientApp. Its @app.train() handler receives the current global model, loads one CIFAR-10 partition, trains the model locally, and replies with updated model parameters plus metrics. Its @app.evaluate() handler evaluates the received model on local validation data and replies with metrics.

  • pytorchexample/server_app.py defines the ServerApp. It creates the initial PyTorch model, wraps the model parameters in an ArrayRecord, creates a FedAvg strategy, and starts the federated learning run.

  • pyproject.toml declares the app metadata and dependencies, points Flower to the ServerApp and ClientApp objects, and defines run configuration values such as the number of server rounds, batch size, local epochs, learning rate, and evaluation settings.

The important idea is the same as before: the ServerApp starts the run, FedAvg coordinates each federated learning round, and each ClientApp trains or evaluates the model using the data available on its SuperNode.

This app uses Flower Datasets to download CIFAR-10 and split it into partitions, one for each simulated client. This is ideal for simulations because it lets you experiment with federated learning even when you start from a single centralized dataset. In a typical Flower App that runs outside of simulation, you usually do not create artificial partitions. Instead, each ClientApp loads the data already available on the SuperNode where it runs.

Run the App on SuperGrid

Note

If you have not already done so, complete the first tutorial to create a SuperGrid account and a simulated federation.

Open a terminal, activate your Python environment, and run the following command to first login to SuperGrid:

# This will open a browser window where you can enter your SuperGrid credentials.
$ flwr login

Once you are logged in, run the following command to run the app on SuperGrid and across the federation you created in the previous tutorial:

# Navigate to the directory of the app you want to run
$ cd /path/to/quickstart-pytorch
# Run the app across the federation you created in the previous tutorial
$ flwr run . --federation @<username>/<federation-name>
# for example
# flwr run . --federation @peter123/my-first-federation

SuperGrid will start a new run for this app. Open the SuperGrid dashboard, select your federation, and click the new run to follow its progress and inspect the logs.

In the logs, you should see Flower start the FedAvg strategy and run several rounds of federated learning. Each round includes local training on selected ClientApp instances, aggregation in the ServerApp, and evaluation metrics such as eval_loss and eval_acc.

You can override values from pyproject.toml at run time. For example:

# Run the app for five rounds instead of the default three rounds
$ flwr run . --federation @<username>/<federation-name> \
    --run-config "num-server-rounds=5"

# Run the app for five rounds and a smaller batch size
$ flwr run . --federation @<username>/<federation-name> \
    --run-config "num-server-rounds=5" \
    --run-config "batch-size=16"

Run the App Locally

Running on SuperGrid is the recommended way to run collaborative AI workflows with Flower. However, it is also useful to run the same app locally while you are developing or debugging.

From the quickstart-pytorch directory, install the app and its dependencies into your Python environment:

$ cd /path/to/quickstart-pytorch
$ pip install -e .

Then run the app locally with the command below. Flower will start a managed local SuperLink – a distilled version of SuperGrid – and execute the app with simulated SuperNodes on your machine. The first run can take longer because the app needs to download CIFAR-10. With the flag --stream, you can see the logs from the local run in your terminal.

$ flwr run . local --stream

The streamed output should include logs similar to this:

INFO :      Starting FedAvg strategy:
INFO :          ├── Number of rounds: 3
INFO :      ...
INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 5 SuperNodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'train_loss': 2.149280}
INFO :      configure_evaluate: Sampled 10 SuperNodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'eval_loss': 2.31319, 'eval_acc': 0.13004}
INFO :      [ROUND 2/3]
INFO :      ...
INFO :      [ROUND 3/3]
INFO :      ...
INFO :      Strategy execution finished

Note

In the above flwr run command you are not specifying a federation, this is because for local prototyping there is only one federation available. Because of this, the --federation flag is not required.

Note

If you’re on Windows and see unexpected terminal output, for example □[32m□[1m, check this FAQ entry.

For more details on using the Flower CLI against a locally running SuperLink, including how to list your runs and view their logs, see Run Flower Locally with a Managed SuperLink.

A Deeper Dive into the App

The @flwrlabs/quickstart-pytorch app demonstrates a simple federated learning workflow. In federated learning, the server sends global model parameters to the client, and the client updates the local model with parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).

Define the Flower ClientApp

Federated learning systems consist of a server and multiple clients (SuperNodes). In Flower, we create a ServerApp and a ClientApp to run the server-side and client-side code, respectively.

The core functionality of the ClientApp is to perform some action with the local data that the SuperNode it runs on (e.g. an edge device, a server in a data center, or a laptop) has access to. In this tutorial such action is to train and evaluate the small CNN model defined earlier using the local training and validation data.

Loading the data

This app trains a small convolutional neural network on CIFAR-10. Since the tutorial uses the Simulation Runtime, all data starts from one centralized dataset and is split into partitions, one for each simulated SuperNode.

The load_data() function in task.py uses Flower Datasets to load one partition, split it into training and validation data, apply the PyTorch transforms, and return two DataLoader objects:

def load_data(partition_id: int, num_partitions: int, batch_size: int):
    """Load partition CIFAR10 data."""
    # Only initialize `FederatedDataset` once
    global fds
    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(
            dataset="uoft-cs/cifar10",
            partitioners={"train": partitioner},
        )
    partition = fds.load_partition(partition_id)
    # Divide data on each SuperNode: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = Compose(
        [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        """Apply transforms to the partition from FederatedDataset."""
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=batch_size, shuffle=True
    )
    testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)
    return trainloader, testloader

This partitioning is only needed for simulation. In deployment, each SuperNode would usually load its own local data directly, for example from a path provided through --node-config.

Training

We can define how the ClientApp performs training by wrapping a function with the @app.train() decorator. In this case we name this function train because we’ll use it to train the model on the local data. The function always expects two arguments:

  • A Message: The message received from the server. It contains the model parameters and any other configuration information sent by the server.

  • A Context: The context object that contains information about the SuperNode executing the ClientApp and about the current run.

Through the context you can retrieve the config settings defined in the pyproject.toml of your app. The context can be used to persist the state of the client across multiple calls to train or evaluate. In Flower, ClientApps are ephemeral objects that get instantiated for the execution of one Message and destroyed when a reply is communicated back to the server.

Let’s see an implementation of ClientApp that uses the previously defined PyTorch CNN model, applies the parameters received from the ServerApp via the message, loads its local data, trains the model with it (using the train_fn function), and generates a reply Message containing the updated model parameters as well as some metrics of interest.

from pytorchexample.task import train as train_fn

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    trainloader, _ = load_data(partition_id, num_partitions, batch_size)

    # Call the training function
    train_loss = train_fn(
        model,
        trainloader,
        context.run_config["local-epochs"],
        msg.content["config"]["lr"],
        device,
    )

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

Note

The partition-id and num-partitions values shown above are provided by the Simulation Runtime. In a deployment setting, the ClientApp would usually load data that already exists on the SuperNode. For example, you could pass the path to that data when starting the SuperNode with --node-config "data-path=/path/to/data" and then call load_data with context.node_config["data-path"].

Note that the train_fn is simply an alias name pointing to the train function defined earlier in this tutorial (where we defined the PyTorch training loop and optimizer). To this function we pass the model we want to train locally and the data loader, but also the number of local epochs and the learning rate (lr) to use. Note how in this case the local-epochs setting is read from the run config via the Context while the lr is read from the ConfigRecord sent by the server via the Message. This can be used to adjust the learning rate on each round from the server. When this dynamism isn’t needed, reading the lr from the run config via the Context is also perfectly valid.

Once training is completed, the ClientApp constructs a reply Message. This reply typically includes a RecordDict with two records:

  • An ArrayRecord containing the updated model parameters

  • A MetricRecord with relevant metrics (in this case, the training loss and the number of examples used for training)

Note

Returning the number of examples under the "num-examples" key is required, because strategies such as FedAvg used by the ServerApp rely on this key to aggregate both models and metrics by default, unless you override the weighted_by_key argument (for example: FedAvg(weighted_by_key="my-different-key")).

After constructing the reply Message, the ClientApp returns it. Flower then handles sending the reply back to the server automatically.

Evaluation

In a typical federated learning setup, the ClientApp would also implement an @app.evaluate() function to evaluate the model received from the ServerApp on local validation data. This is especially useful to monitor the performance of the global model on each client during training. The implementation of the evaluate function is very similar to the train function, except that it calls the test_fn function defined earlier in this tutorial (which implements the PyTorch evaluation loop) and it returns a Message containing only a MetricRecord with the evaluation metrics (no ArrayRecord because the model parameters are not updated during evaluation). Here’s how the evaluate function looks like:

from pytorchexample.task import test as test_fn


@app.evaluate()
def evaluate(msg: Message, context: Context):
    """Evaluate the model on local data."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    _, valloader = load_data(partition_id, num_partitions, batch_size)

    # Call the evaluation function
    eval_loss, eval_acc = test_fn(
        model,
        valloader,
        device,
    )

    # Construct and return reply Message
    metrics = {
        "eval_loss": eval_loss,
        "eval_acc": eval_acc,
        "num-examples": len(valloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"metrics": metric_record})
    return Message(content=content, reply_to=msg)

As you can see the evaluate implementation is near identical to the train implementation, except that it calls the test_fn function instead of the train_fn function and it returns a Message containing only a MetricRecord with metrics relevant to evaluation (eval_loss, eval_acc – both scalars). We also need to include the num-examples key in the metrics so the server can aggregate the evaluation metrics correctly.

Define the Flower ServerApp

On the server side, we need to configure a strategy which encapsulates the federated learning approach/algorithm, for example, Federated Averaging (FedAvg). Flower has a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this tutorial, we use the built-in FedAvg implementation and customize it slightly by specifying the fraction of connected SuperNodes to involve in a round of training.

To construct a ServerApp, we define its @app.main() method. This method receives as input arguments:

  • a Grid object that will be used to interface with the SuperNodes running the ClientApp to involve them in a round of train/evaluate/query or other.

  • a Context object that provides access to the run configuration.

Before launching the strategy via the start method, we want to initialize the global model. This will be the model that gets sent to the ClientApp running on the clients in the first round of federated learning. We can do this by creating an instance of the model (Net), extracting the parameters in its state_dict, and constructing an ArrayRecord with them. We can then make it available to the strategy via the initial_arrays argument of the start() method.

We can also optionally pass to the start() method a ConfigRecord containing settings that we would like to communicate to the clients. These will be sent as part of the Message that also carries the model parameters.

app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Read run config
    fraction_evaluate: float = context.run_config["fraction-evaluate"]
    num_rounds: int = context.run_config["num-server-rounds"]
    lr: float = context.run_config["learning-rate"]

    # Load global model
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())

    # Initialize FedAvg strategy
    strategy = FedAvg(fraction_evaluate=fraction_evaluate)

    # Start strategy, run FedAvg for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=global_evaluate,
    )

    # Save final model to disk
    print("\nSaving final model to disk...")
    state_dict = result.arrays.to_torch_state_dict()
    torch.save(state_dict, "final_model.pt")

Most of the execution of the ServerApp happens inside the strategy.start() method. After the specified number of rounds (num_rounds), the start() method returns a Result object containing the final model parameters and metrics received from the clients or generated by the strategy itself. We can then save the final model to disk for later use.

Behind the scenes

So how does this work? How does Flower execute this simulation?

When we execute flwr run against the default local connection configuration, Flower submits the run to the managed local SuperLink. By default, the local SuperLink will configure the simulation runtime to use 10 clients. Each will run an instance of the ClientApp we defined earlier.

The local SuperLink then starts the ServerApp and asks it to issue instructions to those SuperNodes using the FedAvg strategy. In this example, FedAvg is configured with two key parameters:

  • fraction-train=0.5 → select 50% of the available clients for training

  • fraction-evaluate=1.0 → select 100% of the available clients for evaluation

This means in our example, 5 out of 10 clients will be selected for training, and all 10 clients will later participate in evaluation.

A typical round looks like this:

  • Training

    1. FedAvg randomly selects 5 clients (50% of 10).

    2. Flower sends a TRAIN message to each selected ClientApp.

    3. Each ClientApp calls the function decorated with @app.train(), then returns a Message containing an ArrayRecord (the updated model parameters) and a MetricRecord (the training loss and number of examples).

    4. The ServerApp receives all replies.

    5. FedAvg aggregates all ArrayRecord into a new ArrayRecord representing the new global model and combines all MetricRecord.

  • Evaluation

    1. FedAvg selects all 10 clients (100%).

    2. Flower sends an EVALUATE message to each ClientApp.

    3. Each ClientApp calls the function decorated with @app.evaluate() and returns a Message containing a MetricRecord (the evaluation loss, accuracy, and number of examples).

    4. The ServerApp receives all replies.

    5. FedAvg aggregates all MetricRecord.

Once both training and evaluation are done, the next round begins: another training step, then another evaluation step, and so on, until the configured number of rounds is reached.

Final remarks

You have now run a PyTorch Flower App on SuperGrid and locally. Compared with the NumPy demo, this app uses a real model, a real dataset, and real local training, but the Flower structure is the same: ServerApp, ClientApp, strategy, and pyproject.toml.

In the next tutorial, you will customize the federated learning strategy to change how the server coordinates training and evaluation.

Next steps

Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).

There’s a dedicated #questions Slack channel if you need help, but we’d also love to hear who you are in #introductions!

The Flower Collaborative AI Tutorial - Part 4: Use a federated learning strategy goes into more depth about strategies and the advanced behavior you can build with them.