scikit-learn快速入门#

在本教程中,我们将学习如何使用 Flower 和 scikit-learn 在 MNIST 上训练一个 Logistic Regression 模型。

建议创建一个虚拟环境,并在此 virtualenv 中运行所有内容。

我们的例子包括一个*服务器*和两个*客户端*,它们都有相同的模型。

客户端*负责根据其本地数据集为模型生成单独的模型参数更新。然后,这些参数更新将被发送到*服务器,由*服务器*汇总后生成一个更新的全局模型。最后,服务器*将这一改进版模型发回给每个*客户端。一个完整的参数更新周期称为一*轮*。

现在,我们已经有了一个大致的概念,让我们开始吧。首先,我们需要安装 Flower。运行:

$ pip install flwr

既然我们要使用 scikt-learn,那就继续安装吧:

$ pip install scikit-learn

或者直接使用 Poetry 安装所有依赖项:

$ poetry install

Flower 客户端#

现在我们已经安装了所有的依赖项,让我们用两个客户端和一个服务器来运行一个简单的分布式训练。不过,在设置客户端和服务器之前,我们将在 utils.py 中定义联邦学习设置所需的所有功能。:code:`utils.py`包含定义所有机器学习基础知识的不同函数:

  • get_model_parameters()
    • 返回 sklearn LogisticRegression 模型的参数

  • set_model_params()
    • 设置:code:`sklean`的LogisticRegression模型的参数

  • set_initial_params()
    • 初始化 Flower 服务器将要求的模型参数

  • load_mnist()
    • 使用 OpenML 加载 MNIST 数据集

  • shuffle()
    • 对数据及其标签进行洗牌

  • partition()
    • 将数据集分割成多个分区

更多详情请查看 utils.py` 这里 <https://github.com/adap/flower/blob/main/examples/sklearn-logreg-mnist/utils.py>`_。在 client.py 中使用并导入了预定义函数。client.py 还需要导入几个软件包,如 Flower 和 scikit-learn:

import warnings
import flwr as fl
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import utils

我们从 OpenML 中加载 MNIST 数据集,这是一个用于机器学习的流行手写数字图像分类数据集。实用程序 utils.load_mnist() 下载训练和测试数据。然后使用 :code:`utils.partition()`将训练集分割成 10 个分区。

if __name__ == "__main__":

    (X_train, y_train), (X_test, y_test) = utils.load_mnist()

    partition_id = np.random.choice(10)
    (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]

接下来,使用 utils.set_initial_params() 对逻辑回归模型进行定义和初始化。

model = LogisticRegression(
    penalty="l2",
    max_iter=1,  # local epoch
    warm_start=True,  # prevent refreshing weights when fitting
)

utils.set_initial_params(model)

Flower 服务器通过一个名为 Client 的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client 方法之一来运行您的代码(即拟合我们之前定义的逻辑回归)。

Flower 提供了一个名为 NumPyClient 的便捷类,当你的工作负载使用 scikit-learn 时,它可以让你更容易地实现 Client 接口。实现 NumPyClient 通常意味着定义以下方法(set_parameters 是可选的):

  1. get_parameters
    • 以 NumPy ndarrays 列表形式返回模型参数

  2. set_parameters (可选)
    • 用从服务器接收到的参数更新本地模型参数

    • 直接导入 utils.set_model_params()

  3. fit
    • 设置本地模型参数

    • 训练本地模型

    • 接收更新的本地模型参数

  4. evaluate
    • 测试本地模型

这些方法可以通过以下方式实现:

