XGBoost快速入门#

联邦化 XGBoost#

EXtreme Gradient Boosting(XGBoost)是梯度提升决策树(GBDT)的一种稳健而高效的实现方法,能最大限度地提高提升树方法的计算边界。它主要用于提高机器学习模型的性能和计算速度。在 XGBoost 中,决策树是并发构建的,与 GBDT 采用的顺序方法不同。

对于训练示例少于 10k 的中型数据集上的表格数据,XGBoost 的结果往往超过深度学习技术。

为什么选择联邦 XGBoost?#

事实上,随着对数据隐私和分散学习的需求不断增长,越来越多的专业应用(如生存分析和金融欺诈检测)需要实施联邦 XGBoost 系统。

联邦学习可确保原始数据保留在本地设备上,因此对于数据安全和隐私至关重要的敏感领域来说,这是一种极具吸引力的方法。鉴于 XGBoost 的稳健性和高效性,将其与联邦学习相结合为应对这些特定挑战提供了一种前景广阔的解决方案。

在本教程中,我们将学习如何使用 Flower 和 xgboost 软件包在 HIGGS 数据集上训练联邦 XGBoost 模型。我们将使用一个包含两个 * 客户端* 和一个 * 服务器* 的简单示例 (完整代码 xgboost-quickstart)来演示联邦 XGBoost 如何工作,然后我们将深入到一个更复杂的示例 (完整代码 xgboost-comprehensive),以运行各种实验。

环境设定#

First of all, it is recommended to create a virtual environment and run everything within a virtualenv.

我们首先需要安装 Flower 和 Flower Datasets。您可以通过运行 :

$ pip install flwr flwr-datasets

既然我们要使用 xgboost 软件包来构建 XGBoost 树,那就继续安装 xgboost

$ pip install xgboost

Flower 客户端#

*客户端*负责根据其本地数据集为模型生成单独的模型参数更新。现在我们已经安装了所有的依赖项,让我们用两个客户端和一个服务器来运行一个简单的分布式训练。

在名为 client.py 的文件中,导入 xgboost、Flower、Flower Datasets 和其他相关函数:

import argparse
from typing import Union
from logging import INFO
from datasets import Dataset, DatasetDict
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Parameters,
    Status,
)
from flwr_datasets.partitioner import IidPartitioner

数据集划分和超参数选择#

在本地训练之前,我们需要从 Flower Datasets 加载 HIGGS 数据集,并对 FL 进行数据分区:

# Load (HIGGS) dataset and conduct partitioning
# We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process.
partitioner = IidPartitioner(num_partitions=30)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})

# Load the partition for this `node_id`
partition = fds.load_partition(node_id=args.node_id, split="train")
partition.set_format("numpy")

在此示例中,我们将数据集分割成两个均匀分布的分区(IidPartitioner(num_partitions=2))。然后,我们根据 node_id 为给定客户端加载分区:

# We first define arguments parser for user to specify the client/node ID.
parser = argparse.ArgumentParser()
parser.add_argument(
    "--node-id",
    default=0,
    type=int,
    help="Node ID used for the current client.",
)
args = parser.parse_args()

# Load the partition for this `node_id`.
partition = fds.load_partition(idx=args.node_id, split="train")
partition.set_format("numpy")

然后,我们在给定的分区(客户端的本地数据)上进行训练/测试分割,并为 xgboost 软件包转换数据格式。

# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
    partition, test_fraction=0.2, seed=42
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)

train_test_splittransform_dataset_too_dmatrix 的函数定义如下:

# Define data partitioning related functions
def train_test_split(partition: Dataset, test_fraction: float, seed: int):
    """Split the data into train and validation set given split rate."""
    train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
    partition_train = train_test["train"]
    partition_test = train_test["test"]

    num_train = len(partition_train)
    num_test = len(partition_test)

    return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
    """Transform dataset to DMatrix format for xgboost."""
    x = data["inputs"]
    y = data["label"]
    new_data = xgb.DMatrix(x, label=y)
    return new_data

