Démarrage rapide de JAX#

This tutorial will show you how to use Flower to build a federated version of an existing JAX workload. We are using JAX to train a linear regression model on a scikit-learn dataset. We will structure the example similar to our PyTorch - From Centralized To Federated walkthrough. First, we build a centralized training approach based on the Linear Regression with JAX tutorial`. Then, we build upon the centralized training code to run the training in a federated fashion.

Avant de commencer à construire notre exemple JAX, nous devons installer les paquets jax, jaxlib, scikit-learn, et flwr :

$ pip install jax jaxlib scikit-learn flwr

Régression linéaire avec JAX#

Nous commençons par une brève description du code d’entraînement centralisé basé sur un modèle Régression linéaire. Si tu veux une explication plus approfondie de ce qui se passe, jette un coup d’œil à la documentation officielle JAX.

Créons un nouveau fichier appelé jax_training.py avec tous les composants nécessaires pour un apprentissage traditionnel (centralisé) de la régression linéaire. Tout d’abord, les paquets JAX jax et jaxlib doivent être importés. En outre, nous devons importer sklearn puisque nous utilisons make_regression pour le jeu de données et train_test_split pour diviser le jeu de données en un jeu d’entraînement et un jeu de test. Tu peux voir que nous n’avons pas encore importé le paquet flwr pour l’apprentissage fédéré, ce qui sera fait plus tard.

from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

key = jax.random.PRNGKey(0)

La fonction load_data() charge les ensembles d’entraînement et de test mentionnés.

def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
    # create our dataset and start with similar datasets for different clients
    X, y = make_regression(n_features=3, random_state=0)
    X, X_test, y, y_test = train_test_split(X, y)
    return X, y, X_test, y_test

L’architecture du modèle (un modèle Régression linéaire très simple) est définie dans load_model().

def load_model(model_shape) -> Dict:
    # model weights
    params = {
        'b' : jax.random.uniform(key),
        'w' : jax.random.uniform(key, model_shape)
    }
    return params

Nous devons maintenant définir l’entraînement (fonction train()), qui boucle sur l’ensemble d’entraînement et mesure la perte (fonction loss_fn()) pour chaque lot d’exemples d’entraînement. La fonction de perte est séparée puisque JAX prend des dérivés avec une fonction grad() (définie dans la fonction main() et appelée dans train()).

def loss_fn(params, X, y) -> Callable:
    err = jnp.dot(X, params['w']) + params['b'] - y
    return jnp.mean(jnp.square(err))  # mse

def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
    num_examples = X.shape[0]
    for epochs in range(10):
        grads = grad_fn(params, X, y)
        params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
        loss = loss_fn(params,X, y)
        # if epochs % 10 == 9:
        #     print(f'For Epoch {epochs} loss {loss}')
    return params, loss, num_examples

L’évaluation du modèle est définie dans la fonction evaluation(). La fonction prend tous les exemples de test et mesure la perte du modèle de régression linéaire.

def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
    num_examples = X_test.shape[0]
    err_test = loss_fn(params, X_test, y_test)
    loss_test = jnp.mean(jnp.square(err_test))
    # print(f'Test loss {loss_test}')
    return loss_test, num_examples

Après avoir défini le chargement des données, l’architecture du modèle, l’entraînement et l’évaluation, nous pouvons tout assembler et entraîner notre modèle à l’aide de JAX. Comme nous l’avons déjà mentionné, la fonction jax.grad() est définie dans main() et transmise à train().

def main():
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]
    grad_fn = jax.grad(loss_fn)
    print("Model Shape", model_shape)
    params = load_model(model_shape)
    params, loss, num_examples = train(params, grad_fn, X, y)
    evaluation(params, grad_fn, X_test, y_test)


if __name__ == "__main__":
    main()

Tu peux maintenant exécuter ta charge de travail (centralisée) de régression linéaire JAX :

python3 jax_training.py

Jusqu’à présent, tout cela devrait te sembler assez familier si tu as déjà utilisé JAX. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un simple système d’apprentissage fédéré composé d’un serveur et de deux clients.

JAX rencontre Flower#

Le concept de fédération d’une charge de travail existante est toujours le même et facile à comprendre. Nous devons démarrer un serveur, puis utiliser le code dans jax_training.py pour les clients qui sont connectés au serveur.Le serveur envoie les paramètres du modèle aux clients.Les clients exécutent la formation et mettent à jour les paramètres.Les paramètres mis à jour sont renvoyés au serveur, qui fait la moyenne de toutes les mises à jour de paramètres reçues.Ceci décrit un tour du processus d’apprentissage fédéré, et nous répétons cette opération pour plusieurs tours.

Notre exemple consiste en un serveur et deux clients. Commençons par configurer server.py. Le serveur doit importer le paquet Flower flwr. Ensuite, nous utilisons la fonction start_server pour démarrer un serveur et lui demander d’effectuer trois cycles d’apprentissage fédéré.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))

Nous pouvons déjà démarrer le serveur :

python3 server.py

Enfin, nous allons définir la logique de notre client dans client.py et nous appuyer sur la formation JAX définie précédemment dans jax_training.py. Notre client doit importer flwr, mais aussi jax et jaxlib pour mettre à jour les paramètres de notre modèle JAX :

from typing import Dict, List, Callable, Tuple

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

import jax_training

L’implémentation d’un client Flower signifie essentiellement l’implémentation d’une sous-classe de flwr.client.Client ou flwr.client.NumPyClient. Notre implémentation sera basée sur flwr.client.NumPyClient et nous l’appellerons FlowerClient. NumPyClient est légèrement plus facile à implémenter que Client si vous utilisez un framework avec une bonne interopérabilité NumPy (comme JAX) parce qu’il évite une partie du boilerplate qui serait autrement nécessaire. FlowerClient doit implémenter quatre méthodes, deux méthodes pour obtenir/régler les paramètres du modèle, une méthode pour former le modèle, et une méthode pour tester le modèle :

  1. set_parameters (optional)
    • règle les paramètres du modèle local reçus du serveur

    • transforme les paramètres en NumPy ndarray’s

    • boucle sur la liste des paramètres du modèle reçus sous forme de NumPy ndarray’s (pensez à la liste des couches du réseau neuronal)

  2. get_parameters
    • récupère les paramètres du modèle et les renvoie sous forme de liste de ndarray NumPy (ce qui correspond à ce que flwr.client.NumPyClient attend)

  3. fit
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • entraîne le modèle sur l’ensemble d’apprentissage local

    • récupère les paramètres du modèle local mis à jour et les renvoie au serveur

  4. évaluer
    • mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur

    • évaluer le modèle mis à jour sur l’ensemble de test local

    • renvoie la perte locale au serveur

La partie la plus difficile consiste à transformer les paramètres du modèle JAX de DeviceArray en NumPy ndarray pour les rendre compatibles avec NumPyClient.

Les deux méthodes NumPyClient fit et evaluate utilisent les fonctions train() et evaluate() définies précédemment dans jax_training.py. Ce que nous faisons vraiment ici, c’est que nous indiquons à Flower, par le biais de notre sous-classe NumPyClient, laquelle de nos fonctions déjà définies doit être appelée pour l’entraînement et l’évaluation. Nous avons inclus des annotations de type pour te donner une meilleure compréhension des types de données qui sont transmis.

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing using linear regression and JAX."""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params= params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self, config) -> Dict:
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value

    def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
        # Collect model parameters and update the parameters of the local model
        value=jnp.ndarray
        params_item = list(zip(self.params.keys(),parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params


    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(config={}), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate the model on a local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

Après avoir défini le processus de fédération, nous pouvons l’exécuter.

def main() -> None:
    """Load data, start MNISTClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

if __name__ == "__main__":
    main()

Tu peux maintenant ouvrir deux autres fenêtres de terminal et exécuter les commandes suivantes

python3 client.py

dans chaque fenêtre (assure-toi que le serveur est toujours en cours d’exécution avant de le faire) et tu verras que ton projet JAX exécute l’apprentissage fédéré sur deux clients. Félicitations !

Prochaines étapes#

The source code of this example was improved over time and can be found here: Quickstart JAX. Our example is somewhat over-simplified because both clients load the same dataset.

Tu es maintenant prêt à approfondir ce sujet. Pourquoi ne pas utiliser un modèle plus sophistiqué ou un ensemble de données différent ? Pourquoi ne pas ajouter d’autres clients ?