Quickstart JAX

Dans ce tutoriel d’apprentissage fédéré, nous allons apprendre à entraîner un modèle CNN sur le jeu de données MNIST en utilisant Flower et JAX avec la bibliothèque Flax. Il est recommandé de créer un environnement virtuel et de faire tourner tout dans un virtualenv.

Utilisons flwr new pour créer un projet complet Flower+JAX. Cela générera tous les fichiers nécessaires pour exécuter une fédération de 50 nœuds en utilisant FedAvg. Par défaut, l’application générée utilise un profil de simulation local qui flwr run soumet à un SuperLink géré local, qui exécute ensuite l’exécution avec le runtime d’exécution Flower Simulation. Le jeu de données MNIST sera partitionné en utilisant Flower Datasets’s IidPartitioner.

Maintenant que nous avons une idée approximative de ce que cet exemple est sur, allons-y. Tout d’abord, installez Flower dans votre nouvel environnement :

# In a new Python environment
$ pip install flwr[simulation]

Ensuite, exécutez la commande suivante :

$ flwr new @flwrlabs/quickstart-jax

Après l’avoir lancé, vous remarquerez qu’un nouveau répertoire nommé quickstart-jax a été créé. Il devrait avoir la structure suivante:

quickstart-jax
├── jaxexample
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

Si vous n’avez pas encore installé le projet et ses dépendances, vous pouvez le faire avec :

# From the directory where your pyproject.toml is
$ pip install -e .

Pour lancer le projet, faites:

# Run with default arguments and stream logs
$ flwr run . --stream

Le processus flwr run . soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Pour le workflow local complet, voir Exécuter Flower Localement avec un SuperLink Géré.

Avec les arguments par défaut, vous verrez une sortie en flux continu comme celle-ci :

Starting local SuperLink on 127.0.0.1:39093...
Successfully started run 1859953118041441032
INFO :      Starting FedAvg strategy:
INFO :          ├── Number of rounds: 5
INFO :      [ROUND 1/5]
INFO :      configure_train: Sampled 20 nodes (out of 50)
INFO :      aggregate_train: Received 20 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'train_loss': 2.1116, 'train_acc': 0.2821}
INFO :      configure_evaluate: Sampled 20 nodes (out of 50)
INFO :      aggregate_evaluate: Received 20 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'eval_loss': 1.3394, 'eval_acc': 0.4984}
INFO :      [ROUND 2/5]
INFO :      ...
INFO :      [ROUND 5/5]
INFO :      ...
INFO :      Strategy execution finished in 60.58s
INFO :      Final results:
INFO :          ServerApp-side Evaluate Metrics:
INFO :          {}

Vous pouvez également surcharger les paramètres définis dans la section [tool.flwr.app.config] de pyproject.toml comme suit:

# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 batch-size=64"

Voici une explication de chaque composant dans le projet que vous venez de créer : partitionnement du jeu de données, modèle, définir la ClientApp et définir la ServerApp.

Les données

Ce tutoriel utilise Flower Datasets pour télécharger et partitionner facilement le jeu de données MNIST. Dans cet exemple, vous utiliserez IidPartitioner pour générer des partitions num_partitions. Vous pouvez choisir parmi les other partitioners disponibles dans les jeux de données Flower.

partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
    dataset="mnist",
    partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)

# Divide data on each node: 80% train, 20% test
partition = partition.train_test_split(test_size=0.2)

partition["train"].set_format("jax")
partition["test"].set_format("jax")


def apply_transforms(batch):
    """Apply transforms to the partition from FederatedDataset."""
    batch["image"] = [
        jnp.expand_dims(jnp.float32(img), 3) / 255 for img in batch["image"]
    ]
    batch["label"] = [jnp.int16(label) for label in batch["label"]]
    return batch


train_partition = (
    partition["train"]
    .batch(batch_size, num_proc=2, drop_last_batch=True)
    .with_transform(apply_transforms)
)
test_partition = (
    partition["test"]
    .batch(batch_size, num_proc=2, drop_last_batch=True)
    .with_transform(apply_transforms)
)

Le Modèle

Nous utilisons Flax pour définir un modèle CNN simple pour la classification d’images :

