XGBoost快速入门¶
XGBoost¶
EXtreme Gradient Boosting(XGBoost)是梯度提升决策树(GBDT)的一种稳健而高效的实现方法,能最大限度地提高提升树方法的计算边界。它主要用于提高机器学习模型的性能和计算速度。在 XGBoost 中,决策树是并发构建的,与 GBDT 采用的顺序方法不同。
对于训练示例少于 10k 的中型数据集上的表格数据,XGBoost 的结果往往超过深度学习技术。
Why Federated XGBoost?¶
As the demand for data privacy and decentralized learning grows, there's an increasing requirement to implement federated XGBoost systems for specialised applications, like survival analysis and financial fraud detection.
Federated learning ensures that raw data remains on the local device, making it an attractive approach for sensitive domains where data privacy is paramount. Given the robustness and efficiency of XGBoost, combining it with federated learning offers a promising solution for these specific challenges.
环境设定¶
In this tutorial, we learn how to train a federated XGBoost model on the HIGGS dataset
using Flower and the xgboost
package to perform a binary classification task. We use
a simple example (full code xgboost-quickstart) to demonstrate
how federated XGBoost works, and then we dive into a more complex comprehensive example
(full code xgboost-comprehensive) to run
various experiments.
It is recommended to create a virtual environment and run everything within a virtualenv.
我们首先需要安装 Flower 和 Flower Datasets。您可以通过运行 :
# In a new Python environment
$ pip install flwr flwr-datasets
Since we want to use xgboost
package to build up XGBoost trees, let's go ahead and
install xgboost
:
$ pip install xgboost
The Configurations¶
We define all required configurations / hyper-parameters inside the pyproject.toml
file:
[tool.flwr.app.config]
# ServerApp
num-server-rounds = 3
fraction-fit = 0.1
fraction-evaluate = 0.1
# ClientApp
local-epochs = 1
params.objective = "binary:logistic"
params.eta = 0.1 # Learning rate
params.max-depth = 8
params.eval-metric = "auc"
params.nthread = 16
params.num-parallel-tree = 1
params.subsample = 1
params.tree-method = "hist"
The local-epochs
represents the number of iterations for local tree boost. We use
CPU for the training in default. One can assign it to a GPU by setting tree_method
to gpu_hist
. We use AUC as evaluation metric.
The Data¶
This tutorial uses Flower Datasets to easily download and partition the HIGGS dataset.
# Load (HIGGS) dataset and partition.
# We use a small subset (num_partitions=20) of the dataset for demonstration to speed up the data loading process.
partitioner = IidPartitioner(num_partitions=20)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})
# Load the partition for this `partition_id`
partition = fds.load_partition(partition_id, split="train")
partition.set_format("numpy")
In this example, we split the dataset into 20 partitions with uniform distribution
(IidPartitioner).
Then, we load the partition for the given client based on partition_id
.
Subsequently, we train/test split using the given partition (client's local data), and
reformat data to DMatrix for the xgboost
package.
# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
partition, test_fraction=0.2, seed=42
)
# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
The functions of train_test_split
and transform_dataset_to_dmatrix
are defined
as below:
def train_test_split(partition, test_fraction, seed):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
partition_train = train_test["train"]
partition_test = train_test["test"]
num_train = len(partition_train)
num_test = len(partition_test)
return partition_train, partition_test, num_train, num_test
def transform_dataset_to_dmatrix(data):
"""Transform dataset to DMatrix format for xgboost."""
x = data["inputs"]
y = data["label"]
new_data = xgb.DMatrix(x, label=y)
return new_data
The ClientApp¶
Clients are responsible for generating individual weight-updates for the model based
on their local datasets. Let's first see how we define Flower client for XGBoost. We
follow the general rule to define FlowerClient
class inherited from
fl.client.Client
.
# Define Flower Client and client_fn
class FlowerClient(Client):
def __init__(
self,
train_dmatrix,
valid_dmatrix,
num_train,
num_val,
num_local_round,
params,
):
self.train_dmatrix = train_dmatrix
self.valid_dmatrix = valid_dmatrix
self.num_train = num_train
self.num_val = num_val
self.num_local_round = num_local_round
self.params = params
All required parameters defined above are passed to FlowerClient
's constructor.
Then, we override fit
and evaluate
methods insides FlowerClient
class as
follows.
def fit(self, ins: FitIns) -> FitRes:
global_round = int(ins.config["global_round"])
if global_round == 1:
# First round local training
bst = xgb.train(
self.params,
self.train_dmatrix,
num_boost_round=self.num_local_round,
evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")],
)
else:
bst = xgb.Booster(params=self.params)
global_model = bytearray(ins.parameters.tensors[0])
# Load global model into booster
bst.load_model(global_model)
# Local training
bst = self._local_boost(bst)
# Save model
local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)
return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=self.num_train,
metrics={},
)
In fit
, at the first round, we call xgb.train()
to build up the first set of
trees. From the second round, we load the global model sent from server to new build
Booster object, and then update model weights on local training data with function
_local_boost
as follows:
def _local_boost(self, bst_input):
# Update trees based on local training data.
for i in range(self.num_local_round):
bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds())
# Bagging: extract the last N=num_local_round trees for sever aggregation
bst = bst_input[
bst_input.num_boosted_rounds()
- self.num_local_round : bst_input.num_boosted_rounds()
]
return bst
Given num_local_round
, we update trees by calling bst_input.update
method. After
training, the last N=num_local_round
trees will be extracted to send to the server.
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
# Load global model
bst = xgb.Booster(params=self.params)
para_b = bytearray(ins.parameters.tensors[0])
bst.load_model(para_b)
# Run evaluation
eval_results = bst.eval_set(
evals=[(self.valid_dmatrix, "valid")],
iteration=bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)
return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=self.num_val,
metrics={"AUC": auc},
)
In evaluate
, after loading the global model, we call bst.eval_set
function to
conduct evaluation on valid set. The AUC value will be returned.
The ServerApp¶
After the local training on clients, clients' model updates are sent to the server, which aggregates them to produce a better model. Finally, the server sends this improved model version back to each client to complete a federated round.
In the file named server_app.py
, we define a strategy for XGBoost bagging
aggregation:
# Define strategy
strategy = FedXgbBagging(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
on_evaluate_config_fn=config_func,
on_fit_config_fn=config_func,
initial_parameters=parameters,
)
def evaluate_metrics_aggregation(eval_metrics):
"""Return an aggregated metric (AUC) for evaluation."""
total_num = sum([num for num, _ in eval_metrics])
auc_aggregated = (
sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num
)
metrics_aggregated = {"AUC": auc_aggregated}
return metrics_aggregated
def config_func(rnd: int) -> Dict[str, str]:
"""Return a configuration with global epochs."""
config = {
"global_round": str(rnd),
}
return config
An evaluate_metrics_aggregation
function is defined to collect and wighted average
the AUC values from clients. The config_func
function is to return the current FL
round number to client's fit()
and evaluate()
methods.
Tree-based Bagging Aggregation¶
您一定很好奇bagging聚合是如何工作的。让我们来详细了解一下。
In file flwr.server.strategy.fedxgb_bagging.py
, we define FedXgbBagging
inherited from flwr.server.strategy.FedAvg
. Then, we override the aggregate_fit
,
aggregate_evaluate
and evaluate
methods as follows:
import json
from logging import WARNING
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from .fedavg import FedAvg
class FedXgbBagging(FedAvg):
"""Configurable FedXgbBagging strategy implementation."""
def __init__(
self,
evaluate_function: Optional[
Callable[
[int, Parameters, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
**kwargs: Any,
):
self.evaluate_function = evaluate_function
self.global_model: Optional[bytes] = None
super().__init__(**kwargs)
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using bagging."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Aggregate all the client trees
global_model = self.global_model
for _, fit_res in results:
update = fit_res.parameters.tensors
for bst in update:
global_model = aggregate(global_model, bst)
self.global_model = global_model
return (
Parameters(tensor_type="", tensors=[cast(bytes, global_model)]),
{},
)
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation metrics using average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
return 0, metrics_aggregated
def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_function is None:
# No evaluation function provided
return None
eval_res = self.evaluate_function(server_round, parameters, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics
In aggregate_fit
, we sequentially aggregate the clients' XGBoost trees by calling
aggregate()
function:
def aggregate(
bst_prev_org: Optional[bytes],
bst_curr_org: bytes,
) -> bytes:
"""Conduct bagging aggregation for given trees."""
if not bst_prev_org:
return bst_curr_org
# Get the tree numbers
tree_num_prev, _ = _get_tree_nums(bst_prev_org)
_, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
bst_prev = json.loads(bytearray(bst_prev_org))
bst_curr = json.loads(bytearray(bst_curr_org))
bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
"iteration_indptr"
]
bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
iteration_indptr[-1] + paral_tree_num_curr
)
# Aggregate new trees
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
for tree_count in range(paral_tree_num_curr):
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
return bst_prev_bytes
def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]:
xgb_model = json.loads(bytearray(xgb_model_org))
# Get the number of trees
tree_num = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
]
)
# Get the number of parallel trees
paral_tree_num = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_parallel_tree"
]
)
return tree_num, paral_tree_num
In this function, we first fetch the number of trees and the number of parallel trees
for the current and previous model by calling _get_tree_nums
. Then, the fetched
information will be aggregated. After that, the trees (containing model weights) are
aggregated to generate a new tree model.
After traversal of all clients' models, a new global model is generated, followed by serialisation, and sending the global model back to each client.
启动联邦 XGBoost!¶
To run the project, do:
# Run with default arguments
$ flwr run .
With default arguments you will see an output like this one:
Loading project configuration...
Success
INFO : Starting Flower ServerApp, config: num_rounds=3, no round_timeout
INFO :
INFO : [INIT]
INFO : Using initial global parameters provided by strategy
INFO : Starting evaluation of initial global parameters
INFO : Evaluation returned no results (`None`)
INFO :
INFO : [ROUND 1]
INFO : configure_fit: strategy sampled 2 clients (out of 20)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 20)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 2]
INFO : configure_fit: strategy sampled 2 clients (out of 20)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 20)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 3]
INFO : configure_fit: strategy sampled 2 clients (out of 20)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 20)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [SUMMARY]
INFO : Run finished 3 round(s) in 145.42s
INFO : History (loss, distributed):
INFO : round 1: 0
INFO : round 2: 0
INFO : round 3: 0
INFO : History (metrics, distributed, evaluate):
INFO : {'AUC': [(1, 0.7664), (2, 0.77595), (3, 0.7826)]}
INFO :
Congratulations! You've successfully built and run your first federated XGBoost system.
The AUC values can be checked in History (metrics, distributed, evaluate)
. One can
see that the average AUC increases over FL rounds.
You can also override the parameters defined in the [tool.flwr.app.config]
section
in pyproject.toml
like this:
# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 params.eta=0.05"
Note
Check the full source code for this
example in examples/xgboost-quickstart
in the Flower GitHub repository.
综合的联邦 XGBoost¶
Now that you know how federated XGBoost works with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (full code), we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised / distributed evaluation. Let's take a look!
Cyclic Training¶
In addition to bagging aggregation, we offer a cyclic training scheme, which performs FL in a client-by-client fashion. Instead of aggregating multiple clients, there is only one single client participating in the training per round in the cyclic training scenario. The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting.
To do this, we first customise a ClientManager
in server_app.py
:
class CyclicClientManager(SimpleClientManager):
"""Provides a cyclic client selection rule."""
def sample(
self,
num_clients: int,
min_num_clients: Optional[int] = None,
criterion: Optional[Criterion] = None,
) -> List[ClientProxy]:
"""Sample a number of Flower ClientProxy instances."""
# Block until at least num_clients are connected.
if min_num_clients is None:
min_num_clients = num_clients
self.wait_for(min_num_clients)
# Sample clients which meet the criterion
available_cids = list(self.clients)
if criterion is not None:
available_cids = [
cid for cid in available_cids if criterion.select(self.clients[cid])
]
if num_clients > len(available_cids):
log(
INFO,
"Sampling failed: number of available clients"
" (%s) is less than number of requested clients (%s).",
len(available_cids),
num_clients,
)
return []
# Return all available clients
return [self.clients[cid] for cid in available_cids]
The customised ClientManager
samples all available clients in each FL round based on
the order of connection to the server. Then, we define a new strategy FedXgbCyclic
in flwr.server.strategy.fedxgb_cyclic.py
, in order to sequentially select only one
client in given round and pass the received model to the next client.
class FedXgbCyclic(FedAvg):
"""Configurable FedXgbCyclic strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
**kwargs: Any,
):
self.global_model: Optional[bytes] = None
super().__init__(**kwargs)
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using bagging."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Fetch the client model from last round as global model
for _, fit_res in results:
update = fit_res.parameters.tensors
for bst in update:
self.global_model = bst
return (
Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]),
{},
)
Unlike the original FedAvg
, we don't perform aggregation here. Instead, we just make
a copy of the received client model as global model by overriding aggregate_fit
.
Also, the customised configure_fit
and configure_evaluate
methods ensure the
clients to be sequentially selected given FL round:
def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
config = {}
if self.on_fit_config_fn is not None:
# Custom fit config function provided
config = self.on_fit_config_fn(server_round)
fit_ins = FitIns(parameters, config)
# Sample clients
sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
clients = client_manager.sample(
num_clients=sample_size,
min_num_clients=min_num_clients,
)
# Sample the clients sequentially given server_round
sampled_idx = (server_round - 1) % len(clients)
sampled_clients = [clients[sampled_idx]]
# Return client/config pairs
return [(client, fit_ins) for client in sampled_clients]
Customised Data Partitioning¶
In task.py
, we use the instantiate_fds
function to instantiate Flower Datasets
and the data partitioner based on the given partitioner_type
and num_partitions
.
Currently, we provide four supported partitioner type to simulate the
uniformity/non-uniformity in data quantity (uniform, linear, square, exponential).
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import (
IidPartitioner,
LinearPartitioner,
SquarePartitioner,
ExponentialPartitioner,
)
CORRELATION_TO_PARTITIONER = {
"uniform": IidPartitioner,
"linear": LinearPartitioner,
"square": SquarePartitioner,
"exponential": ExponentialPartitioner,
}
def instantiate_fds(partitioner_type, num_partitions):
"""Initialize FederatedDataset."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = CORRELATION_TO_PARTITIONER[partitioner_type](
num_partitions=num_partitions
)
fds = FederatedDataset(
dataset="jxie/higgs",
partitioners={"train": partitioner},
preprocessor=resplit,
)
return fds
Customised Centralised / Distributed Evaluation¶
To facilitate centralised evaluation, we define a function in server_app.py
:
def get_evaluate_fn(test_data, params):
"""Return a function for centralised evaluation."""
def evaluate_fn(
server_round: int, parameters: Parameters, config: Dict[str, Scalar]
):
# If at the first round, skip the evaluation
if server_round == 0:
return 0, {}
else:
bst = xgb.Booster(params=params)
for para in parameters.tensors:
para_b = bytearray(para)
# Load global model
bst.load_model(para_b)
# Run evaluation
eval_results = bst.eval_set(
evals=[(test_data, "valid")],
iteration=bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)
return 0, {"AUC": auc}
return evaluate_fn
This function returns an evaluation function, which instantiates a Booster
object
and loads the global model weights to it. The evaluation is conducted by calling
eval_set()
method, and the tested AUC value is reported.
As for distributed evaluation on the clients, it's same as the quick-start example by
overriding the evaluate()
method insides the XgbClient
class in
client_app.py
.
Arguments Explainer¶
We define all hyper-parameters under [tool.flwr.app.config]
entry in
pyproject.toml
:
[tool.flwr.app.config]
# ServerApp
train-method = "bagging" # Choose from [bagging, cyclic]
num-server-rounds = 3
fraction-fit = 1.0
fraction-evaluate = 1.0
centralised-eval = false
# ClientApp
partitioner-type = "uniform" # Choose from [uniform, linear, square, exponential]
test-fraction = 0.2
seed = 42
centralised-eval-client = false
local-epochs = 1
scaled-lr = false
params.objective = "binary:logistic"
params.eta = 0.1 # Learning rate
params.max-depth = 8
params.eval-metric = "auc"
params.nthread = 16
params.num-parallel-tree = 1
params.subsample = 1
params.tree-method = "hist"
On the server side, we allow user to specify training strategies / FL rounds /
participating clients / clients for evaluation, and evaluation fashion. Note that with
centralised-eval = true
, the sever will do centralised evaluation and all
functionalities for client evaluation will be disabled.
On the client side, we can define various options for client data partitioning. Besides,
clients also have an option to conduct evaluation on centralised test set by setting
centralised-eval = true
, as well as an option to perform scaled learning rate based
on the number of clients by setting scaled-lr = true
.
Example Commands¶
To run bagging aggregation for 5 rounds evaluated on centralised test set:
flwr run . --run-config "train-method='bagging' num-server-rounds=5 centralised-eval=true"
To run cyclic training with linear partitioner type evaluated on centralised test set:
flwr run . --run-config "train-method='cyclic' partitioner-type='linear'
centralised-eval-client=true"
Note
The full code for
this comprehensive example can be found in examples/xgboost-comprehensive
in the
Flower GitHub repository.
Video Tutorial¶
Note
The video shown below shows how to setup a XGBoost + Flower project using our previously recommended APIs. A new video tutorial will be released that shows the new APIs (as the content above does)