最后,我们定义了用于 XGBoost 训练的超参数。

num_local_round = 1
params = {
    "objective": "binary:logistic",
    "eta": 0.1,  # lr
    "max_depth": 8,
    "eval_metric": "auc",
    "nthread": 16,
    "num_parallel_tree": 1,
    "subsample": 1,
    "tree_method": "hist",
}

代码:num_local_round`表示本地树的迭代次数。我们默认使用 CPU 进行训练。可以通过将 :code:`tree_method 设置为 gpu_hist,将其转换为 GPU。我们使用 AUC 作为评估指标。

用于 XGBoost 的 Flower 客户端定义#

加载数据集后,我们定义 Flower 客户端。我们按照一般规则定义从 fl.client.Client 继承而来的 XgbClient 类。

class XgbClient(fl.client.Client):
    def __init__(self):
        self.bst = None
        self.config = None

代码:`self.bst`用于保存在各轮中保持一致的 Booster 对象,使其能够存储在前几轮中集成的树的预测结果,并维护其他用于训练的重要数据结构。

然后,我们在 XgbClient 类中重写 get_parametersfitevaluate 方法如下。

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
    _ = (self, ins)
    return GetParametersRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        parameters=Parameters(tensor_type="", tensors=[]),
    )

与神经网络训练不同,XGBoost 树不是从指定的随机参数开始的。在这种情况下,我们不使用 get_parametersset_parameters 来初始化 XGBoost 的模型参数。因此,当服务器在第一轮调用 get_parameters 时,让我们在 get_parameters 中返回一个空张量。

def fit(self, ins: FitIns) -> FitRes:
    if not self.bst:
        # First round local training
        log(INFO, "Start training at round 1")
        bst = xgb.train(
            params,
            train_dmatrix,
            num_boost_round=num_local_round,
            evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
        )
        self.config = bst.save_config()
        self.bst = bst
    else:
        for item in ins.parameters.tensors:
            global_model = bytearray(item)

        # Load global model into booster
        self.bst.load_model(global_model)
        self.bst.load_config(self.config)

        bst = self._local_boost()

    local_model = bst.save_raw("json")
    local_model_bytes = bytes(local_model)

    return FitRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
        num_examples=num_train,
        metrics={},
    )

fit`中,第一轮我们调用 :code:`xgb.train()`来建立第一组树,返回的 Booster 对象和 config 分别存储在 :code:`self.bstself.config 中。从第二轮开始,我们将服务器发送的全局模型加载到 self.bst,然后使用函数 :code:`local_boost`更新本地训练数据的模型权重,如下所示:

def _local_boost(self):
    # Update trees based on local training data.
    for i in range(num_local_round):
        self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

    # Extract the last N=num_local_round trees for sever aggregation
    bst = self.bst[
        self.bst.num_boosted_rounds()
        - num_local_round : self.bst.num_boosted_rounds()
    ]

给定 num_local_round,我们通过调用 self.bst.update`方法更新树。训练结束后,我们将提取最后一个 :code:`N=num_local_round 树并发送给服务器。

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
    eval_results = self.bst.eval_set(
        evals=[(valid_dmatrix, "valid")],
        iteration=self.bst.num_boosted_rounds() - 1,
    )
    auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

    return EvaluateRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        loss=0.0,
        num_examples=num_val,
        metrics={"AUC": auc},
    )

在 :code:`evaluate`中,我们调用 :code:`self.bst.eval_set`函数对有效集合进行评估。将返回 AUC 值。

现在,我们可以创建一个 XgbClient 类的实例,并添加一行来实际运行该客户端:

fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient())

这就是客户端。我们只需实现 客户端`并调用 :code:`fl.client.start_client()。字符串 "[::]:8080"`会告诉客户端要连接的服务器。在本例中,我们可以在同一台机器上运行服务器和客户端,因此我们使用 :code:"[::]:8080"。如果我们运行的是真正的联邦工作负载,服务器和客户端运行在不同的机器上,那么需要改变的只是客户端指向的 :code:`server_address

Flower 服务器#

