快速入门 JAX

本教程将向您展示如何使用 Flower 构建现有 JAX 的联邦学习版本。我们将使用 JAX 在 scikit-learn 数据集上训练线性回归模型。我们将采用与 PyTorch - 从集中式到联邦式 教程中类似的示例结构。首先,我们根据 JAX 的线性回归 教程构建集中式训练方法。然后,我们在集中式训练代码的基础上以联邦方式运行训练。

Before we start building our JAX example, we need install the packages jax, jaxlib, scikit-learn, and flwr:

$ pip install jax jaxlib scikit-learn flwr

使用 JAX 进行线性回归

We begin with a brief description of the centralized training code based on a Linear Regression model. If you want a more in-depth explanation of what's going on then have a look at the official JAX documentation.

Let's create a new file called jax_training.py with all the components required for a traditional (centralized) linear regression training. First, the JAX packages jax and jaxlib need to be imported. In addition, we need to import sklearn since we use make_regression for the dataset and train_test_split to split the dataset into a training and test set. You can see that we do not yet import the flwr package for federated learning. This will be done later.

from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

key = jax.random.PRNGKey(0)

The load_data() function loads the mentioned training and test sets.

def load_data() -> (
    Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]
):
    # create our dataset and start with similar datasets for different clients
    X, y = make_regression(n_features=3, random_state=0)
    X, X_test, y, y_test = train_test_split(X, y)
    return X, y, X_test, y_test

The model architecture (a very simple Linear Regression model) is defined in load_model().

def load_model(model_shape) -> Dict:
    # model weights
    params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
    return params

We now need to define the training (function train()), which loops over the training set and measures the loss (function loss_fn()) for each batch of training examples. The loss function is separate since JAX takes derivatives with a grad() function (defined in the main() function and called in train()).

def loss_fn(params, X, y) -> Callable:
    err = jnp.dot(X, params["w"]) + params["b"] - y
    return jnp.mean(jnp.square(err))  # mse


def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
    num_examples = X.shape[0]
    for epochs in range(10):
        grads = grad_fn(params, X, y)
        params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
        loss = loss_fn(params, X, y)
        # if epochs % 10 == 9:
        #     print(f'For Epoch {epochs} loss {loss}')
    return params, loss, num_examples

The evaluation of the model is defined in the function evaluation(). The function takes all test examples and measures the loss of the linear regression model.

def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
    num_examples = X_test.shape[0]
    err_test = loss_fn(params, X_test, y_test)
    loss_test = jnp.mean(jnp.square(err_test))
    # print(f'Test loss {loss_test}')
    return loss_test, num_examples

Having defined the data loading, model architecture, training, and evaluation we can put everything together and train our model using JAX. As already mentioned, the jax.grad() function is defined in main() and passed to train().

def main():
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]
    grad_fn = jax.grad(loss_fn)
    print("Model Shape", model_shape)
    params = load_model(model_shape)
    params, loss, num_examples = train(params, grad_fn, X, y)
    evaluation(params, grad_fn, X_test, y_test)


if __name__ == "__main__":
    main()

现在您可以运行(集中式)JAX 线性回归工作了:

python3 jax_training.py

到目前为止,如果你以前使用过 JAX,就会对这一切感到很熟悉。下一步,让我们利用已构建的代码创建一个简单的联邦学习系统(一个服务器和两个客户端)。

JAX 结合 Flower

The concept of federating an existing workload is always the same and easy to understand. We have to start a server and then use the code in jax_training.py for the clients that are connected to the server. The server sends model parameters to the clients. The clients run the training and update the parameters. The updated parameters are sent back to the server, which averages all received parameter updates. This describes one round of the federated learning process, and we repeat this for multiple rounds.

Our example consists of one server and two clients. Let's set up server.py first. The server needs to import the Flower package flwr. Next, we use the start_server function to start a server and tell it to perform three rounds of federated learning.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(
        server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
    )

我们已经可以启动*服务器*了:

python3 server.py

Finally, we will define our client logic in client.py and build upon the previously defined JAX training in jax_training.py. Our client needs to import flwr, but also jax and jaxlib to update the parameters on our JAX model:

from typing import Dict, List, Callable, Tuple

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

import jax_training

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we'll call it FlowerClient. NumPyClient is slightly easier to implement than Client if you use a framework with good NumPy interoperability (like JAX) because it avoids some of the boilerplate that would otherwise be necessary. FlowerClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters (optional)
    • 在本地模型上设置从服务器接收的模型参数

    • transform parameters to NumPy ndarray's

    • loop over the list of model parameters received as NumPy ndarray's (think list of neural network layers)

  2. get_parameters
    • get the model parameters and return them as a list of NumPy ndarray's (which is what flwr.client.NumPyClient expects)

  3. fit
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地训练集上训练模型

    • 获取更新后的本地模型参数并返回服务器

  4. evaluate
    • 用从服务器接收到的参数更新本地模型的参数

    • 在本地测试集上评估更新后的模型

    • 向服务器返回本地损失值

The challenging part is to transform the JAX model parameters from DeviceArray to NumPy ndarray to make them compatible with NumPyClient.

The two NumPyClient methods fit and evaluate make use of the functions train() and evaluate() previously defined in jax_training.py. So what we really do here is we tell Flower through our NumPyClient subclass which of our already defined functions to call for training and evaluation. We included type annotations to give you a better understanding of the data types that get passed around.

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing using linear regression and JAX."""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params = params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self, config) -> Dict:
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value

    def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
        # Collect model parameters and update the parameters of the local model
        value = jnp.ndarray
        params_item = list(zip(self.params.keys(), parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params

    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(
            self.params, self.grad_fn, self.train_x, self.train_y
        )
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(config={}), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate the model on a local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(
            self.params, self.grad_fn, self.test_x, self.test_y
        )
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

定义了联邦进程后,我们就可以运行它了。

def main() -> None:
    """Load data, start MNISTClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())


if __name__ == "__main__":
    main()

就是这样,现在你可以打开另外两个终端窗口,然后运行

python3 client.py

确保服务器仍在运行,然后在每个客户端窗口就能看到你的 JAX 项目在两个客户端上运行联邦学习了。祝贺!

下一步工作

此示例的源代码经过长期改进,可在此处找到: Quickstart JAX。我们的示例有些过于简单,因为两个客户端都加载了相同的数据集。

现在,您已准备好进行更深一步探索了。例如使用更复杂的模型或使用不同的数据集会如何?增加更多客户端会如何?