实例: PyTorch - 从集中式到联邦式

本教程将向您展示如何使用 Flower 构建现有机器学习工作的联邦版本。我们使用 PyTorch 在 CIFAR-10 数据集上训练一个卷积神经网络。首先,我们基于 "Deep Learning with PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_"教程,采用集中式训练方法介绍了这项机器学习任务。然后,我们在集中式训练代码的基础上以联邦方式运行训练。

集中式训练

我们首先简要介绍一下集中式 CNN 训练代码。如果您想获得更深入的解释,请参阅 PyTorch 官方教程`PyTorch tutorial <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_。

Let's create a new file called cifar.py with all the components required for a traditional (centralized) training on CIFAR-10. First, all required packages (such as torch and torchvision) need to be imported. You can see that we do not import any package for federated learning. You can keep all these imports as they are even when we add the federated learning components at a later point.

from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10

As already mentioned we will use the CIFAR-10 dataset for this machine learning workload. The model architecture (a very simple Convolutional Neural Network) is defined in class Net().

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The load_data() function loads the CIFAR-10 training and test sets. The transform normalized the data after loading.

DATA_ROOT = "~/data/cifar-10"


def load_data() -> (
    Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]
):
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    return trainloader, testloader, num_examples

We now need to define the training (function train()) which loops over the training set, measures the loss, backpropagates it, and then takes one optimizer step for each batch of training examples.

The evaluation of the model is defined in the function test(). The function loops over all test samples and measures the loss of the model based on the test dataset.

def train(
    net: Net,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,
) -> None:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0


def test(
    net: Net,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

在确定了数据加载、模型架构、训练和评估之后,我们就可以将所有整合在一起,在 CIFAR-10 上训练我们的 CNN。

def main():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Centralized PyTorch training")
    print("Load data")
    trainloader, testloader, _ = load_data()
    print("Start training")
    net = Net().to(DEVICE)
    train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
    print("Evaluate model")
    loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
    print("Loss: ", loss)
    print("Accuracy: ", accuracy)


if __name__ == "__main__":
    main()

现在,您可以运行您的机器学习工作了:

python3 cifar.py

到目前为止,如果你以前用过 PyTorch,这一切看起来应该相当熟悉。让我们进行下一步,利用我们所构建的内容创建一个简单联邦学习系统(由一个服务器和两个客户端组成)。

联邦培训

上一节讨论的简单机器学习项目在单一数据集(CIFAR-10)上训练模型,我们称之为集中学习。如上一节所示,集中学习的概念可能为大多数人所熟知,而且很多人以前都使用过。通常情况下,如果要以联邦方式运行机器学习工作,就必须更改大部分代码,并从头开始设置一切。这可能是一个相当大的工作量。

不过,有了 Flower,您可以轻松地将已有的代码转变成联邦学习的模式,无需进行大量重写。

The concept is easy to understand. We have to start a server and then use the code in cifar.py for the clients that are connected to the server. The server sends model parameters to the clients. The clients run the training and update the parameters. The updated parameters are sent back to the server which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds.

Our example consists of one server and two clients. Let's set up server.py first. The server needs to import the Flower package flwr. Next, we use the start_server function to start a server and tell it to perform three rounds of federated learning.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(
        server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
    )

我们已经可以启动*服务器*了:

python3 server.py

Finally, we will define our client logic in client.py and build upon the previously defined centralized training in cifar.py. Our client needs to import flwr, but also torch to update the parameters on our PyTorch model:

from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch

import cifar
import flwr as fl

DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we'll call it CifarClient. NumPyClient is slightly easier to implement than Client if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. CifarClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters
    • 在本地模型上设置从服务器接收的模型参数

    • loop over the list of model parameters received as NumPy ndarray's (think list of neural network layers)

  2. get_parameters
    • get the model parameters and return them as a list of NumPy ndarray's (which is what flwr.client.NumPyClient expects)

  3. fit
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地训练集上训练模型

    • 获取更新后的本地模型参数并发送回服务器

  4. evaluate
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地测试集上评估更新后的模型

    • 向服务器返回本地损失值和精确度

The two NumPyClient methods fit and evaluate make use of the functions train() and test() previously defined in cifar.py. So what we really do here is we tell Flower through our NumPyClient subclass which of our already defined functions to call for training and evaluation. We included type annotations to give you a better understanding of the data types that get passed around.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    def __init__(
        self,
        model: cifar.Net,
        trainloader: torch.utils.data.DataLoader,
        testloader: torch.utils.data.DataLoader,
        num_examples: Dict,
    ) -> None:
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.num_examples = num_examples

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        self.set_parameters(parameters)
        cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
        return self.get_parameters(config={}), self.num_examples["trainset"], {}

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        self.set_parameters(parameters)
        loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
        return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}

All that's left to do it to define a function that loads both model and data, creates a CifarClient, and starts this client. You load your data and model by using cifar.py. Start CifarClient with the function fl.client.start_client() by pointing it at the same IP address we used in server.py:

def main() -> None:
    """Load data, start CifarClient."""

    # Load model and data
    model = cifar.Net()
    model.to(DEVICE)
    trainloader, testloader, num_examples = cifar.load_data()

    # Start client
    client = CifarClient(model, trainloader, testloader, num_examples)
    fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())


if __name__ == "__main__":
    main()

就是这样,现在你可以打开另外两个终端窗口,然后运行

python3 client.py

确保服务器正在运行后,您就能看到您的 PyTorch 项目(之前是集中式的)在两个客户端上运行联邦学习了。祝贺!

下一步工作

本示例的完整源代码为:PyTorch: 从集中式到联合式。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。现在,您已经准备好进一步探讨这一主题了。比如在每个客户端使用不同的 CIFAR-10 子集会如何?增加更多客户端会如何?