예제: 파이토치 - 중앙 집중식에서 연합식으로

이 튜토리얼에서는 Flower를 사용해 기존 머신 러닝 워크로드의 연합 버전을 구축하는 방법을 보여드립니다. 여기서는 PyTorch를 사용해 CIFAR-10 데이터 세트에서 컨볼루션 신경망을 훈련합니다. 먼저, ‘PyTorch로 딥 러닝 <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ 튜토리얼을 기반으로 centralized 학습 접근 방식을 사용하여 이 머신 러닝 작업을 소개합니다. 그런 다음 centralized 훈련 코드를 기반으로 연합 방식 훈련을 실행합니다.

중앙 집중식 훈련

중앙 집중식 CNN 트레이닝 코드에 대한 간략한 설명부터 시작하겠습니다. 무슨 일이 일어나고 있는지 더 자세히 설명하려면 공식 `PyTorch 튜토리얼 <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_을 참조하세요.

Let’s create a new file called cifar.py with all the components required for a traditional (centralized) training on CIFAR-10. First, all required packages (such as torch and torchvision) need to be imported. You can see that we do not import any package for federated learning. You can keep all these imports as they are even when we add the federated learning components at a later point.

from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10

As already mentioned we will use the CIFAR-10 dataset for this machine learning workload. The model architecture (a very simple Convolutional Neural Network) is defined in class Net().

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The load_data() function loads the CIFAR-10 training and test sets. The transform normalized the data after loading.

DATA_ROOT = "~/data/cifar-10"


def load_data() -> (
    Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]
):
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    return trainloader, testloader, num_examples

We now need to define the training (function train()) which loops over the training set, measures the loss, backpropagates it, and then takes one optimizer step for each batch of training examples.

The evaluation of the model is defined in the function test(). The function loops over all test samples and measures the loss of the model based on the test dataset.

def train(
    net: Net,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,
) -> None:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0


def test(
    net: Net,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

데이터 로딩, 모델 아키텍처, 훈련 및 평가를 정의했으면 모든 것을 종합하여 CIFAR-10에서 CNN을 훈련할 수 있습니다.

def main():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Centralized PyTorch training")
    print("Load data")
    trainloader, testloader, _ = load_data()
    print("Start training")
    net = Net().to(DEVICE)
    train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
    print("Evaluate model")
    loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
    print("Loss: ", loss)
    print("Accuracy: ", accuracy)


if __name__ == "__main__":
    main()

이제 머신 러닝 워크로드를 실행할 수 있습니다:

python3 cifar.py

지금까지는 파이토치를 사용해 본 적이 있다면 상당히 익숙하게 보일 것입니다. 다음 단계로 넘어가서 구축한 것을 사용하여 하나의 서버와 두 개의 클라이언트로 구성된 간단한 연합 학습 시스템을 만들어 보겠습니다.

연합 훈련

이전 섹션에서 설명한 간단한 머신 러닝 프로젝트는 단일 데이터 세트(CIFAR-10)로 모델을 학습시키는데, 이를 중앙 집중식 학습이라고 부릅니다. 이전 섹션에서 설명한 중앙 집중식 학습의 개념은 대부분 알고 계실 것이며, 많은 분들이 이전에 사용해 보셨을 것입니다. 일반적으로 머신 러닝 워크로드를 연합 방식으로 실행하려면 대부분의 코드를 변경하고 모든 것을 처음부터 다시 설정해야 합니다. 이는 상당한 노력이 필요할 수 있습니다.

하지만 Flower를 사용하면 대대적인 재작성 없이도 기존 코드를 연합 학습 설정으로 발전시킬 수 있습니다.

The concept is easy to understand. We have to start a server and then use the code in cifar.py for the clients that are connected to the server. The server sends model parameters to the clients. The clients run the training and update the parameters. The updated parameters are sent back to the server which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds.

Our example consists of one server and two clients. Let’s set up server.py first. The server needs to import the Flower package flwr. Next, we use the start_server function to start a server and tell it to perform three rounds of federated learning.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(
        server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
    )

이미 *서버*를 시작할 수 있습니다:

python3 server.py

Finally, we will define our client logic in client.py and build upon the previously defined centralized training in cifar.py. Our client needs to import flwr, but also torch to update the parameters on our PyTorch model:

from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch

import cifar
import flwr as fl

DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we’ll call it CifarClient. NumPyClient is slightly easier to implement than Client if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary. CifarClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters
    • 서버에서 수신한 로컬 모델의 모델 파라미터를 설정합니다

    • loop over the list of model parameters received as NumPy ndarray’s (think list of neural network layers)

  2. get_parameters
    • get the model parameters and return them as a list of NumPy ndarray’s (which is what flwr.client.NumPyClient expects)

  3. fit
    • 서버에서 받은 파라미터로 로컬 모델의 파라미터를 업데이트합니다

    • 로컬 훈련 세트에서 모델을 훈련합니다

    • 업데이트된 로컬 모델 가중치를 가져와 서버로 반환합니다

  4. evaluate
    • 서버에서 받은 파라미터로 로컬 모델의 파라미터를 업데이트합니다

    • 로컬 테스트 세트에서 업데이트된 모델을 평가합니다

    • 로컬 손실 및 정확도를 서버에 반환합니다

The two NumPyClient methods fit and evaluate make use of the functions train() and test() previously defined in cifar.py. So what we really do here is we tell Flower through our NumPyClient subclass which of our already defined functions to call for training and evaluation. We included type annotations to give you a better understanding of the data types that get passed around.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    def __init__(
        self,
        model: cifar.Net,
        trainloader: torch.utils.data.DataLoader,
        testloader: torch.utils.data.DataLoader,
        num_examples: Dict,
    ) -> None:
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.num_examples = num_examples

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        self.set_parameters(parameters)
        cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
        return self.get_parameters(config={}), self.num_examples["trainset"], {}

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        self.set_parameters(parameters)
        loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
        return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}

All that’s left to do it to define a function that loads both model and data, creates a CifarClient, and starts this client. You load your data and model by using cifar.py. Start CifarClient with the function fl.client.start_client() by pointing it at the same IP address we used in server.py:

def main() -> None:
    """Load data, start CifarClient."""

    # Load model and data
    model = cifar.Net()
    model.to(DEVICE)
    trainloader, testloader, num_examples = cifar.load_data()

    # Start client
    client = CifarClient(model, trainloader, testloader, num_examples)
    fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())


if __name__ == "__main__":
    main()

여기까지입니다. 이제 두 개의 터미널 창을 추가로 열고 다음을 실행할 수 있습니다

python3 client.py

를 입력하고(그 전에 서버가 실행 중인지 확인하세요) (이전에는 중앙 집중식) PyTorch 프로젝트가 두 클라이언트에서 연합 학습을 실행하는 것을 확인합니다. 축하합니다!

다음 단계

이 예제의 전체 소스 코드: 파이토치: 중앙 Centralized에서 Federated으로 (코드). 물론 이 예제는 두 클라이언트가 완전히 동일한 데이터 세트를 로드하기 때문에 다소 지나치게 단순화되어 있으며, 이는 현실적이지 않습니다. 이제 이 주제를 더 자세히 살펴볼 준비가 되셨습니다. 각 클라이언트에서 서로 다른 CIFAR-10의 하위 집합을 사용해 보는 것은 어떨까요? 클라이언트를 더 추가하는 것은 어떨까요?