Démarrage rapide de PyTorch#

Dans ce tutoriel, nous allons apprendre à entraîner un réseau neuronal convolutif sur CIFAR10 à l’aide de Flower et PyTorch.

First of all, it is recommended to create a virtual environment and run everything within a virtualenv.

Notre exemple consiste en un serveur et deux clients ayant tous le même modèle.

Les clients sont chargés de générer des mises à jour de poids individuelles pour le modèle en fonction de leurs ensembles de données locales. Ces mises à jour sont ensuite envoyées au serveur qui les agrège pour produire un meilleur modèle. Enfin, le serveur renvoie cette version améliorée du modèle à chaque client. Un cycle complet de mises à jour de poids s’appelle un round.

Maintenant que nous avons une idée générale de ce qui se passe, commençons. Nous devons d’abord installer Flower. Tu peux le faire en exécutant :

$ pip install flwr

Puisque nous voulons utiliser PyTorch pour résoudre une tâche de vision par ordinateur, allons-y et installons PyTorch et la bibliothèque torchvision :

$ pip install torch torchvision

Client de la fleur#

Maintenant que nous avons installé toutes nos dépendances, lançons une formation distribuée simple avec deux clients et un serveur. Notre procédure de formation et l’architecture de notre réseau sont basées sur Deep Learning with PyTorch de PyTorch.

Dans un fichier appelé client.py, importe Flower et les paquets liés à PyTorch :

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl

En outre, nous définissons l’attribution des appareils dans PyTorch avec :

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Nous utilisons PyTorch pour charger CIFAR10, un ensemble de données de classification d’images colorées populaire pour l’apprentissage automatique. Le DataLoader() de PyTorch télécharge les données d’entraînement et de test qui sont ensuite normalisées.

def load_data():
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(".", train=True, download=True, transform=transform)
    testset = CIFAR10(".", train=False, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
    testloader = DataLoader(testset, batch_size=32)
    num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
    return trainloader, testloader, num_examples

Définis la perte et l’optimiseur avec PyTorch L’entraînement de l’ensemble de données se fait en bouclant sur l’ensemble de données, en mesurant la perte correspondante et en l’optimisant.

def train(net, trainloader, epochs):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

Définis ensuite la validation du réseau d’apprentissage automatique. Nous passons en boucle sur l’ensemble de test et mesurons la perte et la précision de l’ensemble de test.

def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

Après avoir défini l’entraînement et le test d’un modèle d’apprentissage automatique PyTorch, nous utilisons les fonctions pour les clients Flower.

Les clients de Flower utiliseront un CNN simple adapté de « PyTorch : A 60 Minute Blitz » :

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()

Après avoir chargé l’ensemble des données avec load_data(), nous définissons l’interface Flower.

Le serveur Flower interagit avec les clients par le biais d’une interface appelée Client. Lorsque le serveur sélectionne un client particulier pour la formation, il envoie des instructions de formation sur le réseau. Le client reçoit ces instructions et appelle l’une des méthodes Client pour exécuter ton code (c’est-à-dire pour former le réseau neuronal que nous avons défini plus tôt).

Flower fournit une classe de commodité appelée NumPyClient qui facilite la mise en œuvre de l’interface Client lorsque ta charge de travail utilise PyTorch. Mettre en œuvre NumPyClient signifie généralement définir les méthodes suivantes (set_parameters est cependant facultatif) :

  1. get_parameters
    • renvoie le poids du modèle sous la forme d’une liste de ndarrays NumPy

  2. set_parameters (optionnel)
    • mettre à jour les poids du modèle local avec les paramètres reçus du serveur

  3. fit
    • fixe les poids du modèle local

    • entraîne le modèle local

    • recevoir les poids du modèle local mis à jour

  4. évaluer
    • teste le modèle local

qui peut être mis en œuvre de la manière suivante :

class CifarClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), num_examples["trainset"], {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}

Nous pouvons maintenant créer une instance de notre classe CifarClient et ajouter une ligne pour exécuter ce client :

fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client())

That’s it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client(). If you implement a client of type NumPyClient you’ll need to first call its to_client() method. The string "[::]:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "[::]:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we point the client at.

Serveur de Flower#

Pour les charges de travail simples, nous pouvons démarrer un serveur Flower et laisser toutes les possibilités de configuration à leurs valeurs par défaut. Dans un fichier nommé server.py, importe Flower et démarre le serveur :

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

Entraîne le modèle, fédéré !#

Le client et le serveur étant prêts, nous pouvons maintenant tout exécuter et voir l’apprentissage fédéré en action. Les systèmes FL ont généralement un serveur et plusieurs clients. Nous devons donc commencer par démarrer le serveur :

$ python server.py

Une fois que le serveur fonctionne, nous pouvons démarrer les clients dans différents terminaux. Ouvre un nouveau terminal et démarre le premier client :

$ python client.py

Ouvre un autre terminal et démarre le deuxième client :

$ python client.py

Chaque client aura son propre ensemble de données. Tu devrais maintenant voir comment la formation se déroule dans le tout premier terminal (celui qui a démarré le serveur) :

INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046
INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)]
INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625
INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))]
INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []

Congratulations! You’ve successfully built and run your first federated learning system. The full source code for this example can be found in examples/quickstart-pytorch.