快速入门 JAX#

本教程将向您展示如何使用 Flower 构建现有 JAX 的联邦学习版本。我们将使用 JAX 在 scikit-learn 数据集上训练线性回归模型。我们将采用与 PyTorch - 从集中式到联邦式 教程中类似的示例结构。首先,我们根据 JAX 的线性回归 教程构建集中式训练方法。然后,我们在集中式训练代码的基础上以联邦方式运行训练。

在开始构建 JAX 示例之前,我们需要安装软件包 jaxjaxlibscikit-learnflwr

$ pip install jax jaxlib scikit-learn flwr

使用 JAX 进行线性回归#

首先,我们将简要介绍基于 Linear Regression 模型的集中式训练代码。如果您想获得更深入的解释,请参阅官方的 JAX 文档

让我们创建一个名为 jax_training.py 的新文件,其中包含传统(集中式)线性回归训练所需的所有组件。首先,需要导入 JAX 包 jaxjaxlib。此外,我们还需要导入 sklearn,因为我们使用 make_regression 创建数据集,并使用 train_test_split 将数据集拆分成训练集和测试集。您可以看到,我们还没有导入用于联邦学习的 flwr 软件包,这将在稍后完成。

from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

key = jax.random.PRNGKey(0)

load_data() 函数会加载上述训练集和测试集。

def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
    # create our dataset and start with similar datasets for different clients
    X, y = make_regression(n_features=3, random_state=0)
    X, X_test, y, y_test = train_test_split(X, y)
    return X, y, X_test, y_test

模型结构(一个非常简单的 Linear Regression 线性回归模型)在 load_model() 中定义。

def load_model(model_shape) -> Dict:
    # model weights
    params = {
        'b' : jax.random.uniform(key),
        'w' : jax.random.uniform(key, model_shape)
    }
    return params

现在,我们需要定义训练函数( train())。它循环遍历训练集,并计算每批训练数据的损失值(函数 loss_fn())。由于 JAX 使用 grad() 函数提取导数(在 main() 函数中定义,并在 train() 中调用),因此损失函数是独立的。

def loss_fn(params, X, y) -> Callable:
    err = jnp.dot(X, params['w']) + params['b'] - y
    return jnp.mean(jnp.square(err))  # mse

def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
    num_examples = X.shape[0]
    for epochs in range(10):
        grads = grad_fn(params, X, y)
        params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
        loss = loss_fn(params,X, y)
        # if epochs % 10 == 9:
        #     print(f'For Epoch {epochs} loss {loss}')
    return params, loss, num_examples

模型的评估在函数 evaluation() 中定义。该函数获取所有测试数据,并计算线性回归模型的损失值。

def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
    num_examples = X_test.shape[0]
    err_test = loss_fn(params, X_test, y_test)
    loss_test = jnp.mean(jnp.square(err_test))
    # print(f'Test loss {loss_test}')
    return loss_test, num_examples

在定义了数据加载、模型架构、训练和评估之后,我们就可以把这些放在一起,使用 JAX 训练我们的模型了。如前所述,jax.grad() 函数在 main() 中定义,并传递给 train()

def main():
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]
    grad_fn = jax.grad(loss_fn)
    print("Model Shape", model_shape)
    params = load_model(model_shape)
    params, loss, num_examples = train(params, grad_fn, X, y)
    evaluation(params, grad_fn, X_test, y_test)


if __name__ == "__main__":
    main()

现在您可以运行(集中式)JAX 线性回归工作了:

python3 jax_training.py

到目前为止,如果你以前使用过 JAX,就会对这一切感到很熟悉。下一步,让我们利用已构建的代码创建一个简单的联邦学习系统(一个服务器和两个客户端)。

JAX 结合 Flower#

把现有工作联邦化的概念始终是相同的,也很容易理解。我们要启动一个*服务器*,然后对连接到*服务器*的*客户端*运行 :code:`jax_training.py`中的代码。服务器*向客户端发送模型参数,*客户端*运行训练并更新参数。更新后的参数被发回*服务器,然后服务器对所有收到的参数进行平均聚合。以上的描述构成了一轮联邦学习,我们将重复进行多轮学习。

我们的示例包括一个*服务器*和两个*客户端*。让我们先设置 server.py*服务器*需要导入 Flower 软件包 flwr。接下来,我们使用 start_server 函数启动服务器,并让它执行三轮联邦学习。

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))

我们已经可以启动*服务器*了:

python3 server.py

最后,我们将在 client.py 中定义我们的 client 逻辑,并以之前在 jax_training.py 中定义的 JAX 训练为基础。我们的 client 需要导入 flwr,还需要导入 jaxjaxlib 以更新 JAX 模型的参数:

from typing import Dict, List, Callable, Tuple

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

import jax_training

实现一个 Flower *client*基本上意味着去实现一个 flwr.client.Clientflwr.client.NumPyClient 的子类。我们的代码实现将基于 flwr.client.NumPyClient,并将其命名为 FlowerClient。如果使用具有良好 NumPy 互操作性的框架(如 JAX),NumPyClientClient`更容易实现,因为它避免了一些不必要的操作。:code:`FlowerClient 需要实现四个方法,两个用于获取/设置模型参数,一个用于训练模型,一个用于测试模型:

  1. set_parameters (可选)
    • 在本地模型上设置从服务器接收的模型参数

    • 将参数转换为 NumPy :code:`ndarray`格式

    • 循环遍历以 NumPy ndarray 形式接收的模型参数列表(可以看作神经网络的列表)

  2. get_parameters
    • 获取模型参数,并以 NumPy :code:`ndarray`的列表形式返回(这正是 :code:`flwr.client.NumPyClient`所匹配的格式)

  3. fit
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地训练集上训练模型

    • 获取更新后的本地模型参数并返回服务器

  4. evaluate
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地测试集上评估更新后的模型

    • 向服务器返回本地损失值

具有挑战性的部分是将 JAX 模型参数从 DeviceArray 转换为 NumPy ndarray,使其与 NumPyClient 兼容。

这两个 NumPyClient 方法 fitevaluate 使用了之前在 jax_training.py 中定义的函数 train()evaluate()。因此,我们在这里要做的就是通过 NumPyClient 子类告知 Flower 在训练和评估时要调用哪些已定义的函数。我们加入了类型注解,以便让您更好地理解传递的数据类型。

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing using linear regression and JAX."""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params= params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self, config) -> Dict:
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value

    def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
        # Collect model parameters and update the parameters of the local model
        value=jnp.ndarray
        params_item = list(zip(self.params.keys(),parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params


    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(config={}), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate the model on a local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

定义了联邦进程后,我们就可以运行它了。

def main() -> None:
    """Load data, start MNISTClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

if __name__ == "__main__":
    main()

就是这样,现在你可以打开另外两个终端窗口,然后运行

python3 client.py

确保服务器仍在运行,然后在每个客户端窗口就能看到你的 JAX 项目在两个客户端上运行联邦学习了。祝贺!

下一步工作#

此示例的源代码经过长期改进,可在此处找到: Quickstart JAX。我们的示例有些过于简单,因为两个客户端都加载了相同的数据集。

现在,您已准备好进行更深一步探索了。例如使用更复杂的模型或使用不同的数据集会如何?增加更多客户端会如何?