然后,这些更新会被发送到*服务器*,由*服务器*聚合后生成一个更好的模型。最后,服务器*将这个改进版的模型发回给每个*客户端,以完成一轮完整的 FL。

在名为 server.py 的文件中,从 flwr.server.strategy 导入 Flower 和 FedXgbBagging。

我们首先定义了 XGBoost bagging聚合策略。

# Define strategy
strategy = FedXgbBagging(
    fraction_fit=1.0,
    min_fit_clients=2,
    min_available_clients=2,
    min_evaluate_clients=2,
    fraction_evaluate=1.0,
    evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
)

def evaluate_metrics_aggregation(eval_metrics):
    """Return an aggregated metric (AUC) for evaluation."""
    total_num = sum([num for num, _ in eval_metrics])
    auc_aggregated = (
        sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num
    )
    metrics_aggregated = {"AUC": auc_aggregated}
    return metrics_aggregated

本示例使用两个客户端。我们定义了一个 evaluate_metrics_aggregation 函数,用于收集客户机的 AUC 值并求取平均值。

然后,我们启动服务器:

# Start Flower server
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=num_rounds),
    strategy=strategy,
)

基于树的bagging聚合#

您一定很好奇bagging聚合是如何工作的。让我们来详细了解一下。

在文件 flwr.server.strategy.fedxgb_bagging.py`中,我们定义了从 :code:`flwr.server.strategy.FedAvg`继承的 :code:`FedXgbBagging。然后,我们覆盖 aggregate_fitaggregate_evaluateevaluate 方法如下:

import json
from logging import WARNING
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


class FedXgbBagging(FedAvg):
    """Configurable FedXgbBagging strategy implementation."""

    def __init__(
        self,
        evaluate_function: Optional[
            Callable[
                [int, Parameters, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        **kwargs: Any,
    ):
        self.evaluate_function = evaluate_function
        self.global_model: Optional[bytes] = None
        super().__init__(**kwargs)

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using bagging."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # Aggregate all the client trees
        global_model = self.global_model
        for _, fit_res in results:
            update = fit_res.parameters.tensors
            for bst in update:
                global_model = aggregate(global_model, bst)

        self.global_model = global_model

        return (
            Parameters(tensor_type="", tensors=[cast(bytes, global_model)]),
            {},
        )

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation metrics using average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.evaluate_metrics_aggregation_fn:
            eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No evaluate_metrics_aggregation_fn provided")

        return 0, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_function is None:
            # No evaluation function provided
            return None
        eval_res = self.evaluate_function(server_round, parameters, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

aggregate_fit 中,我们通过调用 aggregate() 函数,按顺序聚合客户端的 XGBoost 树:

def aggregate(
    bst_prev_org: Optional[bytes],
    bst_curr_org: bytes,
) -> bytes:
    """Conduct bagging aggregation for given trees."""
    if not bst_prev_org:
        return bst_curr_org

    # Get the tree numbers
    tree_num_prev, _ = _get_tree_nums(bst_prev_org)
    _, paral_tree_num_curr = _get_tree_nums(bst_curr_org)

    bst_prev = json.loads(bytearray(bst_prev_org))
    bst_curr = json.loads(bytearray(bst_curr_org))

    bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
        "num_trees"
    ] = str(tree_num_prev + paral_tree_num_curr)
    iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
        "iteration_indptr"
    ]
    bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
        iteration_indptr[-1] + paral_tree_num_curr
    )

    # Aggregate new trees
    trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
    for tree_count in range(paral_tree_num_curr):
        trees_curr[tree_count]["id"] = tree_num_prev + tree_count
        bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
            trees_curr[tree_count]
        )
        bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

    bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")

    return bst_prev_bytes


def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]:
    xgb_model = json.loads(bytearray(xgb_model_org))
    # Get the number of trees
    tree_num = int(
        xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
            "num_trees"
        ]
    )
    # Get the number of parallel trees
    paral_tree_num = int(
        xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
            "num_parallel_tree"
        ]
    )
    return tree_num, paral_tree_num

