开始使用Flower¶
欢迎阅读Flower联邦学习教程!
In this notebook, we’ll build a federated learning system using the Flower framework, Flower Datasets and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we federate the PyTorch project using Flower.
Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help: - Join Flower Discuss We’d love to hear from you in the
Introduction
topic! If anything is unclear, post inFlower Help - Beginners
. - Join Flower Slack We’d love to hear from you in the#introductions
channel! If anything is unclear, head over to the#questions
channel.
Let’s get started! 🌼
步骤 0:准备工作¶
在开始编写实际代码之前,让我们先确保我们已经准备好了所需的一切。
Install dependencies¶
Next, we install the necessary packages for PyTorch (torch
and torchvision
), Flower Datasets (flwr-datasets
) and Flower (flwr
):
[ ]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision matplotlib
现在我们已经安装了所有依赖项,可以导入本教程所需的所有内容:
[ ]:
from collections import OrderedDict
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()
It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware accelerator: GPU > Save
). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu")
. If the runtime has GPU acceleration enabled, you should see the output
Training on cuda
, otherwise it’ll say Training on cpu
.
Load the data¶
Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'.
We simulate having multiple datasets from multiple organizations (also called the "cross-silo" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We’re doing this purely for experimentation purposes, in the real world there’s no need for data splitting because each organization already has their own data (the data is naturally partitioned).
Each organization will act as a client in the federated learning system. Having ten organizations participate in a federation means having ten clients connected to the federated learning server.
We use the Flower Datasets library (flwr-datasets
) to partition CIFAR-10 into ten partitions using FederatedDataset
. We will create a small training and test set for each of the ten organizations and wrap each of these into a PyTorch DataLoader
:
[ ]:
NUM_CLIENTS = 10
BATCH_SIZE = 32
def load_datasets(partition_id: int):
fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
pytorch_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
def apply_transforms(batch):
# Instead of passing transforms to CIFAR10(..., transform=transform)
# we will use this function to dataset.with_transform(apply_transforms)
# The transforms object is exactly the same
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch
# Create train/val for each partition and wrap it into DataLoader
partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(
partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
)
valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
testset = fds.load_split("test").with_transform(apply_transforms)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)
return trainloader, valloader, testloader
We now have a function that can return a training set and validation set (trainloader
and valloader
) representing one dataset from one of ten different organizations. Each trainloader
/valloader
pair contains 4000 training examples and 1000 validation examples. There’s also a single testloader
(we did not split the test set). Again, this is only necessary for building research or educational systems, actual federated learning systems have their data naturally distributed
across multiple partitions.
Let’s take a look at the first batch of images and labels in the first training set (i.e., trainloader
from partition_id=0
) before we move on:
[ ]:
trainloader, _, _ = load_datasets(partition_id=0)
batch = next(iter(trainloader))
images, labels = batch["img"], batch["label"]
# Reshape and convert images to a NumPy array
# matplotlib requires images with the shape (height, width, 3)
images = images.permute(0, 2, 3, 1).numpy()
# Denormalize
images = images / 2 + 0.5
# Create a figure and a grid of subplots
fig, axs = plt.subplots(4, 8, figsize=(12, 6))
# Loop over the images and plot them
for i, ax in enumerate(axs.flat):
ax.imshow(images[i])
ax.set_title(trainloader.dataset.features["label"].int2str([labels[i]])[0])
ax.axis("off")
# Show the plot
fig.tight_layout()
plt.show()
The output above shows a random batch of images from the trainloader
from the first of ten partitions. It also prints the labels associated with each image (i.e., one of the ten possible labels we’ve seen above). If you run the cell again, you should see another batch of images.
步骤 1:使用 PyTorch 进行集中训练¶
接下来,我们将使用 PyTorch 来定义一个简单的卷积神经网络。本介绍假定您对 PyTorch 有基本的了解,因此不会详细介绍与 PyTorch 相关的内容。如果你想更深入地了解 PyTorch,我们推荐你阅读 DEEP LEARNING WITH PYTORCH: a 60 minute blitz。
Define the model¶
我们使用` PyTorch 教程 <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network>`__ 中描述的简单 CNN:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
让我们继续进行常规的训练和测试功能:
[ ]:
def train(net, trainloader, epochs: int, verbose=False):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for batch in trainloader:
images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
if verbose:
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for batch in testloader:
images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Train the model¶
We now have all the basic building blocks we need: a dataset, a model, a training function, and a test function. Let’s put them together to train the model on the dataset of one of our organizations (partition_id=0
). This simulates the reality of most machine learning projects today: each organization has their own data and trains models only on this internal data:
[ ]:
trainloader, valloader, testloader = load_datasets(partition_id=0)
net = Net().to(DEVICE)
for epoch in range(5):
train(net, trainloader, 1)
loss, accuracy = test(net, valloader)
print(f"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}")
loss, accuracy = test(net, testloader)
print(f"Final test set performance:\n\tloss {loss}\n\taccuracy {accuracy}")
Training the simple CNN on our CIFAR-10 split for 5 epochs should result in a test set accuracy of about 41%, which is not good, but at the same time, it doesn’t really matter for the purposes of this tutorial. The intent was just to show a simple centralized training pipeline that sets the stage for what comes next - federated learning!
步骤 2:使用 Flower 联邦学习¶
步骤 1 演示了一个简单的集中式训练流程。所有数据都在一个地方(即一个 "trainloader "和一个 "valloader")。接下来,我们将模拟在多个组织中拥有多个数据集的情况,并使用联邦学习在这些组织中训练一个模型。
Update model parameters¶
In federated learning, the server sends global model parameters to the client, and the client updates the local model with parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).
我们需要两个辅助函数,用从服务器接收到的参数更新本地模型,并从本地模型获取更新后的模型参数:`` set_parameters```和`get_parameters``。下面两个函数就是为上面的 PyTorch 模型做这些工作的。
The details of how this works are not really important here (feel free to consult the PyTorch documentation if you want to learn more). In essence, we use state_dict
to access PyTorch model parameter tensors. The parameter tensors are then converted to/from a list of NumPy ndarray’s (which the Flower NumPyClient
knows how to serialize/deserialize):
[ ]:
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
Define the Flower ClientApp¶
With that out of the way, let’s move on to the interesting part. Federated learning systems consist of a server and multiple clients. In Flower, we create a ServerApp
and a ClientApp
to run the server-side and client-side code, respectively.
The first step toward creating a ClientApp
is to implement a subclasses of flwr.client.Client
or flwr.client.NumPyClient
. We use NumPyClient
in this tutorial because it is easier to implement and requires us to write less boilerplate. To implement NumPyClient
, we create a subclass that implements the three methods get_parameters
, fit
, and evaluate
:
get_parameters
: 返回当前本地模型参数fit
: Receive model parameters from the server, train the model on the local data, and return the updated model parameters to the serverevaluate
: Receive model parameters from the server, evaluate the model on the local data, and return the evaluation result to the server
我们提到,我们的客户端将使用之前定义的 PyTorch 组件进行模型训练和评估。让我们来看看一个简单的 Flower 客户端实现,它将一切都整合在一起:
[ ]:
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, valloader):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
return get_parameters(self.net)
def fit(self, parameters, config):
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
Our class FlowerClient
defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through fit
and evaluate
. Each instance of FlowerClient
represents a single client in our federated learning system. Federated learning systems have multiple clients (otherwise, there’s not much to federate), so each client will be represented by its own instance of FlowerClient
. If we have, for example, three clients in our workload, then
we’d have three instances of FlowerClient
(one on each of the machines we’d start the client on). Flower calls FlowerClient.fit
on the respective instance when the server selects a particular client for training (and FlowerClient.evaluate
for evaluation).
In this notebook, we want to simulate a federated learning system with 10 clients on a single machine. This means that the server and all 10 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 10 clients would mean having 10 instances of FlowerClient
in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.
In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create FlowerClient
instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function that creates a FlowerClient
instance on demand. We typically call this function client_fn
. Flower calls client_fn
whenever it needs an
instance of one particular client to call fit
or evaluate
(those instances are usually discarded after use, so they should not keep any local state). In federated learning experiments using Flower, clients are identified by a partition ID, or partition-id
. This partition-id
is used to load different local data partitions for different clients, as can be seen below. The value of partition-id
is retrieved from the node_config
dictionary in the Context
object, which
holds the information that persists throughout each training round.
With this, we have the class FlowerClient
which defines client-side training/evaluation and client_fn
which allows Flower to create FlowerClient
instances whenever it needs to call fit
or evaluate
on one particular client. Last, but definitely not least, we create an instance of ClientApp
and pass it the client_fn
. ClientApp
is the entrypoint that a running Flower client uses to call your code (as defined in, for example, FlowerClient.fit
).
[ ]:
def client_fn(context: Context) -> Client:
"""Create a Flower client representing a single organization."""
# Load model
net = Net().to(DEVICE)
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data partition
# Read the node_config to fetch data partition associated to this node
partition_id = context.node_config["partition-id"]
trainloader, valloader, _ = load_datasets(partition_id=partition_id)
# Create a single Flower client representing a single organization
# FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
# to convert it to a subclass of `flwr.client.Client`
return FlowerClient(net, trainloader, valloader).to_client()
# Create the ClientApp
client = ClientApp(client_fn=client_fn)
Define the Flower ServerApp¶
On the server side, we need to configure a strategy which encapsulates the federated learning approach/algorithm, for example, Federated Averaging (FedAvg). Flower has a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in FedAvg
implementation and customize it using a few basic parameters:
[ ]:
# Create FedAvg strategy
strategy = FedAvg(
fraction_fit=1.0, # Sample 100% of available clients for training
fraction_evaluate=0.5, # Sample 50% of available clients for evaluation
min_fit_clients=10, # Never sample less than 10 clients for training
min_evaluate_clients=5, # Never sample less than 5 clients for evaluation
min_available_clients=10, # Wait until all 10 clients are available
)
Similar to ClientApp
, we create a ServerApp
using a utility function server_fn
. In server_fn
, we pass an instance of ServerConfig
for defining the number of federated learning rounds (num_rounds
) and we also pass the previously created strategy
. The server_fn
returns a ServerAppComponents
object containing the settings that define the ServerApp
behaviour. ServerApp
is the entrypoint that Flower uses to call all your server-side code (for example, the
strategy).
[ ]:
def server_fn(context: Context) -> ServerAppComponents:
"""Construct components that set the ServerApp behaviour.
You can use the settings in `context.run_config` to parameterize the
construction of all elements (e.g the strategy or the number of rounds)
wrapped in the returned ServerAppComponents object.
"""
# Configure the server for 5 rounds of training
config = ServerConfig(num_rounds=5)
return ServerAppComponents(strategy=strategy, config=config)
# Create the ServerApp
server = ServerApp(server_fn=server_fn)
Run the training¶
In simulation, we often want to control the amount of resources each client can use. In the next cell, we specify a backend_config
dictionary with the client_resources
key (required) for defining the amount of CPU and GPU resources each client can access.
[ ]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}
# When running on GPU, assign an entire GPU for each client
if DEVICE.type == "cuda":
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
# Refer to our Flower framework documentation for more details about Flower simulations
# and how to set up the `backend_config`
The last step is the actual call to run_simulation
which - you guessed it - runs the simulation. run_simulation
accepts a number of arguments: - server_app
and client_app
: the previously created ServerApp
and ClientApp
objects, respectively - num_supernodes
: the number of SuperNodes
to simulate which equals the number of clients for Flower simulation - backend_config
: the resource allocation used in this simulation
[ ]:
# Run simulation
run_simulation(
server_app=server,
client_app=client,
num_supernodes=NUM_CLIENTS,
backend_config=backend_config,
)
幕后¶
那么它是如何工作的呢?Flower 如何进行模拟?
When we call run_simulation
, we tell Flower that there are 10 clients (num_supernodes=10
, where 1 SuperNode
launches 1 ClientApp
). Flower then goes ahead an asks the ServerApp
to issue an instructions to those nodes using the FedAvg
strategy. FedAvg
knows that it should select 100% of the available clients (fraction_fit=1.0
), so it goes ahead and selects 10 random clients (i.e., 100% of 10).
Flower then asks the selected 10 clients to train the model. Each of the 10 ClientApp
instances receives a message, which causes it to call client_fn
to create an instance of FlowerClient
. It then calls .fit()
on each the FlowerClient
instances and returns the resulting model parameter updates to the ServerApp
. When the ServerApp
receives the model parameter updates from the clients, it hands those updates over to the strategy (FedAvg) for aggregation. The strategy
aggregates those updates and returns the new global model, which then gets used in the next round of federated learning.
准确度在哪里找?¶
您可能已经注意到,除了 losses_distributed
以外,所有指标都是空的。{"准确度": float(准确度)}``去哪儿了?
Flower 可以自动汇总单个客户端返回的损失值,但无法对通用度量字典中的度量进行同样的处理(即带有 "准确度 "键的度量字典)。度量值字典可以包含非常不同种类的度量值,甚至包含根本不是度量值的键/值对,因此框架不知道(也无法知道)如何自动处理这些度量值。
作为用户,我们需要告诉框架如何处理/聚合这些自定义指标,为此,我们将指标聚合函数传递给策略。然后,只要从客户端接收到拟合或评估指标,策略就会调用这些函数。两个可能的函数是 fit_metrics_aggregation_fn
和 evaluate_metrics_aggregation_fn
。
让我们创建一个简单的加权平均函数来汇总从 evaluate
返回的 accuracy
指标:
[ ]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}
[ ]:
def server_fn(context: Context) -> ServerAppComponents:
"""Construct components that set the ServerApp behaviour.
You can use settings in `context.run_config` to parameterize the
construction of all elements (e.g the strategy or the number of rounds)
wrapped in the returned ServerAppComponents object.
"""
# Create FedAvg strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=0.5,
min_fit_clients=10,
min_evaluate_clients=5,
min_available_clients=10,
evaluate_metrics_aggregation_fn=weighted_average, # <-- pass the metric aggregation function
)
# Configure the server for 5 rounds of training
config = ServerConfig(num_rounds=5)
return ServerAppComponents(strategy=strategy, config=config)
# Create a new server instance with the updated FedAvg strategy
server = ServerApp(server_fn=server_fn)
# Run simulation
run_simulation(
server_app=server,
client_app=client,
num_supernodes=NUM_CLIENTS,
backend_config=backend_config,
)
我们现在有了一个完整的系统,可以执行联邦训练和联邦评估。它使用 weighted_average
函数汇总自定义评估指标,并在服务器端计算所有客户端的单一 accuracy
指标。
其他两类指标(losses_centralized` 和 metrics_centralized)仍然是空的,因为它们只适用于集中评估。Flower 教程的第二部分将介绍集中式评估。
结束语¶
恭喜您,你刚刚训练了一个由 10 个客户端组成的卷积神经网络!这样,你就了解了使用 Flower 进行联邦学习的基础知识。你所看到的方法同样适用于其他机器学习框架(不只是 PyTorch)和任务(不只是 CIFAR-10 图像分类),例如使用 Hugging Face Transformers 的 NLP 或使用 SpeechBrain 的语音。
在下一个笔记中,我们将介绍一些更先进的概念。想定制你的策略吗?在服务器端初始化参数?或者在服务器端评估聚合模型?我们将在下一个教程中介绍所有这些内容以及更多。
接下来的步骤¶
Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).
如果您需要帮助,我们有专门的 #questions
频道,但我们也很乐意在 #introductions
中了解您是谁!
Flower 联邦学习教程 - 第 2 部分 更深入地介绍了策略以及可以使用策略构建的所有高级功能。