Example: Walk-Through PyTorch & MNIST#

In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch.

Our example consists of one server and two clients all having the same model.

Clients are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the server which will aggregate them to produce a better model. Finally, the server sends this improved version of the model back to each client. A complete cycle of weight updates is called a round.

Now that we have a rough idea of what is going on, let’s get started. We first need to install Flower. You can do this by running :

$ pip install flwr

Since we want to use PyTorch to solve a computer vision task, let’s go ahead an install PyTorch and the torchvision library:

$ pip install torch torchvision

Ready… Set… Train!#

Now that we have all our dependencies installed, let’s run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch’s Basic MNIST Example. This will allow you see how easy it is to wrap your code with Flower and begin training in a federated way. We provide you with two helper scripts, namely run-server.sh, and run-clients.sh. Don’t be afraid to look inside, they are simple enough =).

Go ahead and launch on a terminal the run-server.sh script first as follows:

$ bash ./run-server.sh

Now that the server is up and running, go ahead and launch the clients.

$ bash ./run-clients.sh

Et voilĂ ! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client.

Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014

Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403

Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280

Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641

Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784

Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134

Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16%

Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14%

Now, let’s see what is really happening inside.

Flower Server#

Inside the server helper script run-server.sh you will find the following code that basically runs the server.py

python -m flwr_example.quickstart-pytorch.server

We can go a bit deeper and see that server.py simply launches a server that will coordinate three rounds of training. Flower Servers are very customizable, but for simple workloads, we can start a server using the start_server function and leave all the configuration possibilities at their default values, as seen below.

import flwr as fl


Flower Client#

Next, let’s take a look at the run-clients.sh file. You will see that it contains the main loop that starts a set of clients.

python -m flwr_example.quickstart-pytorch.client \
  --cid=$i \
  --server_address=$SERVER_ADDRESS \
  • cid: is the client ID. It is an integer that uniquely identifies client identifier.

  • sever_address: String that identifies IP and port of the server.

  • nb_clients: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both training and test sets.

Again, we can go deeper and look inside flwr_example/quickstart-pytorch/client.py. After going through the argument parsing code at the beginning of our main function, you will find a call to mnist.load_data. This function is responsible for partitioning the original MNIST datasets (training and test) and returning a torch.utils.data.DataLoader s for each of them. We then instantiate a PytorchMNISTClient object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU).