在该函数中,我们首先通过调用 _get_tree_nums 获取当前模型和上一个模型的树数和并行树数。然后,对获取的信息进行聚合。然后,聚合树(包含模型参数)生成新的树模型。

在遍历所有客户端的模型后,会生成一个新的全局模型,然后进行序列化,并发回给每个客户端。

启动联邦 XGBoost!#

客户端和服务器都已准备就绪,我们现在可以运行一切,看看联邦学习的实际效果。FL 系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:

$ python3 server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python3 client.py --node-id=0

打开另一台终端,启动第二个客户端:

$ python3 client.py --node-id=1

每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:

INFO flwr 2023-11-20 11:21:56,454 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flwr 2023-11-20 11:21:56,473 | app.py:176 | Flower ECE: gRPC server running (5 rounds), SSL is disabled
INFO flwr 2023-11-20 11:21:56,473 | server.py:89 | Initializing global parameters
INFO flwr 2023-11-20 11:21:56,473 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-11-20 11:22:38,302 | server.py:280 | Received initial parameters from one random client
INFO flwr 2023-11-20 11:22:38,302 | server.py:91 | Evaluating initial parameters
INFO flwr 2023-11-20 11:22:38,302 | server.py:104 | FL starting
DEBUG flwr 2023-11-20 11:22:38,302 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,636 | server.py:236 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,653 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,653 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,721 | server.py:236 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,745 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,756 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,756 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,831 | server.py:236 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,868 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,881 | server.py:187 | evaluate_round 3 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,881 | server.py:222 | fit_round 4: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,960 | server.py:236 | fit_round 4 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,012 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,026 | server.py:187 | evaluate_round 4 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,026 | server.py:222 | fit_round 5: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,111 | server.py:236 | fit_round 5 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,177 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,193 | server.py:187 | evaluate_round 5 received 2 results and 0 failures
INFO flwr 2023-11-20 11:22:39,193 | server.py:153 | FL finished in 0.8905023969999988
INFO flwr 2023-11-20 11:22:39,193 | app.py:226 | app_fit: losses_distributed [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]
INFO flwr 2023-11-20 11:22:39,193 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2023-11-20 11:22:39,193 | app.py:228 | app_fit: metrics_distributed {'AUC': [(1, 0.7572), (2, 0.7705), (3, 0.77595), (4, 0.78), (5, 0.78385)]}
INFO flwr 2023-11-20 11:22:39,193 | app.py:229 | app_fit: losses_centralized []
INFO flwr 2023-11-20 11:22:39,193 | app.py:230 | app_fit: metrics_centralized {}

恭喜您!您已成功构建并运行了第一个联邦 XGBoost 系统。可以在 metrics_distributed 中查看 AUC 值。我们可以看到,平均 AUC 随 FL 轮数的增加而增加。

此示例的`完整源代码 <https://github.com/adap/flower/blob/main/examples/xgboost-quickstart/>`_ 可在 examples/xgboost-quickstart 中找到。

综合的联邦 XGBoost#

Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (full code), we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. We also support Flower simulation making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look!

Cyclic training#

In addition to bagging aggregation, we offer a cyclic training scheme, which performs FL in a client-by-client fashion. Instead of aggregating multiple clients, there is only one single client participating in the training per round in the cyclic training scenario. The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting.

To do this, we first customise a ClientManager in server_utils.py:

class CyclicClientManager(SimpleClientManager):
    """Provides a cyclic client selection rule."""

    def sample(
        self,
        num_clients: int,
        min_num_clients: Optional[int] = None,
        criterion: Optional[Criterion] = None,
    ) -> List[ClientProxy]:
        """Sample a number of Flower ClientProxy instances."""

        # Block until at least num_clients are connected.
        if min_num_clients is None:
            min_num_clients = num_clients
        self.wait_for(min_num_clients)

        # Sample clients which meet the criterion
        available_cids = list(self.clients)
        if criterion is not None:
            available_cids = [
                cid for cid in available_cids if criterion.select(self.clients[cid])
            ]

        if num_clients > len(available_cids):
            log(
                INFO,
                "Sampling failed: number of available clients"
                " (%s) is less than number of requested clients (%s).",
                len(available_cids),
                num_clients,
            )
            return []

        # Return all available clients
        return [self.clients[cid] for cid in available_cids]

