PyTorch: From Centralized To Federated

Photo of Dr. Maria Börner
Dr. Maria Börner
Program Manager at Adap

Flower was built with a strong focus on compatibility with existing infrastructure. This blog post shows how we can use Flower to federate an existing PyTorch project to train a convolutional neural network (CNN) on CIFAR-10. We will first train a so-called centralized PyTorch setup and then use it as a base to build a simple federated version involving one server and two clients.

PyTorch and Flower Logo being connected

PyTorch, the PyTorch logo and any related marks are trademarks of Facebook, Inc.

Centralized PyTorch Training

PyTorch is one of the most popular machine learning frameworks. It allows users to create different types of deep learning algorithms with many possibilities to customize them for specific use cases. The centralized training part of this blog post follows the basic PyTorch CIFAR-10 tutorial and consists of four components: loading the data via load_data, defining the CNN in a call called Net, training the CNN using the train function, and finally evaluating the trained model using the test function.

The following code snippet shows how these components are called in a centralized training setup. Full implementation details of each of those components can be seen in

First, the CIFAR-10 dataset is loaded and used to create one DataLoader for the training set and another DataLoader for the test set. The training set is used to train the simple CNN that is defined in the Net class. After the training is done, the test set is used to evaluate the previously trained CNN.

def main():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Central PyTorch Training")
    print("Load CIFAR-10")
    trainloader, testloader = load_data()
    print("Define CNN model")
    print("Train model")
    train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
    print("Evaluate model")
    loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
    print("Test set loss: ", loss)
    print("Test set accuracy: ", accuracy)

Federated PyTorch Training

We can now build upon this centralized machine learning process ( and evolve it to build a Federated Learning system.

Let's start with the server (e.g., in a script called, which can start out as a simple two-liner:

import flwr as fl
fl.server.start_server(config={"num_rounds": 3})

This Flower server performs three rounds of Federated Averaging once at least two clients are connected to it. Our client (e.g. in a script called will now connect Flower to the code we previously used in the centralized setup.

Let's start with the required imports, most notably Flower (flwr), PyTorch (torch and torchvision), and our old friend

from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch
import torchvision

import flwr as fl

from . import cifar

DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Next, let's build the Flower client CifarClient which is derived from Flower's convenience class NumPyClient. Our client class is the point where we connect Flower to our custom PyTorch training code. We will implement four methods: get_parameters(), set_parameters(), fit, and evaluate.

The Flower client stores both our CNN (class Net from and the training and test set DataLoader as instance variables. The two methods set_parameters() and get_parameters() are used to update the parameters (or weights) of our CNN. In Federated Learning, we usually update the local model with the global model parameters received from the server (i.e., set_parameters()) and then train the model on the local data. Afterwards, the locally updated model parameters are returned back to the server (i.e., get_parameters()).

The fit() method does exactly that: it receives the CNN model weights (argument parameters) as a list of NumPy ndarrays from the server, updates the local model (via self.set_parameters(parameters)), and then performs training using the training set by calling the train function (cifar.train() from our centralized training project). Finally, it returns the updated model weight using self.get_parameters() (along with the number of training examples).

The evaluate() method is working in a similar way, but instead of training it calculates the loss and accuracy on our test set using the previously defined cifar.test()).

class CifarClient(fl.client.NumPyClient):
    def __init__(
        model: cifar.Net,
    ) -> None:
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int]:
        # Set model parameters

        # Train model
        cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)

        # Return the updated model parameters (and number of training examples)
        return self.get_parameters(), len(self.trainloader), {}

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[int, float, float]:
        # Use provided parameters to update the local model

        # Evaluate the updated model on the local test set
        loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)

        # Return the test set loss, number of evaluation examples, and test set accuracy
        float(loss), len(self.testloader), {"accuracy": float(accuracy)}

The main() function of looks quite similar to our centralized PyTorch example. It uses the definitions from, for example, the model and the data loading. It then creates a CifarClient() object and calls start_numpy_client(). This starts the client and tells it to connect to the Flower server, which in our case runs on the same machine on port 8080, so we use the IPv6 address "[::]:8080".

def main() -> None:
    """Load data, start CifarClient."""

        # Load model and data
        model = cifar.Net()
        trainloader, testloader = cifar.load_data()

        # Start client
        client = CifarClient(model, trainloader, testloader)
        fl.client.start_numpy_client("[::]:8080", client)

if __name__ == "__main__":

We can now start the server and two clients in separate terminal windows:

$ python
$ python
$ python

Congratulations! You've learned how an existing centralized PyTorch project can be federated with Flower. And we didn't even have to change anything about the existing project, we just added the code for federated learning in two additional files and made them use the existing building blocks. You can now start to experiment with your project. How about starting more clients? We should also start to use a different subset of the data on each client to have a more realistic simulation.

Feel free to take a look at the advanced TensorFlow example for more ideas.