Quickstart MLX

Dans ce tutoriel d’apprentissage fédéré, nous allons apprendre à entraîner un simple MLP sur MNIST en utilisant Flower et MLX. Il est recommandé de créer un environnement virtuel et d’exécuter tout dans un virtualenv.

Utilisons flwr new pour créer un projet complet Flower+MLX. Cela générera tous les fichiers nécessaires pour exécuter une fédération de 10 nœuds en utilisant FedAvg. Par défaut, l’application générée utilise un profil de simulation local qui flwr run soumet à un SuperLink géré local, qui exécute ensuite l’exécution avec Flower Simulation Runtime. Le dataset sera partitionné en utilisant les partitions de Flower Dataset’s IidPartitioner.

Maintenant que nous avons une idée approximative de ce que cet exemple est sur, allons-y. Tout d’abord, installez Flower dans votre nouvel environnement :

# In a new Python environment
$ pip install flwr[simulation]

Ensuite, exécutez la commande suivante :

$ flwr new @flwrlabs/quickstart-mlx

Après avoir exécuté cela, vous remarquerez que un nouveau répertoire nommé quickstart-mlx a été créé. Il devrait avoir la structure suivante :

quickstart-mlx
├── mlxexample
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

Si vous n’avez pas encore installé le projet et ses dépendances, vous pouvez le faire avec :

# From the directory where your pyproject.toml is
$ pip install -e .

Pour exécuter le projet, faites :

# Run with default arguments and stream logs
$ flwr run . --stream

Le processus flwr run . soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Pour le workflow local complet, voir Exécuter Flower Localement avec un SuperLink Géré.

Avec les arguments par défaut, vous verrez une sortie en flux comme celle-ci :

Starting local SuperLink on 127.0.0.1:39093...
Successfully started run 1859953118041441032
INFO :      Starting FedAvg strategy:
INFO :          ├── Number of rounds: 3
INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 10 nodes (out of 10)
INFO :      aggregate_train: Received 10 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'accuracy': 0.270375007390976, 'loss': 2.2390866}
INFO :      configure_evaluate: Sampled 10 nodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'accuracy': 0.2720000118017197, 'loss': 2.24028}
INFO :      [ROUND 2/3]
INFO :      ...
INFO :      [ROUND 3/3]
INFO :      ...
INFO :      Strategy execution finished in 9.96s
INFO :      Final results:
INFO :          ServerApp-side Evaluate Metrics:
INFO :          {}

Vous pouvez également surcharger les paramètres définis dans la section [tool.flwr.app.config] du fichier pyproject.toml comme suit :

# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 learning-rate=0.05"

Voici une explication de chaque composant dans le projet que vous venez de créer : partitionnement des jeux de données, modèle, définition de la stratégie et définition de l’application.

Les données

Nous allons utiliser Flower Datasets pour télécharger et partitionner facilement le dataset MNIST. Dans cet exemple, vous utiliserez la IidPartitioner pour générer num_partitions ` partitions. You can choose from other partitioners <https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.html>`_ disponibles dans les données de Flower :

partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
    dataset="ylecun/mnist",
    partitioners={"train": partitioner},
    trust_remote_code=True,
)
partition = fds.load_partition(partition_id)
partition_splits = partition.train_test_split(test_size=0.2, seed=42)

partition_splits["train"].set_format("numpy")
partition_splits["test"].set_format("numpy")

train_partition = partition_splits["train"].map(
    lambda img: {"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0},
    input_columns="image",
)
test_partition = partition_splits["test"].map(
    lambda img: {"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0},
    input_columns="image",
)

data = (
    train_partition["img"],
    train_partition["label"].astype(np.uint32),
    test_partition["img"],
    test_partition["label"].astype(np.uint32),
)

train_images, train_labels, test_images, test_labels = map(mx.array, data)

Le Modèle

Nous définissons le modèle comme indiqué dans le centralized MLX example, c’est un simple MLP :

class MLP(nn.Module):
    """A simple MLP."""

    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mx.maximum(l(x), 0.0)
        return self.layers[-1](x)

Nous définissons également certaines fonctions d’utilité pour tester notre modèle et pour itérer sur des lots.

def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y))