class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=6, kernel_size=(5, 5))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(5, 5))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=120)(x)
        x = nn.relu(x)
        x = nn.Dense(features=84)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


def create_train_state(learning_rate: float) -> TrainState:
    """Creates initial `TrainState`."""

    tx = optax.sgd(learning_rate, momentum=0.9)
    model, model_params = create_model(rng)
    return TrainState.create(apply_fn=model.apply, params=model_params, tx=tx)

En plus de définir l’architecture du modèle, nous incluons également des fonctions d’utilité pour effectuer à la fois l’entraînement (c’est-à-dire train()) et l’évaluation en utilisant le modèle ci-dessus.

@jax.jit
def apply_model(
    state: TrainState, images: Array, labels: Array
) -> Tuple[Any, Array, Array]:
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({"params": params}, images)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


@jax.jit
def update_model(state: TrainState, grads: Any) -> TrainState:
    return state.apply_gradients(grads=grads)


def train(state: TrainState, train_ds) -> Tuple[TrainState, float, float]:
    """Train for a single epoch."""

    epoch_loss = []
    epoch_accuracy = []

    for batch in train_ds:
        batch_images = batch["image"]
        batch_labels = batch["label"]
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, float(train_loss), float(train_accuracy)

L’Application Client

Les principales modifications que nous devons faire pour utiliser JAX avec Flower ont affaire à convertir les ArrayRecord reçues dans Message en tableaux NumPy et vice versa lors de la génération de la réponse Message du serveur ClientApp. Nous avons également besoin d’introduire les fonctions get_params() et set_params() pour définir des valeurs de paramètres pour le modèle JAX. Dans get_params(), les paramètres du modèle JAX sont extraits et représentés sous forme de liste de tableaux NumPy. La fonction set_params() est l’inverse : donnée une liste de tableaux NumPy, elle crée un nouveau modèle TrainState avec ces paramètres. Nous allons combiner ces fonctions avec les méthodes intégrées dans le ArrayRecord pour faire ces conversions :

def get_params(params: Any) -> List[npt.NDArray[Any]]:
    """Get model parameters as list of numpy arrays."""
    return [np.array(param) for param in jax.tree_util.tree_leaves(params)]


def set_params(
    train_state: TrainState, global_params: Sequence[npt.NDArray[Any]]
) -> TrainState:
    """Create a new trainstate with the global_params."""
    new_params_dict = jax.tree_util.tree_unflatten(
        jax.tree_util.tree_structure(train_state.params), list(global_params)
    )
    return train_state.replace(params=new_params_dict)
# Create train state object (model + optimizer)
lr = float(context.run_config["learning-rate"])
train_state = create_train_state(lr)

# Extract ArrayRecord from Message and convert to NumPy arrays
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
# Set JAX model parameters using the converted NumPy arrays
train_state = set_params(train_state, ndarrays)

# ... do some training

# Extract NumPy arrays from the JAX model and convert back into an ArrayRecord
params = get_params(train_state.params)
model_record = ArrayRecord(params)

Le reste de la fonctionnalité est directement inspiré du cas centralisé. Le ClientApp comporte trois méthodes de base (train, evaluate, et query) que nous pouvons implémenter à des fins différentes. Par exemple : train pour entraîner le modèle reçu en utilisant les données locales ; evaluate pour évaluer la performance du modèle reçu sur un jeu de validation ; et query pour récupérer des informations sur le nœud exécutant le ClientApp. Dans ce tutoriel, nous ne ferons que faire usage de train et evaluate.

Voyons comment la méthode train peut être implémentée. Elle reçoit en arguments d’entrée un Message depuis le ServerApp. Par défaut, elle porte :

  • un ArrayRecord avec les tableaux du modèle à fédérer. Par défaut, ils peuvent être récupérés avec la clé "arrays" lors de l’accès au contenu du message.

  • une ConfigRecord avec la configuration transmise depuis le ServerApp. Par défaut, elle peut être récupérée avec la clé "config" lors de l’accès au contenu du message.

