Exemple : PyTorch - De la centralisation à la fédération#

Ce tutoriel te montrera comment utiliser Flower pour construire une version fédérée d’une charge de travail d’apprentissage automatique existante. Nous utilisons PyTorch pour entraîner un réseau neuronal convolutif sur l’ensemble de données CIFAR-10. Tout d’abord, nous présentons cette tâche d’apprentissage automatique avec une approche d’entraînement centralisée basée sur le tutoriel Deep Learning with PyTorch. Ensuite, nous nous appuyons sur le code d’entraînement centralisé pour exécuter l’entraînement de manière fédérée.

Formation centralisée#

Nous commençons par une brève description du code d’entraînement CNN centralisé. Si tu veux une explication plus approfondie de ce qui se passe, jette un coup d’œil au tutoriel officiel PyTorch.

Créons un nouveau fichier appelé cifar.py avec tous les composants requis pour une formation traditionnelle (centralisée) sur le CIFAR-10. Tout d’abord, tous les paquets requis (tels que torch et torchvision) doivent être importés. Tu peux voir que nous n’importons aucun paquet pour l’apprentissage fédéré. Tu peux conserver toutes ces importations telles quelles même lorsque nous ajouterons les composants d’apprentissage fédéré à un moment ultérieur.

from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10

Comme nous l’avons déjà mentionné, nous utiliserons l’ensemble de données CIFAR-10 pour cette charge de travail d’apprentissage automatique. L’architecture du modèle (un réseau neuronal convolutif très simple) est définie dans class Net().

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

La fonction load_data() charge les ensembles d’entraînement et de test CIFAR-10. La fonction transform normalise les données après leur chargement.

DATA_ROOT = "~/data/cifar-10"

def load_data() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]:
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
    return trainloader, testloader, num_examples

Nous devons maintenant définir la formation (fonction train()) qui passe en boucle sur l’ensemble de la formation, mesure la perte, la rétropropage, puis effectue une étape d’optimisation pour chaque lot d’exemples de formation.

L’évaluation du modèle est définie dans la fonction test(). La fonction boucle sur tous les échantillons de test et mesure la perte du modèle en fonction de l’ensemble des données de test.

def train(
    net: Net,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,
) -> None:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0


def test(
    net: Net,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

Après avoir défini le chargement des données, l’architecture du modèle, la formation et l’évaluation, nous pouvons tout mettre ensemble et former notre CNN sur CIFAR-10.

def main():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Centralized PyTorch training")
    print("Load data")
    trainloader, testloader, _ = load_data()
    print("Start training")
    net=Net().to(DEVICE)
    train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
    print("Evaluate model")
    loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
    print("Loss: ", loss)
    print("Accuracy: ", accuracy)


if __name__ == "__main__":
    main()

Tu peux maintenant exécuter ta charge de travail d’apprentissage automatique :

python3 cifar.py

Jusqu’à présent, tout cela devrait te sembler assez familier si tu as déjà utilisé PyTorch. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un simple système d’apprentissage fédéré composé d’un serveur et de deux clients.

Formation fédérée#

Le projet simple d’apprentissage automatique discuté dans la section précédente entraîne le modèle sur un seul ensemble de données (CIFAR-10), nous appelons cela l’apprentissage centralisé. Ce concept d’apprentissage centralisé, comme le montre la section précédente, est probablement connu de la plupart d’entre vous, et beaucoup d’entre vous l’ont déjà utilisé. Normalement, si tu veux exécuter des charges de travail d’apprentissage automatique de manière fédérée, tu dois alors changer la plupart de ton code et tout mettre en place à partir de zéro, ce qui peut représenter un effort considérable.

Cependant, avec Flower, tu peux faire évoluer ton code préexistant vers une configuration d’apprentissage fédéré sans avoir besoin d’une réécriture majeure.

Le concept est facile à comprendre. Nous devons démarrer un serveur et utiliser le code dans cifar.py pour les clients qui sont connectés au serveur. Le serveur envoie les paramètres du modèle aux clients. Les clients exécutent la formation et mettent à jour les paramètres. Les paramètres mis à jour sont renvoyés au serveur qui fait la moyenne de toutes les mises à jour de paramètres reçues. Ceci décrit un tour du processus d’apprentissage fédéré et nous répétons cette opération pour plusieurs tours.

Notre exemple consiste en un serveur et deux clients. Commençons par configurer server.py. Le serveur doit importer le paquet Flower flwr. Ensuite, nous utilisons la fonction start_server pour démarrer un serveur et lui demander d’effectuer trois cycles d’apprentissage fédéré.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))

Nous pouvons déjà démarrer le serveur :

python3 server.py

Enfin, nous allons définir notre logique client dans client.py et nous appuyer sur la formation centralisée définie précédemment dans cifar.py. Notre client doit importer flwr, mais aussi torch pour mettre à jour les paramètres de notre modèle PyTorch :

from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch

import cifar
import flwr as fl

DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we’ll call it CifarClient. NumPyClient is slightly easier to implement than Client if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. CifarClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters
    • règle les paramètres du modèle local reçus du serveur

    • boucle sur la liste des paramètres du modèle reçus sous forme de NumPy ndarray’s (pensez à la liste des couches du réseau neuronal)

  2. get_parameters
    • récupère les paramètres du modèle et les renvoie sous forme de liste de ndarray NumPy (ce qui correspond à ce que flwr.client.NumPyClient attend)

  3. fit
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • entraîne le modèle sur l’ensemble d’apprentissage local

    • récupère les poids du modèle local mis à jour et les renvoie au serveur

  4. évaluer
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • évaluer le modèle mis à jour sur l’ensemble de test local

    • renvoie la perte locale et la précision au serveur

Les deux méthodes NumPyClient fit et evaluate utilisent les fonctions train() et test() définies précédemment dans cifar.py. Ce que nous faisons vraiment ici, c’est que nous indiquons à Flower, par le biais de notre sous-classe NumPyClient, laquelle de nos fonctions déjà définies doit être appelée pour l’entraînement et l’évaluation. Nous avons inclus des annotations de type pour te donner une meilleure compréhension des types de données qui sont transmis.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    def __init__(
        self,
        model: cifar.Net,
        trainloader: torch.utils.data.DataLoader,
        testloader: torch.utils.data.DataLoader,
        num_examples: Dict,
    ) -> None:
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.num_examples = num_examples

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        self.set_parameters(parameters)
        cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
        return self.get_parameters(config={}), self.num_examples["trainset"], {}

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        self.set_parameters(parameters)
        loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
        return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}

All that’s left to do it to define a function that loads both model and data, creates a CifarClient, and starts this client. You load your data and model by using cifar.py. Start CifarClient with the function fl.client.start_client() by pointing it at the same IP address we used in server.py:

def main() -> None:
    """Load data, start CifarClient."""

    # Load model and data
    model = cifar.Net()
    model.to(DEVICE)
    trainloader, testloader, num_examples = cifar.load_data()

    # Start client
    client = CifarClient(model, trainloader, testloader, num_examples)
    fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())


if __name__ == "__main__":
    main()

Tu peux maintenant ouvrir deux autres fenêtres de terminal et exécuter les commandes suivantes

python3 client.py

dans chaque fenêtre (assure-toi que le serveur fonctionne avant de le faire) et tu verras ton projet PyTorch (auparavant centralisé) exécuter l’apprentissage fédéré sur deux clients. Félicitations !

Prochaines étapes#

The full source code for this example: PyTorch: From Centralized To Federated (Code). Our example is, of course, somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using different subsets of CIFAR-10 on each client? How about adding more clients?