开始使用Flower¶
欢迎阅读Flower联邦学习教程!
In this tutorial, we'll build a federated learning system using the Flower framework, Flower Datasets and PyTorch. In part 1, we use PyTorch for model training and data loading. In part 2, we federate this PyTorch project using Flower.
Tip
Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:
Join Flower Discuss We'd love to hear from you in the
Introduction
topic! If anything is unclear, post inFlower Help - Beginners
.Join Flower Slack We'd love to hear from you in the
#introductions
channel! If anything is unclear, head over to the#questions
channel.
Let's get started! 🌼
准备工作¶
在开始编写实际代码之前,让我们先确保我们已经准备好了所需的一切。
Install dependencies¶
First, we install the Flower package flwr
:
# In a new Python environment
$ pip install -U "flwr[simulation]"
Then, we create a new Flower app called flower-tutorial
using the PyTorch template.
We also specify a username (flwrlabs
) for the project:
$ flwr new flower-tutorial --framework pytorch --username flwrlabs
After running the command, a new directory called flower-tutorial
will be created.
It should have the following structure:
flower-tutorial
├── README.md
├── flower_tutorial
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
Next, we install the project and its dependencies, which are specified in the
pyproject.toml
file.
$ cd flower-tutorial
$ pip install -e .
Before we dive into federated learning, we'll take a look at the dataset that we'll be using for this tutorial, which is the CIFAR-10 dataset, and run a simple centralized training pipeline using PyTorch.
The CIFAR-10
dataset¶
Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'.
We simulate having multiple datasets from multiple organizations (also called the “cross-silo” setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes. In the real world there's no need for data splitting because each organization already has their own data (the data is naturally partitioned).
Each organization will act as a client in the federated learning system. Having ten organizations participate in a federation means having ten clients connected to the federated learning server.
We use the Flower Datasets library
(flwr-datasets
) to partition CIFAR-10 into ten partitions using
FederatedDataset
. Using the load_data()
function defined in task.py
, we will
create a small training and test set for each of the ten organizations and wrap each of
these into a PyTorch DataLoader
:
def load_data(partition_id: int, num_partitions: int):
"""Load partition CIFAR10 data."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch
partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader
We now have a function that can return a training set and validation set
(trainloader
and valloader
) representing one dataset from one of ten different
organizations. Each trainloader
/valloader
pair contains 4000 training examples
and 1000 validation examples. There's also a single testloader
(we did not split the
test set). Again, this is only necessary for building research or educational systems,
actual federated learning systems have their data naturally distributed across multiple
partitions.
The model, training, and test functions¶
Next, we're going to use PyTorch to define a simple convolutional neural network. This introduction assumes basic familiarity with PyTorch, so it doesn't cover the PyTorch-related aspects in full detail. If you want to dive deeper into PyTorch, we recommend this introductory tutorial.
Model¶
We will use the simple CNN described in the aforementioned PyTorch tutorial (The
following code is already defined in task.py
):
class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
Training and test functions¶
The PyTorch template also provides the usual training and test functions:
def train(net, trainloader, epochs, lr, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
net.train()
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_trainloss = running_loss / len(trainloader)
return avg_trainloss
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy
Federated Learning with Flower¶
In federated learning, the server sends global model parameters to the client, and the client updates the local model with parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).
Constructing Messages¶
In Flower, the server and clients communicate by sending and receiving Message
objects. A Message
carries a RecordDict
as its main payload. The RecordDict
is like a Python dictionary that can contain multiple records of different types. There
are three main types of records:
ArrayRecord
: Contains model parameters as a dictionary of NumPy arraysMetricRecord
: Contains training or evaluation metrics as a dictionary of integers, floats, lists of integers, or lists of floats.ConfigRecord
: Contains configuration parameters as a dictionary of integers, floats, strings, booleans, or bytes. Lists of these types are also supported.
Let's see a few examples of how to work with these types of records and, ultimately,
construct a RecordDict
that can be sent over a Message
.
from flwr.app import ArrayRecord, MetricRecord, ConfigRecord, RecordDict
# ConfigRecord can be used to communicate configs between ServerApp and ClientApp
# They can hold scalars, but also strings and booleans
config = ConfigRecord(
{"batch_size": 32, "use_augmentation": True, "data-path": "/my/dataset"}
)
# MetricRecords expect scalar-based metrics (i.e. int/float/list[int]/list[float])
# By limiting the types Flower can aggregate MetricRecords automatically
metrics = MetricRecord({"accuracy": 0.9, "losses": [0.1, 0.001], "perplexity": 2.31})
# ArrayRecord objects are designed to communicate arrays/tensors/weights from ML models
array_record = ArrayRecord(my_model.state_dict()) # for a PyTorch model
array_record_other = ArrayRecord(my_model.to_numpy_ndarrays()) # for other ML models
# A RecordDict is like a dictionary that holds named records.
# This is the main payload of a Message
rd = RecordDict({"my-config": config, "metrics": metrics, "my-model": array_record})
Define the Flower ClientApp¶
Federated learning systems consist of a server and multiple nodes or clients. In Flower,
we create a ServerApp
and a ClientApp
to run the server-side and
client-side code, respectively.
The core functionality of the ClientApp
is to perform some action with the local
data that the node it runs from (e.g. an edge device, a server in a data center, or a
laptop) has access to. In this tutorial such action is to train and evaluate the small
CNN model defined earlier using the local training and validation data.
Training¶
We can define how the ClientApp
performs training by wrapping a function with the
@app.train()
decorator. In this case we name this function train
because we'll
use it to train the model on the local data. The function always expects two arguments:
A
Message
: The message received from the server. It contains the model parameters and any other configuration information sent by the server.A
Context
: The context object that contains information about the node executing theClientApp
and about the current run.
Through the context you can retrieve the config settings defined in the
pyproject.toml
of your app. The context can be used to persist the state of the
client across multiple calls to train
or evaluate
. In Flower, ClientApps
are
ephemeral objects that get instantiated for the execution of one Message
and
destroyed when a reply is communicated back to the server.
Let's see an implementation of ClientApp
that uses the previously defined PyTorch
CNN model, applies the parameters received from the ServerApp
via the message, loads
its local data, trains the model with it (using the train_fn
function), and
generates a reply Message
containing the updated model parameters as well as some
metrics of interest.
from flower_tutorial.task import train as train_fn
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Load the model and initialize it with the received weights
model = Net()
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load the data
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, _ = load_data(partition_id, num_partitions)
# Call the training function
train_loss = train_fn(
model,
trainloader,
context.run_config["local-epochs"],
msg.content["config"]["lr"],
device,
)
# Construct and return reply Message
model_record = ArrayRecord(model.state_dict())
metrics = {
"train_loss": train_loss,
"num-examples": len(trainloader.dataset),
}
metric_record = MetricRecord(metrics)
content = RecordDict({"arrays": model_record, "metrics": metric_record})
return Message(content=content, reply_to=msg)
Note that the train_fn
is simply an alias name pointing to the train function
defined earlier in this tutorial (where we defined the PyTorch training loop and
optimizer). To this function we pass the model we want to train locally and the data
loader, but also the number of local epochs and the learning rate (lr
) to use. Note
how in this case the local-epochs
setting is read from the run config via the
Context
while the lr
is read from the ConfigRecord
sent by the server via
the Message
. This can be used to adjust the learning rate on each round from the
server. When this dynamism isn't needed, reading the lr
from the run config via the
Context
is also perfectly valid.
Once training is completed, the ClientApp
constructs a reply Message
. This reply
typically includes a RecordDict
with two records:
An
ArrayRecord
containing the updated model parametersA
MetricRecord
with relevant metrics (in this case, the training loss and the number of examples used for training)
Note
Returning the number of examples under the "num-examples"
key is required,
because strategies such as FedAvg
used by the ServerApp
rely on this key
to aggregate both models and metrics by default, unless you override the
weighted_by_key
argument (for example:
FedAvg(weighted_by="my-different-key")
).
After constructing the reply Message
, the ClientApp
returns it. Flower then
handles sending the reply back to the server automatically.
Evaluation¶
In a typical federated learning setup, the ClientApp
would also implement an
@app.evaluate()
function to evaluate the model received from the ServerApp
on
local validation data. This is especially useful to monitor the performance of the
global model on each client during training. The implementation of the evaluate
function is very similar to the train
function, except that it calls the test_fn
function defined earlier in this tutorial (which implements the PyTorch evaluation loop)
and it returns a Message
containing only a MetricRecord
with the evaluation
metrics (no ArrayRecord
because the model parameters are not updated during
evaluation). Here's how the evaluate
function looks like:
from flower_tutorial.task import test as test_fn
@app.evaluate()
def evaluate(msg: Message, context: Context):
"""Evaluate the model on local data."""
# Load the model and initialize it with the received weights
model = Net()
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load the data
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
_, valloader = load_data(partition_id, num_partitions)
# Call the evaluation function
eval_loss, eval_acc = test_fn(
model,
valloader,
device,
)
# Construct and return reply Message
metrics = {
"eval_loss": eval_loss,
"eval_acc": eval_acc,
"num-examples": len(valloader.dataset),
}
metric_record = MetricRecord(metrics)
content = RecordDict({"metrics": metric_record})
return Message(content=content, reply_to=msg)
As you can see the evaluate
implementation is near identical to the train
implementation, except that it calls the test_fn
function instead of the
train_fn
function and it returns a Message
containing only a MetricRecord
with metrics relevant to evaluation (eval_loss
, eval_acc
-- both scalars). We
also need to include the num-examples
key in the metrics so the server can aggregate
the evaluation metrics correctly.
Define the Flower ServerApp¶
On the server side, we need to configure a strategy which encapsulates the federated
learning approach/algorithm, for example, Federated Averaging (FedAvg). Flower has a
number of built-in strategies, but we can also use our own strategy implementations to
customize nearly all aspects of the federated learning approach. For this tutorial, we
use the built-in FedAvg
implementation and customize it slightly by specifying the
fraction of connected nodes to involve in a round of training.
To construct a ServerApp
, we define its @app.main()
method. This method
receives as input arguments:
a
Grid
object that will be used to interface with the nodes running theClientApp
to involve them in a round of train/evaluate/query or other.a
Context
object that provides access to the run configuration.
Before launching the strategy via the start
method, we want to
initialize the global model. This will be the model that gets sent to the ClientApp
running on the clients in the first round of federated learning. We can do this by
creating an instance of the model (Net
), extracting the parameters in its
state_dict
, and constructing an ArrayRecord
with them. We can then make it
available to the strategy via the initial_arrays
argument of the start()
method.
We can also optionally pass to the start()
method a ConfigRecord
containing
settings that we would like to communicate to the clients. These will be sent as part of
the Message
that also carries the model parameters.
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
"""Main entry point for the ServerApp."""
# Read run config
fraction_train: float = context.run_config["fraction-train"]
num_rounds: int = context.run_config["num-server-rounds"]
lr: float = context.run_config["lr"]
# Load global model
global_model = Net()
arrays = ArrayRecord(global_model.state_dict())
# Initialize FedAvg strategy
strategy = FedAvg(fraction_train=fraction_train)
# Start strategy, run FedAvg for `num_rounds`
result = strategy.start(
grid=grid,
initial_arrays=arrays,
train_config=ConfigRecord({"lr": lr}),
num_rounds=num_rounds,
)
# Save final model to disk
print("\nSaving final model to disk...")
state_dict = result.arrays.to_torch_state_dict()
torch.save(state_dict, "final_model.pt")
Most of the execution of the ServerApp
happens inside the strategy.start()
method. After the specified number of rounds (num_rounds
), the start()
method
returns a Result
object containing the final model parameters and metrics
received from the clients or generated by the strategy itself. We can then save the
final model to disk for later use.
Run the training¶
With all of these components in place, we can now run the federated learning simulation with Flower! The last step is to run our simulation in the command line, as follows:
$ flwr run .
This will execute the federated learning simulation with 10 clients, or SuperNodes,
defined in the [tool.flwr.federations.local-simulation]
section in the
pyproject.toml
. You should expect an output log similar to this:
Loading project configuration...
Success
INFO : Starting FedAvg strategy:
INFO : ├── Number of rounds: 3
INFO : ├── ArrayRecord (0.24 MB)
INFO : ├── ConfigRecord (train): {'lr': 0.01}
INFO : ├── ConfigRecord (evaluate): (empty!)
INFO : ├──> Sampling:
INFO : │ ├──Fraction: train (0.50) | evaluate ( 1.00)
INFO : │ ├──Minimum nodes: train (2) | evaluate (2)
INFO : │ └──Minimum available nodes: 2
INFO : └──> Keys in records:
INFO : ├── Weighted by: 'num-examples'
INFO : ├── ArrayRecord key: 'arrays'
INFO : └── ConfigRecord key: 'config'
INFO :
INFO :
INFO : [ROUND 1/3]
INFO : configure_train: Sampled 5 nodes (out of 10)
INFO : aggregate_train: Received 5 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'train_loss': 2.25811}
INFO : configure_evaluate: Sampled 10 nodes (out of 10)
INFO : aggregate_evaluate: Received 10 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'eval_loss': 2.304821, 'eval_acc': 0.0965}
INFO :
INFO : [ROUND 2/3]
INFO : configure_train: Sampled 5 nodes (out of 10)
INFO : aggregate_train: Received 5 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'train_loss': 2.17333}
INFO : configure_evaluate: Sampled 10 nodes (out of 10)
INFO : aggregate_evaluate: Received 10 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'eval_loss': 2.304577, 'eval_acc': 0.10030}
INFO :
INFO : [ROUND 3/3]
INFO : configure_train: Sampled 5 nodes (out of 10)
INFO : aggregate_train: Received 5 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'train_loss': 2.16953}
INFO : configure_evaluate: Sampled 10 nodes (out of 10)
INFO : aggregate_evaluate: Received 10 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'eval_loss': 2.29976, 'eval_acc': 0.1015}
INFO :
INFO : Strategy execution finished in 17.18s
INFO :
INFO : Final results:
INFO :
INFO : Global Arrays:
INFO : ArrayRecord (0.238 MB)
INFO :
INFO : Aggregated ClientApp-side Train Metrics:
INFO : { 1: {'train_loss': '2.2581e+00'},
INFO : 2: {'train_loss': '2.1733e+00'},
INFO : 3: {'train_loss': '2.1695e+00'}}
INFO :
INFO : Aggregated ClientApp-side Evaluate Metrics:
INFO : { 1: {'eval_acc': '9.6500e-02', 'eval_loss': '2.3048e+00'},
INFO : 2: {'eval_acc': '1.0030e-01', 'eval_loss': '2.3046e+00'},
INFO : 3: {'eval_acc': '1.0150e-01', 'eval_loss': '2.2998e+00'}}
INFO :
INFO : ServerApp-side Evaluate Metrics:
INFO : {}
INFO :
Saving final model to disk...
You can also override the parameters defined in the [tool.flwr.app.config]
section
in pyproject.toml
like this:
# Run the simulation with 5 server rounds and 3 local epochs
$ flwr run . --run-config "num-server-rounds=5 local-epochs=3"
Tip
Learn more about how to configure the execution of your Flower App by checking the pyproject.toml guide.
幕后¶
那么它是如何工作的呢?Flower 如何进行模拟?
When we execute flwr run
, we tell Flower that there are 10 clients
(options.num-supernodes = 10
, where each SuperNode launches one ClientApp
).
Flower then asks the ServerApp
to issue instructions to those nodes using the
FedAvg
strategy. In this example, FedAvg
is configured with two key parameters:
fraction-train=0.5
→ select 50% of the available clients for trainingfraction-evaluate=1.0
→ select 100% of the available clients for evaluation
This means in our example, 5 out of 10 clients will be selected for training, and all 10 clients will later participate in evaluation.
A typical round looks like this:
Training
FedAvg
randomly selects 5 clients (50% of 10).Flower sends a
TRAIN
message to each selectedClientApp
.Each
ClientApp
calls the function decorated with@app.train()
, then returns aMessage
containing anArrayRecord
(the updated model parameters) and aMetricRecord
(the training loss and number of examples).The
ServerApp
receives all replies.FedAvg
aggregates allArrayRecord
into a newArrayRecord
representing the new global model and combines allMetricRecord
.
Evaluation
FedAvg
selects all 10 clients (100%).Flower sends an
EVALUATE
message to eachClientApp
.Each
ClientApp
calls the function decorated with@app.evaluate()
and returns aMessage
containing aMetricRecord
(the evaluation loss, accuracy, and number of examples).The
ServerApp
receives all replies.FedAvg
aggregates allMetricRecord
.
Once both training and evaluation are done, the next round begins: another training step, then another evaluation step, and so on, until the configured number of rounds is reached.
结束语¶
Congratulations, you just trained a convolutional neural network, federated over 10 clients! With that, you understand the basics of federated learning with Flower. The same approach you've seen can be used with other machine learning frameworks (not just PyTorch) and tasks (not just CIFAR-10 image classification), for example NLP with Hugging Face Transformers or speech with SpeechBrain.
In the next tutorial, we're going to cover some more advanced concepts. Want to customize your strategy? Do learning rate decay at the strategy and communicate it to the clients ? Or evaluate the aggregated model on the server side? We'll cover all this and more in the next tutorial.
接下来的步骤¶
Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).
There's a dedicated #questions
Slack channel if you need help, but we'd also love to
hear who you are in #introductions
!
The Flower Federated Learning Tutorial - Part 2 goes into more depth about strategies and all the advanced things you can build with them.