The customised ClientManager samples all available clients in each FL round based on the order of connection to the server. Then, we define a new strategy FedXgbCyclic in flwr.server.strategy.fedxgb_cyclic.py, in order to sequentially select only one client in given round and pass the received model to next client.

class FedXgbCyclic(FedAvg):
    """Configurable FedXgbCyclic strategy implementation."""

    # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
    def __init__(
        self,
        **kwargs: Any,
    ):
        self.global_model: Optional[bytes] = None
        super().__init__(**kwargs)

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using bagging."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # Fetch the client model from last round as global model
        for _, fit_res in results:
            update = fit_res.parameters.tensors
            for bst in update:
                self.global_model = bst

        return (
            Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]),
            {},
        )

Unlike the original FedAvg, we don't perform aggregation here. Instead, we just make a copy of the received client model as global model by overriding aggregate_fit.

Also, the customised configure_fit and configure_evaluate methods ensure the clients to be sequentially selected given FL round:

def configure_fit(
    self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
    """Configure the next round of training."""
    config = {}
    if self.on_fit_config_fn is not None:
        # Custom fit config function provided
        config = self.on_fit_config_fn(server_round)
    fit_ins = FitIns(parameters, config)

    # Sample clients
    sample_size, min_num_clients = self.num_fit_clients(
        client_manager.num_available()
    )
    clients = client_manager.sample(
        num_clients=sample_size,
        min_num_clients=min_num_clients,
    )

    # Sample the clients sequentially given server_round
    sampled_idx = (server_round - 1) % len(clients)
    sampled_clients = [clients[sampled_idx]]

    # Return client/config pairs
    return [(client, fit_ins) for client in sampled_clients]

def configure_evaluate(
    self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
    """Configure the next round of evaluation."""
    # Do not configure federated evaluation if fraction eval is 0.
    if self.fraction_evaluate == 0.0:
        return []

    # Parameters and config
    config = {}
    if self.on_evaluate_config_fn is not None:
        # Custom evaluation config function provided
        config = self.on_evaluate_config_fn(server_round)
    evaluate_ins = EvaluateIns(parameters, config)

    # Sample clients
    sample_size, min_num_clients = self.num_evaluation_clients(
        client_manager.num_available()
    )
    clients = client_manager.sample(
        num_clients=sample_size,
        min_num_clients=min_num_clients,
    )

    # Sample the clients sequentially given server_round
    sampled_idx = (server_round - 1) % len(clients)
    sampled_clients = [clients[sampled_idx]]

    # Return client/config pairs
    return [(client, evaluate_ins) for client in sampled_clients]

定制数据分区#

dataset.py 中,我们有一个函数 instantiate_partitioner 来根据给定的 num_partitionspartitioner_type 来实例化数据分区器。目前,我们提供四种支持的分区器类型(均匀、线性、正方形、指数)来模拟数据量的均匀性/非均匀性。

from flwr_datasets.partitioner import (
    IidPartitioner,
    LinearPartitioner,
    SquarePartitioner,
    ExponentialPartitioner,
)

CORRELATION_TO_PARTITIONER = {
    "uniform": IidPartitioner,
    "linear": LinearPartitioner,
    "square": SquarePartitioner,
    "exponential": ExponentialPartitioner,
}


def instantiate_partitioner(partitioner_type: str, num_partitions: int):
    """Initialise partitioner based on selected partitioner type and number of
    partitions."""
    partitioner = CORRELATION_TO_PARTITIONER[partitioner_type](
        num_partitions=num_partitions
    )
    return partitioner

定制的集中/分布式评估#

To facilitate centralised evaluation, we define a function in server_utils.py:

def get_evaluate_fn(test_data):
    """Return a function for centralised evaluation."""

    def evaluate_fn(
        server_round: int, parameters: Parameters, config: Dict[str, Scalar]
    ):
        # If at the first round, skip the evaluation
        if server_round == 0:
            return 0, {}
        else:
            bst = xgb.Booster(params=params)
            for para in parameters.tensors:
                para_b = bytearray(para)

            # Load global model
            bst.load_model(para_b)
            # Run evaluation
            eval_results = bst.eval_set(
                evals=[(test_data, "valid")],
                iteration=bst.num_boosted_rounds() - 1,
            )
            auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)
            log(INFO, f"AUC = {auc} at round {server_round}")

            return 0, {"AUC": auc}

    return evaluate_fn

此函数返回一个评估函数,该函数实例化一个 Booster 对象,并向其加载全局模型参数。评估通过调用 eval_set() 方法进行,并报告测试的 AUC 值。

As for distributed evaluation on the clients, it's same as the quick-start example by overriding the evaluate() method insides the XgbClient class in client_utils.py.

Flower simulation#

We also provide an example code (sim.py) to use the simulation capabilities of Flower to simulate federated XGBoost training on either a single machine or a cluster of machines.

from logging import INFO
import xgboost as xgb
from tqdm import tqdm

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.server.strategy import FedXgbBagging, FedXgbCyclic

from dataset import (
    instantiate_partitioner,
    train_test_split,
    transform_dataset_to_dmatrix,
    separate_xy,
    resplit,
)
from utils import (
    sim_args_parser,
    NUM_LOCAL_ROUND,
    BST_PARAMS,
)
from server_utils import (
    eval_config,
    fit_config,
    evaluate_metrics_aggregation,
    get_evaluate_fn,
    CyclicClientManager,
)
from client_utils import XgbClient

After importing all required packages, we define a main() function to perform the simulation process:

def main():
  # Parse arguments for experimental settings
  args = sim_args_parser()

  # Load (HIGGS) dataset and conduct partitioning
  partitioner = instantiate_partitioner(
      partitioner_type=args.partitioner_type, num_partitions=args.pool_size
  )
  fds = FederatedDataset(
      dataset="jxie/higgs",
      partitioners={"train": partitioner},
      resplitter=resplit,
  )

  # Load centralised test set
  if args.centralised_eval or args.centralised_eval_client:
      log(INFO, "Loading centralised test set...")
      test_data = fds.load_split("test")
      test_data.set_format("numpy")
      num_test = test_data.shape[0]
      test_dmatrix = transform_dataset_to_dmatrix(test_data)

  # Load partitions and reformat data to DMatrix for xgboost
  log(INFO, "Loading client local partitions...")
  train_data_list = []
  valid_data_list = []

  # Load and process all client partitions. This upfront cost is amortized soon
  # after the simulation begins since clients wont need to preprocess their partition.
  for node_id in tqdm(range(args.pool_size), desc="Extracting client partition"):
      # Extract partition for client with node_id
      partition = fds.load_partition(node_id=node_id, split="train")
      partition.set_format("numpy")

      if args.centralised_eval_client:
          # Use centralised test set for evaluation
          train_data = partition
          num_train = train_data.shape[0]
          x_test, y_test = separate_xy(test_data)
          valid_data_list.append(((x_test, y_test), num_test))
      else:
          # Train/test splitting
          train_data, valid_data, num_train, num_val = train_test_split(
              partition, test_fraction=args.test_fraction, seed=args.seed
          )
          x_valid, y_valid = separate_xy(valid_data)
          valid_data_list.append(((x_valid, y_valid), num_val))

      x_train, y_train = separate_xy(train_data)
      train_data_list.append(((x_train, y_train), num_train))

We first load the dataset and perform data partitioning, and the pre-processed data is stored in a list. After the simulation begins, the clients won't need to pre-process their partitions again.

Then, we define the strategies and other hyper-parameters:

# Define strategy
if args.train_method == "bagging":
    # Bagging training
    strategy = FedXgbBagging(
        evaluate_function=get_evaluate_fn(test_dmatrix)
        if args.centralised_eval
        else None,
        fraction_fit=(float(args.num_clients_per_round) / args.pool_size),
        min_fit_clients=args.num_clients_per_round,
        min_available_clients=args.pool_size,
        min_evaluate_clients=args.num_evaluate_clients
        if not args.centralised_eval
        else 0,
        fraction_evaluate=1.0 if not args.centralised_eval else 0.0,
        on_evaluate_config_fn=eval_config,
        on_fit_config_fn=fit_config,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation
        if not args.centralised_eval
        else None,
    )
else:
    # Cyclic training
    strategy = FedXgbCyclic(
        fraction_fit=1.0,
        min_available_clients=args.pool_size,
        fraction_evaluate=1.0,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
        on_evaluate_config_fn=eval_config,
        on_fit_config_fn=fit_config,
    )

# Resources to be assigned to each virtual client
# In this example we use CPU by default
client_resources = {
    "num_cpus": args.num_cpus_per_client,
    "num_gpus": 0.0,
}

# Hyper-parameters for xgboost training
num_local_round = NUM_LOCAL_ROUND
params = BST_PARAMS

# Setup learning rate
if args.train_method == "bagging" and args.scaled_lr:
    new_lr = params["eta"] / args.pool_size
    params.update({"eta": new_lr})

After that, we start the simulation by calling fl.simulation.start_simulation:

# Start simulation
fl.simulation.start_simulation(
    client_fn=get_client_fn(
        train_data_list,
        valid_data_list,
        args.train_method,
        params,
        num_local_round,
    ),
    num_clients=args.pool_size,
    client_resources=client_resources,
    config=fl.server.ServerConfig(num_rounds=args.num_rounds),
    strategy=strategy,
    client_manager=CyclicClientManager() if args.train_method == "cyclic" else None,
)

One of key parameters for start_simulation is client_fn which returns a function to construct a client. We define it as follows:

def get_client_fn(
    train_data_list, valid_data_list, train_method, params, num_local_round
):
    """Return a function to construct a client.

    The VirtualClientEngine will execute this function whenever a client is sampled by
    the strategy to participate.
    """

    def client_fn(cid: str) -> fl.client.Client:
        """Construct a FlowerClient with its own dataset partition."""
        x_train, y_train = train_data_list[int(cid)][0]
        x_valid, y_valid = valid_data_list[int(cid)][0]

        # Reformat data to DMatrix
        train_dmatrix = xgb.DMatrix(x_train, label=y_train)
        valid_dmatrix = xgb.DMatrix(x_valid, label=y_valid)

        # Fetch the number of examples
        num_train = train_data_list[int(cid)][1]
        num_val = valid_data_list[int(cid)][1]

        # Create and return client
        return XgbClient(
            train_dmatrix,
            valid_dmatrix,
            num_train,
            num_val,
            num_local_round,
            params,
            train_method,
        )

    return client_fn

参数解析器#

In utils.py, we define the arguments parsers for clients, server and simulation, allowing users to specify different experimental settings. Let's first see the sever side:

import argparse


def server_args_parser():
  """Parse arguments to define experimental settings on server side."""
  parser = argparse.ArgumentParser()

  parser.add_argument(
      "--train-method",
      default="bagging",
      type=str,
      choices=["bagging", "cyclic"],
      help="Training methods selected from bagging aggregation or cyclic training.",
  )
  parser.add_argument(
      "--pool-size", default=2, type=int, help="Number of total clients."
  )
  parser.add_argument(
      "--num-rounds", default=5, type=int, help="Number of FL rounds."
  )
  parser.add_argument(
      "--num-clients-per-round",
      default=2,
      type=int,
      help="Number of clients participate in training each round.",
  )
  parser.add_argument(
      "--num-evaluate-clients",
      default=2,
      type=int,
      help="Number of clients selected for evaluation.",
  )
  parser.add_argument(
      "--centralised-eval",
      action="store_true",
      help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).",
  )

  args = parser.parse_args()
  return args