def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)


def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

L’Application Client

Les principales modifications que nous devons apporter pour utiliser MLX avec Flower seront trouvées dans les fonctions get_params() et set_params(). MLX ne fournit pas une façon facile de convertir les paramètres du modèle en une liste d’objets np.array (le format dont nous avons besoin pour la sérialisation des messages).

MLX stocke ses paramètres comme suit :

{
"layers": [
    {"weight": mlx.core.array, "bias": mlx.core.array},
    {"weight": mlx.core.array, "bias": mlx.core.array},
    ...,
    {"weight": mlx.core.array, "bias": mlx.core.array}
]
}

Par conséquent, pour obtenir notre liste d’objets np.array, nous avons besoin d’extrait chaque tableau et de les convertir en tableaux NumPy :

def get_params(model):
    layers = model.parameters()["layers"]
    return [np.array(val) for layer in layers for _, val in layer.items()]

Pour la fonction set_params(), nous effectuons l’opération inverse. Nous recevons une liste de tableaux NumPy et voulons les convertir en paramètres MLX. Par conséquent, nous itérons à travers des paires de paramètres et leur assignons aux clés weight et bias de chaque dictionnaire de couche :

def set_params(model, parameters):
    new_params = {}
    new_params["layers"] = [
        {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
        for i in range(0, len(parameters), 2)
    ]
    model.update(new_params)

Le reste de la fonctionnalité est directement inspiré du cas centralisé. Le ClientApp entraînera le modèle sur les données locales en utilisant le boucle d’entraînement standard MLX :

# Train the model on local data
for _ in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        _, grads = loss_and_grad_fn(model, X, y)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

Asseyons tout ensemble et voyons l’implémentation complète du ClientApp. Tout d’abord, la conduite dans un tour de formation est définie à l’intérieur d’une fonction enveloppée par le décorateur @app.train().

Après avoir lu les paramètres de configuration à partir du Context, nous instancions le modèle et appliquons les paramètres globaux envoyés par le serveur en utilisant la fonction set_params() définie ci-dessus. Nous définissons ensuite l’optimiseur et la fonction de perte, chargeons la partition de données locales à l’aide de load_data(), et entraînons le modèle sur les données. Enfin, nous calculons l’exactitude et la perte sur les données d’entraînement et construisons une réponse Message contenant un ArrayRecord avec les paramètres du modèle mis à jour et un MetricRecord avec l’exactitude et la perte de formation. Il est très important qu’elle contienne également la clé num-examples qui sera utilisée par le serveur pour effectuer une moyenne pondérée des paramètres du modèle. La valeur de cette clé est le nombre d’exemples de formation dans la partition de données locales.

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # Read config
    num_layers = context.run_config["num-layers"]
    input_dim = context.run_config["input-dim"]
    hidden_dim = context.run_config["hidden-dim"]
    batch_size = context.run_config["batch-size"]
    learning_rate = context.run_config["learning-rate"]
    num_epochs = context.run_config["local-epochs"]

    # Instantiate model and apply global parameters
    model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
    ndarrays = msg.content["arrays"].to_numpy_ndarrays()
    set_params(model, ndarrays)

    # Define optimizer and loss function
    optimizer = optim.SGD(learning_rate=learning_rate)
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

    # Load data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    train_images, train_labels, _, _ = load_data(partition_id, num_partitions)

    # Train on local data
    for _ in range(num_epochs):
        for X, y in batch_iterate(batch_size, train_images, train_labels):
            _, grads = loss_and_grad_fn(model, X, y)
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)

    # Compute train accuracy and loss
    accuracy = eval_fn(model, train_images, train_labels)
    loss = loss_fn(model, train_images, train_labels)
    # Construct and return reply Message
    model_record = ArrayRecord(get_params(model))
    metrics = {
        "num-examples": len(train_images),
        "accuracy": float(accuracy.item()),
        "loss": float(loss.item()),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

Le ClientApp permet également l’évaluation du modèle sur les données de test locales. Cela peut être fait en définissant une fonction enveloppée par le décorateur @app.evaluate(). La signature de la fonction est identique à celle de la fonction train(). Comme montré ci-dessous, la fonction d’évaluation est très similaire à la fonction d’entraînement, excepté que nous n’exécutons pas d’entraînement. Nous devons toujours mettre à jour les paramètres du modèle avec ceux envoyés par le serveur, puis calculer la perte et l’exactitude en utilisant les fonctions définies ci-dessus. Enfin, nous construisons une réponse Message contenant un objet MetricRecord avec l’exactitude d’évaluation et la perte, ainsi que la clé num-examples, qui sera utilisée par le serveur pour effectuer une moyenne pondérée des métriques.

@app.evaluate()
def evaluate(msg: Message, context: Context):
    """Evaluate the model on local data."""

    # ... read config, instantiate model, load data

    # Evaluate the model on local data
    accuracy = eval_fn(model, test_images, test_labels)
    loss = loss_fn(model, test_images, test_labels)

    # Construct and return reply Message
    metrics = {
        "num-examples": len(test_images),
        "accuracy": float(accuracy.item()),
        "loss": float(loss.item()),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"metrics": metric_record})
    return Message(content=content, reply_to=msg)

