开始使用Flower¶
欢迎阅读Flower联邦学习教程!
In this tutorial, we'll build a federated learning system using the Flower framework, Flower Datasets and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we federate the PyTorch project using Flower.
Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:
Join Flower Discuss We'd love to hear from you in the
Introduction
topic! If anything is unclear, post inFlower Help - Beginners
.Join Flower Slack We'd love to hear from you in the
#introductions
channel! If anything is unclear, head over to the#questions
channel.
Let's get started! 🌼
准备工作¶
在开始编写实际代码之前,让我们先确保我们已经准备好了所需的一切。
Install dependencies¶
First, we install the Flower package flwr
:
# In a new Python environment
$ pip install -U "flwr[simulation]"
Then, we create a new Flower app called flower-tutorial
using the PyTorch template.
We also specify a username (flwrlabs
) for the project:
$ flwr new flower-tutorial --framework pytorch --username flwrlabs
After running the command, a new directory called flower-tutorial
will be created.
It should have the following structure:
flower-tutorial
├── README.md
├── flower_tutorial
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
Next, we install the project and its dependencies, which are specified in the
pyproject.toml
file.
$ cd flower-tutorial
$ pip install -e .
Before we dive into federated learning, we'll take a look at the dataset that we'll be using for this tutorial, which is the CIFAR-10 dataset, and run a simple centralized training pipeline using PyTorch.
The CIFAR-10
dataset¶
Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'.
We simulate having multiple datasets from multiple organizations (also called the “cross-silo” setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes, in the real world there's no need for data splitting because each organization already has their own data (the data is naturally partitioned).
Each organization will act as a client in the federated learning system. Having ten organizations participate in a federation means having ten clients connected to the federated learning server.
We use the Flower Datasets library
(flwr-datasets
) to partition CIFAR-10 into ten partitions using
FederatedDataset
. Using the load_data()
function defined in task.py
, we will
create a small training and test set for each of the ten organizations and wrap each of
these into a PyTorch DataLoader
:
def load_data(partition_id: int, num_partitions: int):
"""Load partition CIFAR10 data."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch
partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader
We now have a function that can return a training set and validation set
(trainloader
and valloader
) representing one dataset from one of ten different
organizations. Each trainloader
/valloader
pair contains 4000 training examples
and 1000 validation examples. There's also a single testloader
(we did not split the
test set). Again, this is only necessary for building research or educational systems,
actual federated learning systems have their data naturally distributed across multiple
partitions.
The model and train and evaluate functions¶
Next, we're going to use PyTorch to define a simple convolutional neural network. This introduction assumes basic familiarity with PyTorch, so it doesn't cover the PyTorch-related aspects in full detail. If you want to dive deeper into PyTorch, we recommend this introductory tutorial.
The model¶
We will use the simple CNN described in the aforementioned PyTorch tutorial (The
following code is already defined in task.py
):
class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
The PyTorch template has also provided us with the usual training and test functions:
def train(net, trainloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
net.train()
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
loss = criterion(net(images.to(device)), labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_trainloss = running_loss / len(trainloader)
return avg_trainloss
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy
Federated Learning with Flower¶
In federated learning, the server sends global model parameters to the client, and the client updates the local model with parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).
Update model parameters¶
We need two helper functions to get the updated model parameters from the local model
and to update the local model with parameters received from the server: get_weights
and set_weights
. The following two functions do just that for the PyTorch model
above and are predefined in task.py
.
The details of how this works are not really important here (feel free to consult the
PyTorch documentation if you want to learn more). In essence, we use state_dict
to
access PyTorch model parameter tensors. The parameter tensors are then converted to/from
a list of NumPy ndarray
s (which the Flower NumPyClient
knows how to
serialize/deserialize):
def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
Define the Flower ClientApp¶
With that out of the way, let's move on to the interesting part. Federated learning
systems consist of a server and multiple clients. In Flower, we create a ServerApp
and a ClientApp
to run the server-side and client-side code, respectively.
The first step toward creating a ClientApp
is to implement a subclasses of
flwr.client.Client
or flwr.client.NumPyClient
. We use NumPyClient
in this
tutorial because it is easier to implement and requires us to write less boilerplate. To
implement NumPyClient
, we create a subclass that implements the three methods
get_weights
, fit
, and evaluate
:
get_weights
: Return the current local model parametersfit
: Receive model parameters from the server, train the model on the local data, and return the updated model parameters to the serverevaluate
: Receive model parameters from the server, evaluate the model on the local data, and return the evaluation result to the server
We mentioned that our clients will use the previously defined PyTorch components for model training and evaluation. Let's see a simple Flower client implementation that brings everything together. Note that all of this boilerplate implementation has already been done for us in our Flower project:
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, valloader, local_epochs):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)
def fit(self, parameters, config):
set_weights(self.net, parameters)
train_loss = train(
self.net,
self.trainloader,
self.local_epochs,
self.device,
)
return (
get_weights(self.net),
len(self.trainloader.dataset),
{"train_loss": train_loss},
)
def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
Our class FlowerClient
defines how local training/evaluation will be performed and
allows Flower to call the local training/evaluation through fit
and evaluate
.
Each instance of FlowerClient
represents a single client in our federated learning
system. Federated learning systems have multiple clients (otherwise, there's not much to
federate), so each client will be represented by its own instance of FlowerClient
.
If we have, for example, three clients in our workload, then we'd have three instances
of FlowerClient
(one on each of the machines we'd start the client on). Flower calls
FlowerClient.fit
on the respective instance when the server selects a particular
client for training (and FlowerClient.evaluate
for evaluation).
In this project, we want to simulate a federated learning system with 10 clients on a
single machine. This means that the server and all 10 clients will live on a single
machine and share resources such as CPU, GPU, and memory. Having 10 clients would mean
having 10 instances of FlowerClient
in memory. Doing this on a single machine can
quickly exhaust the available memory resources, even if only a subset of these clients
participates in a single round of federated learning.
In addition to the regular capabilities where server and clients run on multiple
machines, Flower, therefore, provides special simulation capabilities that create
FlowerClient
instances only when they are actually necessary for training or
evaluation. To enable the Flower framework to create clients when necessary, we need to
implement a function that creates a FlowerClient
instance on demand. We typically
call this function client_fn
. Flower calls client_fn
whenever it needs an
instance of one particular client to call fit
or evaluate
(those instances are
usually discarded after use, so they should not keep any local state). In federated
learning experiments using Flower, clients are identified by a partition ID, or
partition_id
. This partition_id
is used to load different local data partitions
for different clients, as can be seen below. The value of partition_id
is retrieved
from the node_config
dictionary in the Context
object, which holds the
information that persists throughout each training round.
With this, we have the class FlowerClient
which defines client-side
training/evaluation and client_fn
which allows Flower to create FlowerClient
instances whenever it needs to call fit
or evaluate
on one particular client.
Last, but definitely not least, we create an instance of ClientApp
and pass it the
client_fn
. ClientApp
is the entrypoint that a running Flower client uses to call
your code (as defined in, for example, FlowerClient.fit
). The following code is
reproduced from client_app.py
with additional comments:
def client_fn(context: Context):
# Load model and data
net = Net()
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data partition
# Read the node_config to fetch data partition associated to this node
trainloader, valloader = load_data(partition_id, num_partitions)
local_epochs = context.run_config["local-epochs"]
# Create a single Flower client representing a single organization
# FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
# to convert it to a subclass of `flwr.client.Client`
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
# Create the Flower ClientApp
app = ClientApp(client_fn=client_fn)
Define the Flower ServerApp¶
On the server side, we need to configure a strategy which encapsulates the federated
learning approach/algorithm, for example, Federated Averaging (FedAvg). Flower has a
number of built-in strategies, but we can also use our own strategy implementations to
customize nearly all aspects of the federated learning approach. For this example, we
use the built-in FedAvg
implementation and customize it using a few basic
parameters:
# Create FedAvg strategy
strategy = FedAvg(
fraction_fit=fraction_fit, # Sample this value of available client for training
fraction_evaluate=1.0, # Sample 100% of available clients for evaluation
min_available_clients=2, # Wait until 2 clients are available
initial_parameters=parameters, # Use these initial model parameters
)
Similar to ClientApp
, we create a ServerApp
using a utility function
server_fn
. This function is predefined for us in server_app.py
. In
server_fn
, we pass an instance of ServerConfig
for defining the number of
federated learning rounds (num_rounds
) and we also pass the previously created
strategy
. The server_fn
returns a ServerAppComponents
object containing the
settings that define the ServerApp
behaviour. ServerApp
is the entrypoint that
Flower uses to call all your server-side code (for example, the strategy).
def server_fn(context: Context):
"""Construct components that set the ServerApp behaviour.
You can use the settings in `context.run_config` to parameterize the
construction of all elements (e.g the strategy or the number of rounds)
wrapped in the returned ServerAppComponents object.
"""
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]
# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)
# Define strategy
strategy = FedAvg(
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=parameters,
)
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(strategy=strategy, config=config)
Run the training¶
With all of these components in place, we can now run the federated learning simulation with Flower! The last step is to run our simulation in the command line, as follows:
$ flwr run .
This will execute the federated learning simulation with 10 clients, or SuperNodes,
defined in the [tool.flwr.federations.local-simulation]
section in the
pyproject.toml
. You can also override the parameters defined in the
[tool.flwr.app.config]
section in pyproject.toml
like this:
# Run the simulation with 5 server rounds and 3 local epochs
$ flwr run . --run-config "num-server-rounds=5 local-epochs=3"
Tip
Learn more about how to configure the execution of your Flower App by checking the pyproject.toml guide.
幕后¶
那么它是如何工作的呢?Flower 如何进行模拟?
When we execute flwr run
, we tell Flower that there are 10 clients
(options.num-supernodes = 10
, where 1 SuperNode
launches 1 ClientApp
).
Flower then goes ahead and asks the ServerApp
to issue instructions to those nodes
using the FedAvg
strategy. FedAvg
knows that it should select 50% of the
available clients (fraction-fit=0.5
), so it goes ahead and selects 5 random clients
(i.e., 50% of 10).
Flower then asks the selected 5 clients to train the model. Each of the 5 ClientApp
instances receives a message, which causes it to call client_fn
to create an
instance of FlowerClient
. It then calls .fit()
on each of the FlowerClient
instances and returns the resulting model parameter updates to the ServerApp
. When
the ServerApp
receives the model parameter updates from the clients, it hands those
updates over to the strategy (FedAvg) for aggregation. The strategy aggregates those
updates and returns the new global model, which then gets used in the next round of
federated learning.
准确度在哪里找?¶
您可能已经注意到,除了 losses_distributed
以外,所有指标都是空的。{"准确度": float(准确度)}``去哪儿了?
Flower 可以自动汇总单个客户端返回的损失值,但无法对通用度量字典中的度量进行同样的处理(即带有 "准确度 "键的度量字典)。度量值字典可以包含非常不同种类的度量值,甚至包含根本不是度量值的键/值对,因此框架不知道(也无法知道)如何自动处理这些度量值。
作为用户,我们需要告诉框架如何处理/聚合这些自定义指标,为此,我们将指标聚合函数传递给策略。然后,只要从客户端接收到拟合或评估指标,策略就会调用这些函数。两个可能的函数是 fit_metrics_aggregation_fn
和 evaluate_metrics_aggregation_fn
。
Let's create a simple weighted averaging function to aggregate the accuracy
metric
we return from evaluate
. Copy the following weighted_average()
function to
task.py
:
from typing import List, Tuple
from flwr.common.typing import Metrics
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}
Now, in server_app.py
, we import the function and pass it to the FedAvg
strategy:
from flower_tutorial.task import weighted_average
def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]
# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)
# Define strategy
strategy = FedAvg(
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=parameters,
evaluate_metrics_aggregation_fn=weighted_average,
)
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(strategy=strategy, config=config)
# Create ServerApp
app = ServerApp(server_fn=server_fn)
我们现在有了一个完整的系统,可以执行联邦训练和联邦评估。它使用 weighted_average
函数汇总自定义评估指标,并在服务器端计算所有客户端的单一 accuracy
指标。
其他两类指标(losses_centralized` 和 metrics_centralized)仍然是空的,因为它们只适用于集中评估。Flower 教程的第二部分将介绍集中式评估。
结束语¶
恭喜您,你刚刚训练了一个由 10 个客户端组成的卷积神经网络!这样,你就了解了使用 Flower 进行联邦学习的基础知识。你所看到的方法同样适用于其他机器学习框架(不只是 PyTorch)和任务(不只是 CIFAR-10 图像分类),例如使用 Hugging Face Transformers 的 NLP 或使用 SpeechBrain 的语音。
In the next tutorial, we're going to cover some more advanced concepts. Want to customize your strategy? Initialize parameters on the server side? Or evaluate the aggregated model on the server side? We'll cover all this and more in the next tutorial.
接下来的步骤¶
Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).
There's a dedicated #questions
Slack channel if you need help, but we'd also love to
hear who you are in #introductions
!
The Flower Federated Learning Tutorial - Part 2 goes into more depth about strategies and all the advanced things you can build with them.