This allows user to specify training strategies / the number of total clients / FL rounds / participating clients / clients for evaluation, and evaluation fashion. Note that with --centralised-eval, the sever will do centralised evaluation and all functionalities for client evaluation will be disabled.

然后是客户端的参数解析器:

def client_args_parser():
  """Parse arguments to define experimental settings on client side."""
  parser = argparse.ArgumentParser()

  parser.add_argument(
      "--train-method",
      default="bagging",
      type=str,
      choices=["bagging", "cyclic"],
      help="Training methods selected from bagging aggregation or cyclic training.",
  )
  parser.add_argument(
      "--num-partitions", default=10, type=int, help="Number of partitions."
  )
  parser.add_argument(
      "--partitioner-type",
      default="uniform",
      type=str,
      choices=["uniform", "linear", "square", "exponential"],
      help="Partitioner types.",
  )
  parser.add_argument(
      "--node-id",
      default=0,
      type=int,
      help="Node ID used for the current client.",
  )
  parser.add_argument(
      "--seed", default=42, type=int, help="Seed used for train/test splitting."
  )
  parser.add_argument(
      "--test-fraction",
      default=0.2,
      type=float,
      help="Test fraction for train/test splitting.",
  )
  parser.add_argument(
      "--centralised-eval",
      action="store_true",
      help="Conduct evaluation on centralised test set (True), or on hold-out data (False).",
  )
  parser.add_argument(
      "--scaled-lr",
      action="store_true",
      help="Perform scaled learning rate based on the number of clients (True).",
  )

  args = parser.parse_args()
  return args

