快速入门 TensorFlow#

让我们用不到 20 行代码构建一个联邦学习系统!

在导入 Flower 之前,我们必须先安装它:

$ pip install flwr

由于我们要使用 TensorFlow (TF) 的 Keras API,因此还必须安装 TF:

$ pip install tensorflow

Flower 客户端#

接下来,在名为 client.py 的文件中导入 Flower 和 TensorFlow:

import flwr as fl
import tensorflow as tf

我们使用 TF 的 Keras 实用程序加载 CIFAR10,这是一个用于机器学习的流行彩色图像分类数据集。调用 tf.keras.datasets.cifar10.load_data() 会下载 CIFAR10,将其缓存到本地,然后以 NumPy ndarrays 的形式返回整个训练集和测试集。

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

接下来,我们需要一个模型。在本教程中,我们使用带有 10 个输出类的 MobilNetV2:

model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

Flower 服务器通过一个名为 Client 的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client 方法之一来运行您的代码(即训练我们之前定义的神经网络)。

Flower 提供了一个名为 NumPyClient 的便捷类,当您的工作负载使用 Keras 时,该类可以更轻松地实现 Client 接口。NumPyClient 接口定义了三个方法,可以通过以下方式实现:

class CifarClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)
        return model.get_weights(), len(x_train), {}

    def evaluate(self, parameters, config):
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(x_test, y_test)
        return loss, len(x_test), {"accuracy": float(accuracy)}

现在我们可以创建一个 CifarClient 类的实例,并添加一行来实际运行该客户端:

fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client())

That's it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client(). If you implement a client of type NumPyClient you'll need to first call its to_client() method. The string "[::]:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "[::]:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we point the client at.

Flower 服务器#

对于简单的工作负载,我们可以启动 Flower 服务器,并将所有配置选项保留为默认值。在名为 server.py 的文件中,导入 Flower 并启动服务器:

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

联邦训练模型!#

客户端和服务器都已准备就绪,我们现在可以运行一切,看看联邦学习的实际效果。FL 系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:

$ python server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python client.py

打开另一台终端,启动第二个客户端:

$ python client.py

每个客户都有自己的数据集。

现在你应该能在第一个终端(启动服务器的终端)看到训练的效果了:

INFO flower 2021-02-25 14:15:46,741 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:15:46,742 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:16:01,770 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:16:01,770 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:16:12,341 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:21:17,235 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:21:17,512 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:21:29,628 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:21:29,696 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:25:59,917 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:26:00,227 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:26:11,457 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:26:11,530 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:30:43,389 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:30:43,630 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:30:53,384 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:30:53,384 | server.py:122 | [TIME] FL finished in 891.6143046000007
INFO flower 2021-02-25 14:30:53,385 | app.py:109 | app_fit: losses_distributed [(1, 2.3196680545806885), (2, 2.3202896118164062), (3, 2.1818180084228516)]
INFO flower 2021-02-25 14:30:53,385 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:30:53,385 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:30:53,385 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:30:53,442 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:31:02,848 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:31:02,848 | app.py:121 | app_evaluate: federated loss: 2.1818180084228516
INFO flower 2021-02-25 14:31:02,848 | app.py:125 | app_evaluate: results [('ipv4:127.0.0.1:57158', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852})), ('ipv4:127.0.0.1:57160', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852}))]
INFO flower 2021-02-25 14:31:02,848 | app.py:127 | app_evaluate: failures [] flower 2020-07-15 10:07:56,396 | app.py:77 | app_evaluate: failures []

恭喜您!您已经成功构建并运行了第一个联邦学习系统。完整的源代码 可以在 examples/quickstart-tensorflow/client.py 中找到。