L’Application Serveur

L’Application Serveur

Pour construire un ServerApp, nous définissons sa méthode @app.main(). Cette méthode reçoit en entrée les arguments suivants :

  • Un objet Grid qui sera utilisé pour interagir avec les nœuds s’exécutant la ClientApp afin d’impliquer les clients dans une ronde d’entraînement/évaluation/requête ou autre.

  • Un objet Context qui fournit accès à la configuration de l’exécution.

Dans cet exemple, nous utilisons FedAvg et le laissons avec ses paramètres par défaut. Ensuite, après avoir initialisé le MLP qui servirait de modèle global dans la première ronde, l’exécution de la stratégie est lancée lorsqu’on invoque sa méthode start. À cela on passe:

  • l’objet Grid.

  • un ArrayRecord portant un modèle initialisé aléatoirement qui servira de modèle global pour fédérer.

  • Le paramètre num_rounds spécifiant combien de rondes de FedAvg effectuer.

# Create ServerApp
app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""
    # Read from config
    num_rounds = context.run_config["num-server-rounds"]
    num_layers = context.run_config["num-layers"]
    input_dim = context.run_config["input-dim"]
    hidden_dim = context.run_config["hidden-dim"]

    # Initialize global model
    model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
    params = get_params(model)
    arrays = ArrayRecord(params)

    # Initialize FedAvg strategy
    strategy = FedAvg()

    # Start strategy, run FedAvg for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        num_rounds=num_rounds,
    )

    if context.run_config["save-model"]:
        # Save final model to disk
        print("\nSaving final model to disk...")
        ndarrays = result.arrays.to_numpy_ndarrays()
        set_params(model, ndarrays)
        model.save_weights("final_model.npz")

Notez que la méthode start du stratégie retourne un objet Result. Cet objet contient toutes les informations pertinentes sur le processus FL, y compris les poids du modèle final sous forme de ArrayRecord, et des métriques d’entraînement et d’évaluation fédérées sous forme de MetricRecords.

Félicitations ! Vous avez réussi à créer et exécuter votre premier système d’apprentissage fédéré.

Astuce

Vérifiez la documentation de Run simulations pour en savoir plus sur la façon de configurer et d’exécuter les simulations Flower.

Note

Vérifiez la partie source code de l’édition étendue de ce tutoriel dans examples/quickstart-mlx dans le dépôt GitHub Flower.