This defines various options for client data partitioning. Besides, clients also have an option to conduct evaluation on centralised test set by setting --centralised-eval, as well as an option to perform scaled learning rate based on the number of clients by setting --scaled-lr.

We also have an argument parser for simulation:

def sim_args_parser():
    """Parse arguments to define experimental settings on server side."""
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--train-method",
        default="bagging",
        type=str,
        choices=["bagging", "cyclic"],
        help="Training methods selected from bagging aggregation or cyclic training.",
    )

    # Server side
    parser.add_argument(
        "--pool-size", default=5, type=int, help="Number of total clients."
    )
    parser.add_argument(
        "--num-rounds", default=30, type=int, help="Number of FL rounds."
    )
    parser.add_argument(
        "--num-clients-per-round",
        default=5,
        type=int,
        help="Number of clients participate in training each round.",
    )
    parser.add_argument(
        "--num-evaluate-clients",
        default=5,
        type=int,
        help="Number of clients selected for evaluation.",
    )
    parser.add_argument(
        "--centralised-eval",
        action="store_true",
        help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).",
    )
    parser.add_argument(
        "--num-cpus-per-client",
        default=2,
        type=int,
        help="Number of CPUs used for per client.",
    )

    # Client side
    parser.add_argument(
        "--partitioner-type",
        default="uniform",
        type=str,
        choices=["uniform", "linear", "square", "exponential"],
        help="Partitioner types.",
    )
    parser.add_argument(
        "--seed", default=42, type=int, help="Seed used for train/test splitting."
    )
    parser.add_argument(
        "--test-fraction",
        default=0.2,
        type=float,
        help="Test fraction for train/test splitting.",
    )
    parser.add_argument(
        "--centralised-eval-client",
        action="store_true",
        help="Conduct evaluation on centralised test set (True), or on hold-out data (False).",
    )
    parser.add_argument(
        "--scaled-lr",
        action="store_true",
        help="Perform scaled learning rate based on the number of clients (True).",
    )

    args = parser.parse_args()
    return args

This integrates all arguments for both client and server sides.

命令示例#

To run a centralised evaluated experiment with bagging strategy on 5 clients with exponential distribution for 50 rounds, we first start the server as below:

$ python3 server.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval

然后,我们在每个客户终端上启动客户机:

$ python3 clients.py --train-method=bagging --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID

To run the same experiment with Flower simulation:

$ python3 sim.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --partitioner-type=exponential --centralised-eval

The full code for this comprehensive example can be found in examples/xgboost-comprehensive.