scikit-learn快速入门#

在本教程中,我们将学习如何使用 Flower 和 scikit-learn 在 MNIST 上训练一个 Logistic Regression 模型。

It is recommended to create a virtual environment and run everything within this virtualenv.

我们的例子包括一个*服务器*和两个*客户端*,它们都有相同的模型。

客户端*负责根据其本地数据集为模型生成单独的模型参数更新。然后,这些参数更新将被发送到*服务器,由*服务器*汇总后生成一个更新的全局模型。最后,服务器*将这一改进版模型发回给每个*客户端。一个完整的参数更新周期称为一*轮*。

现在,我们已经有了一个大致的概念,让我们开始吧。首先,我们需要安装 Flower。运行:

$ pip install flwr

Since we want to use scikit-learn, let's go ahead and install it:

$ pip install scikit-learn

或者直接使用 Poetry 安装所有依赖项:

$ poetry install

Flower 客户端#

现在我们已经安装了所有的依赖项,让我们用两个客户端和一个服务器来运行一个简单的分布式训练。不过,在设置客户端和服务器之前,我们将在 utils.py 中定义联邦学习设置所需的所有功能。:code:`utils.py`包含定义所有机器学习基础知识的不同函数:

  • get_model_parameters()
    • 返回 sklearn LogisticRegression 模型的参数

  • set_model_params()
    • Sets the parameters of a sklearn LogisticRegression model

  • set_initial_params()
    • 初始化 Flower 服务器将要求的模型参数

更多详情请查看 utils.py` 这里 <https://github.com/adap/flower/blob/main/examples/sklearn-logreg-mnist/utils.py>`_。在 client.py 中使用并导入了预定义函数。client.py 还需要导入几个软件包,如 Flower 和 scikit-learn:

import argparse
import warnings

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import flwr as fl
import utils
from flwr_datasets import FederatedDataset

Prior to local training, we need to load the MNIST dataset, a popular image classification dataset of handwritten digits for machine learning, and partition the dataset for FL. This can be conveniently achieved using Flower Datasets. The FederatedDataset.load_partition() method loads the partitioned training set for each partition ID defined in the --partition-id argument.

if __name__ == "__main__":
    N_CLIENTS = 10

    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--partition-id",
        type=int,
        choices=range(0, N_CLIENTS),
        required=True,
        help="Specifies the artificial data partition",
    )
    args = parser.parse_args()
    partition_id = args.partition_id

    fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})

    dataset = fds.load_partition(partition_id, "train").with_format("numpy")
    X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

    X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
    y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

接下来,使用 utils.set_initial_params() 对逻辑回归模型进行定义和初始化。

model = LogisticRegression(
    penalty="l2",
    max_iter=1,  # local epoch
    warm_start=True,  # prevent refreshing weights when fitting
)

utils.set_initial_params(model)

Flower 服务器通过一个名为 Client 的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client 方法之一来运行您的代码(即拟合我们之前定义的逻辑回归)。

Flower 提供了一个名为 NumPyClient 的便捷类,当你的工作负载使用 scikit-learn 时,它可以让你更容易地实现 Client 接口。实现 NumPyClient 通常意味着定义以下方法(set_parameters 是可选的):

  1. get_parameters
    • 以 NumPy ndarrays 列表形式返回模型参数

  2. set_parameters (可选)
    • 用从服务器接收到的参数更新本地模型参数

    • 直接导入 utils.set_model_params()

  3. fit
    • 设置本地模型参数

    • 训练本地模型

    • 接收更新的本地模型参数

  4. evaluate
    • 测试本地模型

这些方法可以通过以下方式实现:

class MnistClient(fl.client.NumPyClient):
    def get_parameters(self, config):  # type: ignore
        return utils.get_model_parameters(model)

    def fit(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model.fit(X_train, y_train)
        print(f"Training finished for round {config['server_round']}")
        return utils.get_model_parameters(model), len(X_train), {}

    def evaluate(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, len(X_test), {"accuracy": accuracy}

现在我们可以创建一个 MnistClient 类的实例,并添加一行来实际运行该客户端:

fl.client.start_client("0.0.0.0:8080", client=MnistClient().to_client())

That's it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client(). If you implement a client of type NumPyClient you'll need to first call its to_client() method. The string "0.0.0.0:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "0.0.0.0:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we pass to the client.

Flower 服务器#

下面的 Flower 服务器更先进一些,会返回一个用于服务器端评估的评估函数。首先,我们再次导入所有需要的库,如 Flower 和 scikit-learn。

server.py, 导入 Flower 并启动服务器:

import flwr as fl
import utils
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict

from flwr_datasets import FederatedDataset

The number of federated learning rounds is set in fit_round() and the evaluation is defined in get_evaluate_fn(). The evaluation function is called after each federated learning round and gives you information about loss and accuracy. Note that we also make use of Flower Datasets here to load the test split of the MNIST dataset for server-side evaluation.

def fit_round(server_round: int) -> Dict:
    """Send round number to client."""
    return {"server_round": server_round}


def get_evaluate_fn(model: LogisticRegression):
    """Return an evaluation function for server-side evaluation."""

    fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
    dataset = fds.load_split("test").with_format("numpy")
    X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, {"accuracy": accuracy}

    return evaluate

main`包含服务器端参数初始化:code:`utils.set_initial_params()`以及聚合策略 :code:`fl.server.strategy:FedAvg()。该策略是默认的联邦平均(或 FedAvg)策略,有两个客户端,在每轮联邦学习后进行评估。可以使用 fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3)) 命令启动服务器。

# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
    model = LogisticRegression()
    utils.set_initial_params(model)
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(model),
        on_fit_config_fn=fit_round,
    )
    fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3))

联邦训练模型!#

客户端和服务器都准备就绪后,我们现在就可以运行一切,看看联邦学习的运行情况。联邦学习系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:

$ python3 server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python3 client.py

打开另一台终端,启动第二个客户端:

$ python3 client.py

每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:

INFO flower 2022-01-13 13:43:14,859 | app.py:73 | Flower server running (insecure, 3 rounds)
INFO flower 2022-01-13 13:43:14,859 | server.py:118 | Getting initial parameters
INFO flower 2022-01-13 13:43:17,903 | server.py:306 | Received initial parameters from one random client
INFO flower 2022-01-13 13:43:17,903 | server.py:120 | Evaluating initial parameters
INFO flower 2022-01-13 13:43:17,992 | server.py:123 | initial parameters (loss, other metrics): 2.3025850929940455, {'accuracy': 0.098}
INFO flower 2022-01-13 13:43:17,992 | server.py:133 | FL starting
DEBUG flower 2022-01-13 13:43:19,814 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,046 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,220 | server.py:148 | fit progress: (1, 1.3365667871792377, {'accuracy': 0.6605}, 2.227397900000142)
INFO flower 2022-01-13 13:43:20,220 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,220 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,456 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,603 | server.py:148 | fit progress: (2, 0.721620492535375, {'accuracy': 0.7796}, 2.6108531999998377)
INFO flower 2022-01-13 13:43:20,603 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,603 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,837 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,967 | server.py:148 | fit progress: (3, 0.5843629244915138, {'accuracy': 0.8217}, 2.9750180000010005)
INFO flower 2022-01-13 13:43:20,968 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2022-01-13 13:43:20,968 | server.py:172 | FL finished in 2.975252800000817
INFO flower 2022-01-13 13:43:20,968 | app.py:109 | app_fit: losses_distributed []
INFO flower 2022-01-13 13:43:20,968 | app.py:110 | app_fit: metrics_distributed {}
INFO flower 2022-01-13 13:43:20,968 | app.py:111 | app_fit: losses_centralized [(0, 2.3025850929940455), (1, 1.3365667871792377), (2, 0.721620492535375), (3, 0.5843629244915138)]
INFO flower 2022-01-13 13:43:20,968 | app.py:112 | app_fit: metrics_centralized {'accuracy': [(0, 0.098), (1, 0.6605), (2, 0.7796), (3, 0.8217)]}
DEBUG flower 2022-01-13 13:43:20,968 | server.py:201 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:21,232 | server.py:210 | evaluate_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:21,232 | app.py:121 | app_evaluate: federated loss: 0.5843629240989685
INFO flower 2022-01-13 13:43:21,232 | app.py:122 | app_evaluate: results [('ipv4:127.0.0.1:53980', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217})), ('ipv4:127.0.0.1:53982', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217}))]
INFO flower 2022-01-13 13:43:21,232 | app.py:127 | app_evaluate: failures []

恭喜您!您已经成功构建并运行了第一个联邦学习系统。本示例的`完整源代码 <https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist>`_ 可以在 examples/sklearn-logreg-mnist 中找到。