Quickstart MXNet#

In this tutorial, we will learn how to train a Sequential model on MNIST using Flower and MXNet.

It is recommended to create a virtual environment and run everything within this virtualenv.

我们的例子包括一个*服务器*和两个*客户端*,它们都有相同的模型。

客户端*负责根据其本地数据集为模型生成单独的模型参数更新。然后,这些参数更新将被发送到*服务器,由*服务器*汇总后生成一个更新的全局模型。最后,服务器*将这一改进版模型发回给每个*客户端。一个完整的参数更新周期称为一*轮*。

现在,我们已经有了一个大致的概念,让我们开始吧。首先,我们需要安装 Flower。运行:

$ pip install flwr

Since we want to use MXNet, let's go ahead and install it:

$ pip install mxnet

Flower 客户端#

Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on MXNet´s Hand-written Digit Recognition tutorial.

In a file called client.py, import Flower and MXNet related packages:

import flwr as fl

import numpy as np

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F

In addition, define the device allocation in MXNet with:

DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

We use MXNet to load MNIST, a popular image classification dataset of handwritten digits for machine learning. The MXNet utility mx.test_utils.get_mnist() downloads the training and test data.

def load_data():
    print("Download Dataset")
    mnist = mx.test_utils.get_mnist()
    batch_size = 100
    train_data = mx.io.NDArrayIter(
        mnist["train_data"], mnist["train_label"], batch_size, shuffle=True
    )
    val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
    return train_data, val_data

Define the training and loss with MXNet. We train the model by looping over the dataset, measure the corresponding loss, and optimize it.

def train(net, train_data, epoch):
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.01})
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    for i in range(epoch):
        train_data.reset()
        num_examples = 0
        for batch in train_data:
            data = gluon.utils.split_and_load(
                batch.data[0], ctx_list=DEVICE, batch_axis=0
            )
            label = gluon.utils.split_and_load(
                batch.label[0], ctx_list=DEVICE, batch_axis=0
            )
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    loss = softmax_cross_entropy_loss(z, y)
                    loss.backward()
                    outputs.append(z.softmax())
                    num_examples += len(x)
            metrics.update(label, outputs)
            trainer.step(batch.data[0].shape[0])
        trainings_metric = metrics.get_name_value()
        print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric))
    return trainings_metric, num_examples

Next, we define the validation of our machine learning model. We loop over the test set and measure both loss and accuracy on the test set.

def test(net, val_data):
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    val_data.reset()
    num_examples = 0
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=DEVICE, batch_axis=0)
        label = gluon.utils.split_and_load(
            batch.label[0], ctx_list=DEVICE, batch_axis=0
        )
        outputs = []
        for x in data:
            outputs.append(net(x).softmax())
            num_examples += len(x)
        metrics.update(label, outputs)
    return metrics.get_name_value(), num_examples

After defining the training and testing of a MXNet machine learning model, we use these functions to implement a Flower client.

Our Flower clients will use a simple Sequential model:

def main():
    def model():
        net = nn.Sequential()
        net.add(nn.Dense(256, activation="relu"))
        net.add(nn.Dense(64, activation="relu"))
        net.add(nn.Dense(10))
        net.collect_params().initialize()
        return net

    train_data, val_data = load_data()

    model = model()
    init = nd.random.uniform(shape=(2, 784))
    model(init)

After loading the dataset with load_data() we perform one forward propagation to initialize the model and model parameters with model(init). Next, we implement a Flower client.

Flower 服务器通过一个名为 Client 的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client 方法之一来运行您的代码(即训练我们之前定义的神经网络)。

Flower provides a convenience class called NumPyClient which makes it easier to implement the Client interface when your workload uses MXNet. Implementing NumPyClient usually means defining the following methods (set_parameters is optional though):

  1. get_parameters
    • 以 NumPy ndarrays 列表形式返回模型参数

  2. set_parameters (可选)
    • 用从服务器接收到的参数更新本地模型参数

  3. fit
    • 设置本地模型参数

    • 训练本地模型

    • 接收更新的本地模型参数

  4. evaluate
    • 测试本地模型

They can be implemented in the following way:

class MNISTClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        param = []
        for val in model.collect_params(".*weight").values():
            p = val.data()
            param.append(p.asnumpy())
        return param

    def set_parameters(self, parameters):
        params = zip(model.collect_params(".*weight").keys(), parameters)
        for key, value in params:
            model.collect_params().setattr(key, value)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = train(model, train_data, epoch=2)
        results = {"accuracy": float(accuracy[1]), "loss": float(loss[1])}
        return self.get_parameters(config={}), num_examples, results

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = test(model, val_data)
        print("Evaluation accuracy & loss", accuracy, loss)
        return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])}

We can now create an instance of our class MNISTClient and add one line to actually run this client:

fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MNISTClient())

That's it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client() or fl.client.start_numpy_client(). The string "0.0.0.0:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "0.0.0.0:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we pass to the client.

Flower 服务器#

对于简单的工作负载,我们可以启动 Flower 服务器,并将所有配置选项保留为默认值。在名为 server.py 的文件中,导入 Flower 并启动服务器:

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

联邦训练模型!#

With both client and server ready, we can now run everything and see federated learning in action. Federated learning systems usually have a server and multiple clients. We therefore have to start the server first:

$ python server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python client.py

打开另一台终端,启动第二个客户端:

$ python client.py

每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:

INFO flower 2021-03-11 11:59:04,512 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-03-11 11:59:04,512 | server.py:72 | Getting initial parameters
INFO flower 2021-03-11 11:59:09,089 | server.py:74 | Evaluating initial parameters
INFO flower 2021-03-11 11:59:09,089 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-03-11 11:59:11,997 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:14,652 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:14,656 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:14,811 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:14,812 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:18,499 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:18,503 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:18,784 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:18,786 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:22,551 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:22,555 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:22,789 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-03-11 11:59:22,789 | server.py:122 | [TIME] FL finished in 13.700094900001204
INFO flower 2021-03-11 11:59:22,790 | app.py:109 | app_fit: losses_distributed [(1, 1.5717803835868835), (2, 0.6093432009220123), (3, 0.4424773305654526)]
INFO flower 2021-03-11 11:59:22,790 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-03-11 11:59:22,791 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-03-11 11:59:22,791 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-03-11 11:59:22,793 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:23,111 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-03-11 11:59:23,112 | app.py:121 | app_evaluate: federated loss: 0.4424773305654526
INFO flower 2021-03-11 11:59:23,112 | app.py:125 | app_evaluate: results [('ipv4:127.0.0.1:44344', EvaluateRes(loss=0.443320095539093, num_examples=100, accuracy=0.0, metrics={'accuracy': 0.8752475247524752})), ('ipv4:127.0.0.1:44346', EvaluateRes(loss=0.44163456559181213, num_examples=100, accuracy=0.0, metrics={'accuracy': 0.8761386138613861}))]
INFO flower 2021-03-11 11:59:23,112 | app.py:127 | app_evaluate: failures []

Congratulations! You've successfully built and run your first federated learning system. The full source code for this example can be found in examples/quickstart-mxnet.