Exemple : PyTorch - De la centralisation à la fédération

Ce tutoriel te montrera comment utiliser Flower pour construire une version fédérée d’une charge de travail d’apprentissage automatique existante. Nous utilisons PyTorch pour entraîner un réseau neuronal convolutif sur l’ensemble de données CIFAR-10. Tout d’abord, nous présentons cette tâche d’apprentissage automatique avec une approche d’entraînement centralisée basée sur le tutoriel Deep Learning with PyTorch. Ensuite, nous nous appuyons sur le code d’entraînement centralisé pour exécuter l’entraînement de manière fédérée.

Formation centralisée

Nous commençons par une brève description du code d’entraînement CNN centralisé. Si tu veux une explication plus approfondie de ce qui se passe, jette un coup d’œil au tutoriel officiel PyTorch.

Let’s create a new file called cifar.py with all the components required for a traditional (centralized) training on CIFAR-10. First, all required packages (such as torch and torchvision) need to be imported. You can see that we do not import any package for federated learning. You can keep all these imports as they are even when we add the federated learning components at a later point.

from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10

As already mentioned we will use the CIFAR-10 dataset for this machine learning workload. The model architecture (a very simple Convolutional Neural Network) is defined in class Net().

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The load_data() function loads the CIFAR-10 training and test sets. The transform normalized the data after loading.

DATA_ROOT = "~/data/cifar-10"


def load_data() -> (
    Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]
):
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    return trainloader, testloader, num_examples

We now need to define the training (function train()) which loops over the training set, measures the loss, backpropagates it, and then takes one optimizer step for each batch of training examples.

The evaluation of the model is defined in the function test(). The function loops over all test samples and measures the loss of the model based on the test dataset.

def train(
    net: Net,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,
) -> None:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0


def test(
    net: Net,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

Après avoir défini le chargement des données, l’architecture du modèle, la formation et l’évaluation, nous pouvons tout mettre ensemble et former notre CNN sur CIFAR-10.

def main():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Centralized PyTorch training")
    print("Load data")
    trainloader, testloader, _ = load_data()
    print("Start training")
    net = Net().to(DEVICE)
    train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
    print("Evaluate model")
    loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
    print("Loss: ", loss)
    print("Accuracy: ", accuracy)


if __name__ == "__main__":
    main()

Tu peux maintenant exécuter ta charge de travail d’apprentissage automatique :

python3 cifar.py

Jusqu’à présent, tout cela devrait te sembler assez familier si tu as déjà utilisé PyTorch. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un simple système d’apprentissage fédéré composé d’un serveur et de deux clients.

Formation fédérée

Le projet simple d’apprentissage automatique discuté dans la section précédente entraîne le modèle sur un seul ensemble de données (CIFAR-10), nous appelons cela l’apprentissage centralisé. Ce concept d’apprentissage centralisé, comme le montre la section précédente, est probablement connu de la plupart d’entre vous, et beaucoup d’entre vous l’ont déjà utilisé. Normalement, si tu veux exécuter des charges de travail d’apprentissage automatique de manière fédérée, tu dois alors changer la plupart de ton code et tout mettre en place à partir de zéro, ce qui peut représenter un effort considérable.

Cependant, avec Flower, tu peux faire évoluer ton code préexistant vers une configuration d’apprentissage fédéré sans avoir besoin d’une réécriture majeure.

The concept is easy to understand. We have to start a server and then use the code in cifar.py for the clients that are connected to the server. The server sends model parameters to the clients. The clients run the training and update the parameters. The updated parameters are sent back to the server which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds.

Our example consists of one server and two clients. Let’s set up server.py first. The server needs to import the Flower package flwr. Next, we use the start_server function to start a server and tell it to perform three rounds of federated learning.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(
        server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
    )

Nous pouvons déjà démarrer le serveur :

python3 server.py

Finally, we will define our client logic in client.py and build upon the previously defined centralized training in cifar.py. Our client needs to import flwr, but also torch to update the parameters on our PyTorch model:

from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch

import cifar
import flwr as fl

DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we’ll call it CifarClient. NumPyClient is slightly easier to implement than Client if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. CifarClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters
    • règle les paramètres du modèle local reçus du serveur

    • loop over the list of model parameters received as NumPy ndarray’s (think list of neural network layers)

  2. get_parameters
    • get the model parameters and return them as a list of NumPy ndarray’s (which is what flwr.client.NumPyClient expects)

  3. fit
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • entraîne le modèle sur l’ensemble d’apprentissage local

    • récupère les poids du modèle local mis à jour et les renvoie au serveur

  4. evaluate
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • évaluer le modèle mis à jour sur l’ensemble de test local

    • renvoie la perte locale et la précision au serveur

The two NumPyClient methods fit and evaluate make use of the functions train() and test() previously defined in cifar.py. So what we really do here is we tell Flower through our NumPyClient subclass which of our already defined functions to call for training and evaluation. We included type annotations to give you a better understanding of the data types that get passed around.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    def __init__(
        self,
        model: cifar.Net,
        trainloader: torch.utils.data.DataLoader,
        testloader: torch.utils.data.DataLoader,
        num_examples: Dict,
    ) -> None:
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.num_examples = num_examples

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        self.set_parameters(parameters)
        cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
        return self.get_parameters(config={}), self.num_examples["trainset"], {}

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        self.set_parameters(parameters)
        loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
        return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}

All that’s left to do it to define a function that loads both model and data, creates a CifarClient, and starts this client. You load your data and model by using cifar.py. Start CifarClient with the function fl.client.start_client() by pointing it at the same IP address we used in server.py:

def main() -> None:
    """Load data, start CifarClient."""

    # Load model and data
    model = cifar.Net()
    model.to(DEVICE)
    trainloader, testloader, num_examples = cifar.load_data()

    # Start client
    client = CifarClient(model, trainloader, testloader, num_examples)
    fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())


if __name__ == "__main__":
    main()

Tu peux maintenant ouvrir deux autres fenêtres de terminal et exécuter les commandes suivantes

python3 client.py

dans chaque fenêtre (assure-toi que le serveur fonctionne avant de le faire) et tu verras ton projet PyTorch (auparavant centralisé) exécuter l’apprentissage fédéré sur deux clients. Félicitations !

Prochaines étapes

The full source code for this example: PyTorch: From Centralized To Federated (Code). Our example is, of course, somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using different subsets of CIFAR-10 on each client? How about adding more clients?