La méthode train reçoit également le Context, donnant accès aux configurations pour votre exécution et nœud. Les hyperparamètres de la configuration d’exécution sont définis dans la section pyproject.toml de votre application Flower. La configuration du nœud ne peut être configurée que lors de l’exécution de Flower avec le Deployment Runtime et ce, directement configurable pendant les simulations.

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # Create train state object (model + optimizer)
    lr = float(context.run_config["learning-rate"])
    train_state = create_train_state(lr)
    # Extract numpy arrays from ArrayRecord before applying
    ndarrays = msg.content["arrays"].to_numpy_ndarrays()
    train_state = set_params(train_state, ndarrays)

    # Load the data
    partition_id = int(context.node_config["partition-id"])
    num_partitions = int(context.node_config["num-partitions"])
    batch_size = int(context.run_config["batch-size"])
    trainloader, _ = load_data(partition_id, num_partitions, batch_size)

    train_state, loss, acc = jax_train(train_state, trainloader)
    params = get_params(train_state.params)

    # Construct and return reply Message
    model_record = ArrayRecord(params)
    metrics = {
        "train_loss": float(loss),
        "train_acc": float(acc),
        "num-examples": int(len(trainloader) * batch_size),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

La méthode @app.evaluate() serait presque identique avec deux exceptions : (1) le modèle n’est pas entraîné localement, mais il est utilisé pour évaluer sa performance sur le jeu de validation local ; (2) inclure le modèle dans la réponse Message n’est plus nécessaire car il ne subit aucune modification locale.

L’Application Serveur

Pour construire un ServerApp, nous définissons sa méthode @app.main(). Cette méthode reçoit en arguments d’entrée :

  • Un objet Grid qui sera utilisé pour interagir avec les nœuds s’exécutant la ClientApp afin d’impliquer les clients dans une ronde d’entraînement/évaluation/requête ou autre.

  • Un objet Context qui fournit accès à la configuration de l’exécution.

Dans cet exemple, nous utilisons la fonction FedAvg et la configurons avec des valeurs spécifiques lues à partir de la configuration d’exécution. Vous pouvez trouver les valeurs par défaut définies dans le fichier pyproject.toml. Ensuite, l’exécution de la stratégie est lancée lorsqu’on invoque sa méthode start. À cela on passe :

  • l’objet Grid.

  • Un objet ArrayRecord portant un modèle initialisé aléatoirement qui servira de modèle global à fédérer.

  • un objet ConfigRecord avec les hyperparamètres d’entraînement (taux d’apprentissage) à envoyer aux clients. La stratégie insérera également le numéro de ronde actuel dans ce config avant de l’envoyer aux nœuds participants.

  • Le paramètre num_rounds spécifiant combien de rondes de FedAvg effectuer.

# Create ServerApp
app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Read run config
    fraction_evaluate: float = float(context.run_config["fraction-evaluate"])
    num_rounds: int = int(context.run_config["num-server-rounds"])
    lr: float = float(context.run_config["learning-rate"])

    rng = random.PRNGKey(0)
    rng, _ = random.split(rng)
    _, model_params = create_model(rng)
    params = get_params(model_params)

    # Initialize FedAvg strategy
    strategy = FedAvg(
        fraction_train=0.4,
        fraction_evaluate=fraction_evaluate,
    )

    # Start strategy, run FedAvg for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=ArrayRecord(params),
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
    )

    if context.run_config["save-model"]:
        # Save final model to disk
        print("\nSaving final model to disk...")
        ndarrays = result.arrays.to_numpy_ndarrays()
        np.savez("final_model.npz", *ndarrays)

Notez que le méthode start du stratégie retourne un objet résultat. Cet objet contient toutes les informations pertinentes sur le processus FL, y compris les poids de modèle final sous forme de ArrayRecord, et les métriques de formation et d’évaluation fédérées comme des MetricRecords. Vous pouvez facilement logger les métriques en utilisant Python’s pprint et, si save-model est défini sur true, sauvegarder les tableaux NumPy globaux du modèle à l’aide de np.savez() comme montré ci-dessus.

Félicitations ! Vous avez réussi à créer et à lancer votre premier système d’apprentissage fédéré pour JAX avec Flower!

Astuce

Vérifiez la documentation de Run simulations pour en savoir plus sur la façon de configurer et d’exécuter les simulations Flower.

Note

Vérifiez le code source de la version étendue de ce tutoriel dans examples/quickstart-jax dans le dépôt GitHub Flower.