实例: PyTorch - 从集中式到联邦式¶
本教程将向您展示如何使用 Flower 构建现有机器学习工作的联邦版本。我们使用 PyTorch 在 CIFAR-10 数据集上训练一个卷积神经网络。首先,我们基于 "Deep Learning with PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_"教程,采用集中式训练方法介绍了这项机器学习任务。然后,我们在集中式训练代码的基础上以联邦方式运行训练。
集中式训练¶
我们首先简要介绍一下集中式 CNN 训练代码。如果您想获得更深入的解释,请参阅 PyTorch 官方教程`PyTorch tutorial <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_。
Let's create a new file called cifar.py
with all the components required for a
traditional (centralized) training on CIFAR-10. First, all required packages (such as
torch
and torchvision
) need to be imported. You can see that we do not import
any package for federated learning. You can keep all these imports as they are even when
we add the federated learning components at a later point.
from typing import Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10
As already mentioned we will use the CIFAR-10 dataset for this machine learning
workload. The model architecture (a very simple Convolutional Neural Network) is defined
in class Net()
.
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: Tensor) -> Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
The load_data()
function loads the CIFAR-10 training and test sets. The
transform
normalized the data after loading.
DATA_ROOT = "~/data/cifar-10"
def load_data() -> (
Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]
):
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
num_examples = {"trainset": len(trainset), "testset": len(testset)}
return trainloader, testloader, num_examples
We now need to define the training (function train()
) which loops over the training
set, measures the loss, backpropagates it, and then takes one optimizer step for each
batch of training examples.
The evaluation of the model is defined in the function test()
. The function loops
over all test samples and measures the loss of the model based on the test dataset.
def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device,
) -> None:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")
# Train the network
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
images, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return loss, accuracy
在确定了数据加载、模型架构、训练和评估之后,我们就可以将所有整合在一起,在 CIFAR-10 上训练我们的 CNN。
def main():
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Centralized PyTorch training")
print("Load data")
trainloader, testloader, _ = load_data()
print("Start training")
net = Net().to(DEVICE)
train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
print("Evaluate model")
loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
print("Loss: ", loss)
print("Accuracy: ", accuracy)
if __name__ == "__main__":
main()
现在,您可以运行您的机器学习工作了:
python3 cifar.py
到目前为止,如果你以前用过 PyTorch,这一切看起来应该相当熟悉。让我们进行下一步,利用我们所构建的内容创建一个简单联邦学习系统(由一个服务器和两个客户端组成)。
联邦培训¶
上一节讨论的简单机器学习项目在单一数据集(CIFAR-10)上训练模型,我们称之为集中学习。如上一节所示,集中学习的概念可能为大多数人所熟知,而且很多人以前都使用过。通常情况下,如果要以联邦方式运行机器学习工作,就必须更改大部分代码,并从头开始设置一切。这可能是一个相当大的工作量。
不过,有了 Flower,您可以轻松地将已有的代码转变成联邦学习的模式,无需进行大量重写。
The concept is easy to understand. We have to start a server and then use the code in
cifar.py
for the clients that are connected to the server. The server sends
model parameters to the clients. The clients run the training and update the
parameters. The updated parameters are sent back to the server which averages all
received parameter updates. This describes one round of the federated learning process
and we repeat this for multiple rounds.
Our example consists of one server and two clients. Let's set up server.py
first. The server needs to import the Flower package flwr
. Next, we use the
start_server
function to start a server and tell it to perform three rounds of
federated learning.
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(
server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
)
我们已经可以启动*服务器*了:
python3 server.py
Finally, we will define our client logic in client.py
and build upon the
previously defined centralized training in cifar.py
. Our client needs to import
flwr
, but also torch
to update the parameters on our PyTorch model:
from collections import OrderedDict
from typing import Dict, List, Tuple
import numpy as np
import torch
import cifar
import flwr as fl
DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Implementing a Flower client basically means implementing a subclass of either
flwr.client.Client
or flwr.client.NumPyClient
. Our implementation will be based
on flwr.client.NumPyClient
and we'll call it CifarClient
. NumPyClient
is
slightly easier to implement than Client
if you use a framework with good NumPy
interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the
boilerplate that would otherwise be necessary. CifarClient
needs to implement four
methods, two methods for getting/setting model parameters, one method for training the
model, and one method for testing the model:
set_parameters
在本地模型上设置从服务器接收的模型参数
loop over the list of model parameters received as NumPy
ndarray
's (think list of neural network layers)
get_parameters
get the model parameters and return them as a list of NumPy
ndarray
's (which is whatflwr.client.NumPyClient
expects)
fit
用从服务器接收到的参数更新本地模型的参数
在本地训练集上训练模型
获取更新后的本地模型参数并发送回服务器
evaluate
用从服务器接收到的参数更新本地模型的参数
在本地测试集上评估更新后的模型
向服务器返回本地损失值和精确度
The two NumPyClient
methods fit
and evaluate
make use of the functions
train()
and test()
previously defined in cifar.py
. So what we really do here
is we tell Flower through our NumPyClient
subclass which of our already defined
functions to call for training and evaluation. We included type annotations to give you
a better understanding of the data types that get passed around.
class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using
PyTorch."""
def __init__(
self,
model: cifar.Net,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
num_examples: Dict,
) -> None:
self.model = model
self.trainloader = trainloader
self.testloader = testloader
self.num_examples = num_examples
def get_parameters(self, config) -> List[np.ndarray]:
# Return model parameters as a list of NumPy ndarrays
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Set model parameters from a list of NumPy ndarrays
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), self.num_examples["trainset"], {}
def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}
All that's left to do it to define a function that loads both model and data, creates a
CifarClient
, and starts this client. You load your data and model by using
cifar.py
. Start CifarClient
with the function fl.client.start_client()
by
pointing it at the same IP address we used in server.py
:
def main() -> None:
"""Load data, start CifarClient."""
# Load model and data
model = cifar.Net()
model.to(DEVICE)
trainloader, testloader, num_examples = cifar.load_data()
# Start client
client = CifarClient(model, trainloader, testloader, num_examples)
fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())
if __name__ == "__main__":
main()
就是这样,现在你可以打开另外两个终端窗口,然后运行
python3 client.py
确保服务器正在运行后,您就能看到您的 PyTorch 项目(之前是集中式的)在两个客户端上运行联邦学习了。祝贺!
下一步工作¶
本示例的完整源代码为:PyTorch: 从集中式到联合式。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。现在,您已经准备好进一步探讨这一主题了。比如在每个客户端使用不同的 CIFAR-10 子集会如何?增加更多客户端会如何?