class MnistClient(fl.client.NumPyClient):
    def get_parameters(self, config):  # type: ignore
        return utils.get_model_parameters(model)

    def fit(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model.fit(X_train, y_train)
        print(f"Training finished for round {config['server_round']}")
        return utils.get_model_parameters(model), len(X_train), {}

    def evaluate(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, len(X_test), {"accuracy": accuracy}

现在我们可以创建一个 MnistClient 类的实例,并添加一行来实际运行该客户端:

fl.client.start_client("0.0.0.0:8080", client=MnistClient().to_client())

That's it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client(). If you implement a client of type NumPyClient you'll need to first call its to_client() method. The string "0.0.0.0:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "0.0.0.0:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we pass to the client.

Flower 服务器#

下面的 Flower 服务器更先进一些,会返回一个用于服务器端评估的评估函数。首先,我们再次导入所有需要的库,如 Flower 和 scikit-learn。

server.py, 导入 Flower 并启动服务器:

import flwr as fl
import utils
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict, Optional

联邦学习轮数在 fit_round() 中设置,评估在 get_evaluate_fn() 中定义。每轮联邦学习后都会调用评估函数,并提供有关损失值和准确率的信息。

def fit_round(server_round: int) -> Dict:
    """Send round number to client."""
    return {"server_round": server_round}


def get_evaluate_fn(model: LogisticRegression):
    """Return an evaluation function for server-side evaluation."""

    _, (X_test, y_test) = utils.load_mnist()

    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, {"accuracy": accuracy}

    return evaluate

main`包含服务器端参数初始化:code:`utils.set_initial_params()`以及聚合策略 :code:`fl.server.strategy:FedAvg()。该策略是默认的联邦平均(或 FedAvg)策略,有两个客户端,在每轮联邦学习后进行评估。可以使用 fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3)) 命令启动服务器。

# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
    model = LogisticRegression()
    utils.set_initial_params(model)
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(model),
        on_fit_config_fn=fit_round,
    )
    fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3))

联邦训练模型!#

客户端和服务器都准备就绪后,我们现在就可以运行一切,看看联邦学习的运行情况。联邦学习系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:

$ python3 server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python3 client.py

打开另一台终端,启动第二个客户端:

$ python3 client.py

每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:

INFO flower 2022-01-13 13:43:14,859 | app.py:73 | Flower server running (insecure, 3 rounds)
INFO flower 2022-01-13 13:43:14,859 | server.py:118 | Getting initial parameters
INFO flower 2022-01-13 13:43:17,903 | server.py:306 | Received initial parameters from one random client
INFO flower 2022-01-13 13:43:17,903 | server.py:120 | Evaluating initial parameters
INFO flower 2022-01-13 13:43:17,992 | server.py:123 | initial parameters (loss, other metrics): 2.3025850929940455, {'accuracy': 0.098}
INFO flower 2022-01-13 13:43:17,992 | server.py:133 | FL starting
DEBUG flower 2022-01-13 13:43:19,814 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,046 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,220 | server.py:148 | fit progress: (1, 1.3365667871792377, {'accuracy': 0.6605}, 2.227397900000142)
INFO flower 2022-01-13 13:43:20,220 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,220 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,456 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,603 | server.py:148 | fit progress: (2, 0.721620492535375, {'accuracy': 0.7796}, 2.6108531999998377)
INFO flower 2022-01-13 13:43:20,603 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,603 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,837 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,967 | server.py:148 | fit progress: (3, 0.5843629244915138, {'accuracy': 0.8217}, 2.9750180000010005)
INFO flower 2022-01-13 13:43:20,968 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2022-01-13 13:43:20,968 | server.py:172 | FL finished in 2.975252800000817
INFO flower 2022-01-13 13:43:20,968 | app.py:109 | app_fit: losses_distributed []
INFO flower 2022-01-13 13:43:20,968 | app.py:110 | app_fit: metrics_distributed {}
INFO flower 2022-01-13 13:43:20,968 | app.py:111 | app_fit: losses_centralized [(0, 2.3025850929940455), (1, 1.3365667871792377), (2, 0.721620492535375), (3, 0.5843629244915138)]
INFO flower 2022-01-13 13:43:20,968 | app.py:112 | app_fit: metrics_centralized {'accuracy': [(0, 0.098), (1, 0.6605), (2, 0.7796), (3, 0.8217)]}
DEBUG flower 2022-01-13 13:43:20,968 | server.py:201 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:21,232 | server.py:210 | evaluate_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:21,232 | app.py:121 | app_evaluate: federated loss: 0.5843629240989685
INFO flower 2022-01-13 13:43:21,232 | app.py:122 | app_evaluate: results [('ipv4:127.0.0.1:53980', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217})), ('ipv4:127.0.0.1:53982', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217}))]
INFO flower 2022-01-13 13:43:21,232 | app.py:127 | app_evaluate: failures []

恭喜您!您已经成功构建并运行了第一个联邦学习系统。本示例的`完整源代码 <https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist>`_ 可以在 examples/sklearn-logreg-mnist 中找到。