client = mnist.PytorchMNISTClient(

The PytorchMNISTClient object when finally passed to fl.client.start_client along with the server’s address as the training process begins.

A Closer Look#

Now, let’s look closely into the PytorchMNISTClient inside flwr_example.quickstart-pytorch.mnist and see what it is doing:

class PytorchMNISTClient(fl.client.Client):
    """Flower client implementing MNIST handwritten classification using PyTorch."""
    def __init__(
        cid: int,
        train_loader: datasets,
        test_loader: datasets,
        epochs: int,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        self.model = MNISTNet().to(device)
        self.cid = cid
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.epochs = epochs

    def get_weights(self) -> fl.common.NDArrays:
        """Get model weights as a list of NumPy ndarrays."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_weights(self, weights: fl.common.NDArrays) -> None:
        """Set model weights from a list of NumPy ndarrays.

        weights: fl.common.NDArrays
            Weights received by the server and set to local model


        state_dict = OrderedDict(
                k: torch.tensor(v)
                for k, v in zip(self.model.state_dict().keys(), weights)
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config) -> fl.common.ParametersRes:
        """Encapsulates the weight into Flower Parameters """
        weights: fl.common.NDArrays = self.get_weights()
        parameters = fl.common.ndarrays_to_parameters(weights)
        return fl.common.ParametersRes(parameters=parameters)

    def fit(self, ins: fl.common.FitIns) -> fl.common.FitRes:
        """Trains the model on local dataset

        ins: fl.common.FitIns
        Parameters sent by the server to be used during training.

            Set of variables containing the new set of weights and information the client.

        weights: fl.common.NDArrays = fl.common.parameters_to_ndarrays(ins.parameters)
        fit_begin = timeit.default_timer()

        # Set model parameters/weights

        # Train model
        num_examples_train: int = train(
            self.model, self.train_loader, epochs=self.epochs, device=self.device

        # Return the refined weights and the number of examples used for training
        weights_prime: fl.common.NDArrays = self.get_weights()
        params_prime = fl.common.ndarrays_to_parameters(weights_prime)
        fit_duration = timeit.default_timer() - fit_begin
        return fl.common.FitRes(

    def evaluate(self, ins: fl.common.EvaluateIns) -> fl.common.EvaluateRes:

        ins: fl.common.EvaluateIns
        Parameters sent by the server to be used during testing.

            Information the clients testing results.

The first thing to notice is that PytorchMNISTClient instantiates a CNN model inside its constructor

class PytorchMNISTClient(fl.client.Client):
"""Flower client implementing MNIST handwritten classification using PyTorch."""

def __init__(
    cid: int,
    train_loader: datasets,
    test_loader: datasets,
    epochs: int,
    device: torch.device = torch.device("cpu"),
) -> None:
    self.model = MNISTNet().to(device)

The code for the CNN is available under quickstart-pytorch.mnist and it is reproduced below. It is the same network found in Basic MNIST Example.

class MNISTNet(nn.Module):
    """Simple CNN adapted from Pytorch's 'Basic MNIST Example'."""

    def __init__(self) -> None:
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: Tensor) -> Tensor:
        """Compute forward pass.

        x: Tensor
            Mini-batch of shape (N,28,28) containing images from MNIST dataset.

        output: Tensor
            The probability density of the output being from a specific class given the input.

        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

The second thing to notice is that PytorchMNISTClient class inherits from the fl.client.Client, and hence it must implement the following methods:

from abc import ABC, abstractmethod

from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, ParametersRes

class Client(ABC):
    """Abstract base class for Flower clients."""

    def get_parameters(self, config) -> ParametersRes:
        """Return the current local model parameters."""

    def fit(self, ins: FitIns) -> FitRes:
        """Refine the provided weights using the locally held dataset."""

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        """Evaluate the provided weights using the locally held dataset."""

When comparing the abstract class to its derived class PytorchMNISTClient you will notice that fit calls a train function and that evaluate calls a test: function.

These functions can both be found inside the same quickstart-pytorch.mnist module:

def train(
    model: torch.nn.ModuleList,
    train_loader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device = torch.device("cpu"),
) -> int:
    """Train routine based on 'Basic MNIST Example'

    model: torch.nn.ModuleList
        Neural network model used in this example.

    train_loader: torch.utils.data.DataLoader
        DataLoader used in traning.

    epochs: int
        Number of epochs to run in each round.

    device: torch.device
        (Default value = torch.device("cpu"))
        Device where the network will be trained within a client.

    num_examples_train: int
        Number of total samples used during traning.

    optimizer = optim.Adadelta(model.parameters(), lr=1.0)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    print(f"Training {epochs} epoch(s) w/ {len(train_loader)} mini-batches each")
    for epoch in range(epochs):  # loop over the dataset multiple time
        loss_epoch: float = 0.0
        num_examples_train: int = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            # Grab mini-batch and transfer to device
            data, target = data.to(device), target.to(device)
            num_examples_train += len(data)

            # Zero gradients

            output = model(data)
            loss = F.nll_loss(output, target)

            loss_epoch += loss.item()
            if batch_idx % 10 == 8:
                    "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}\t\t\t\t".format(
                        len(train_loader) * train_loader.batch_size,
                        * num_examples_train
                        / len(train_loader)
                        / train_loader.batch_size,
    return num_examples_train

def test(
    model: torch.nn.ModuleList,
    test_loader: torch.utils.data.DataLoader,
    device: torch.device = torch.device("cpu"),
) -> Tuple[int, float, float]:
    """Test routine 'Basic MNIST Example'

    model: torch.nn.ModuleList :
        Neural network model used in this example.

    test_loader: torch.utils.data.DataLoader :
        DataLoader used in test.

    device: torch.device :
        (Default value = torch.device("cpu"))
        Device where the network will be tested within a client.

        Tuple containing the total number of test samples, the test_loss, and the accuracy evaluated on the test set.

    test_loss: float = 0
    correct: int = 0
    num_test_samples: int = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            num_test_samples += len(data)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= num_test_samples

    return (num_test_samples, test_loss, correct / num_test_samples)

Observe that these functions encapsulate regular training and test loops and provide fit and evaluate with final statistics for each round. You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. As a matter of fact, why not try and modify the code to an example of your liking?

Give It a Try#

Looking through the quickstart code description above will have given a good understanding of how clients and servers work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower:

  • Try and change PytorchMNISTClient so it can accept different architectures.

  • Modify the train function so that it accepts different optimizers

  • Modify the test function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy?

  • Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50?

You are ready now. Enjoy learning in a federated way!