Écrivez votre première application Flower avec PyTorch

Bienvenue dans la prochaine partie du tutoriel d’intelligence artificielle collaborative Flower !

Dans les étapes précédentes, vous avez créé une fédération simulée sur SuperGrid, exécuté une application Flower, téléchargé l’application @flwrlabs/demo et appris comment ServerApp, ClientApp, des stratégies et pyproject.toml s’intègrent. Dans cette étape, vous allez utiliser le même workflow avec une application Flower plus réaliste : une application PyTorch qui entraîne un petit classificateur d’images sur CIFAR-10.

Astuce

Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Flower Discuss ou Flower Slack pour vous présenter, poser des questions et obtenir de l’aide.

Allons-y ! 🌼

Créer l’application

Utilisez flwr new pour récupérer l’application quickstart PyTorch à partir de Hub Flower:

$ flwr new @flwrlabs/quickstart-pytorch

Après avoir exécuté la commande, un nouveau répertoire nommé quickstart-pytorch sera créé:

quickstart-pytorch
├── pytorchexample
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

Cette application a la même structure Flower que le démonstrateur NumPy de l’étape précédente, mais la charge de travail est maintenant une tâche d’entraînement PyTorch réelle. L’application entraîne un petit réseau neuronal convolutionnel sur CIFAR-10, un jeu de données d’apprentissage d’image avec dix classes telles que avion, automobile, oiseau, chat, chien, navire et camion.

Aperçu rapide de l’application

Note

Un passage plus détaillé de l’application est disponible plus tard dans cette étape.

Avant d’exécuter l’application, il est utile de savoir ce que chaque fichier est responsable :

  • pytorchexample/task.py contient le code spécifique au PyTorch: le réseau neuronal, le chargement et la partitionnement des données CIFAR-10, le boucle d’entraînement local, la boucle d’évaluation, et les aides à l’évaluation serveur.

  • pytorchexample/client_app.py définit le ClientApp. Son gestionnaire de @app.train() reçoit le modèle global actuel, charge une partition CIFAR-10, entraîne le modèle localement et répond avec les paramètres mis à jour du modèle ainsi que des métriques.

  • pytorchexample/server_app.py définit le ServerApp. Il crée le modèle initial PyTorch, enveloppe les paramètres du modèle dans un ArrayRecord, crée une stratégie FedAvg, et démarre l’exécution fédérée.

  • pyproject.toml déclare les métadonnées de l’application et les dépendances, pointe vers les objets Flower à importer, et définit des valeurs de configuration d’exécution telles que le nombre de tours serveur, la taille du lot, les épochs locales, le taux d’apprentissage, et les paramètres d’évaluation.

L’idée importante est la même que précédemment : le ServerApp démarre l’exécution, le FedAvg coordonne chaque tour d’apprentissage fédéré, et chaque ClientApp entraîne ou évalue le modèle à l’aide des données disponibles sur son SuperNode.

Cette application utilise Flower Datasets pour télécharger CIFAR-10 et le partitionner en parties, une pour chaque client simulé. C’est idéal pour les simulations car cela vous permet d’expérimenter avec l’apprentissage fédéré même lorsque vous commencez à partir d’un seul ensemble de données centralisé. Dans une application Flower typique qui s’exécute en dehors des simulations, vous créez généralement des partitions artificielles. Au lieu de cela, chaque ClientApp charge les données déjà disponibles sur le SuperNode où il s’exécute.

Exécutez l’App sur SuperGrid

Note

Si vous n’avez pas déjà fait cela, complétez le first tutorial pour créer un compte SuperGrid et une fédération simulée.

Ouvrez une fenêtre de terminal, activez votre environnement Python et exécutez la commande suivante pour vous connecter à SuperGrid :

# This will open a browser window where you can enter your SuperGrid credentials.
$ flwr login supergrid

Une fois que vous êtes connecté, exécutez la commande suivante pour exécuter l’application sur SuperGrid et dans la fédération que vous avez créée dans le tutoriel précédent :

# Navigate to the directory of the app you want to run
$ cd /path/to/quickstart-pytorch
# Run the app across the federation you created in the previous tutorial
$ flwr run . supergrid --federation @<username>/<federation-name>
# for example
# flwr run . supergrid --federation @peter123/my-first-federation

SuperGrid lancera une nouvelle exécution pour cette application. Ouvrez le SuperGrid dashboard, sélectionnez votre fédération et cliquez sur la nouvelle exécution pour suivre son progression et inspecter les journaux.

Dans les journaux, vous devriez voir Flower lancer la stratégie FedAvg et effectuer plusieurs tours d’apprentissage fédéré. Chaque tour comprend un entraînement local sur des instances sélectionnées ClientApp, une mise en forme dans le ServerApp, et des métriques d’évaluation telles que eval_loss et eval_acc.

Vous pouvez définir des valeurs à partir de pyproject.toml en temps réel. Par exemple:

# Run the app for five rounds instead of the default three rounds
$ flwr run . --federation @<username>/<federation-name> \
    --run-config "num-server-rounds=5"

# Run the app for five rounds and a smaller batch size
$ flwr run . --federation @<username>/<federation-name> \
    --run-config "num-server-rounds=5" \
    --run-config "batch-size=16"

Exécutez l’App Localement

L’exécution sur SuperGrid est la méthode recommandée pour exécuter des workflows d’intelligence artificielle collaborative avec Flower. Cependant, il est également utile d’exécuter la même application localement pendant que vous développez ou déboguez.

À partir du répertoire quickstart-pytorch, installez l’application et ses dépendances dans votre environnement Python :

$ cd /path/to/quickstart-pytorch
$ pip install -e .

Exécutez ensuite l’application localement avec la commande suivante. Flower lancera une instance locale gérée de type SuperLink — une version distillée de SuperGrid — et exécuterez l’application avec des SuperNodes simulés sur votre machine. La première exécution peut prendre plus de temps car l’application doit télécharger CIFAR-10. Avec la flague --stream, vous pouvez voir les journaux de l’exécution locale dans votre terminal.

$ flwr run . local --stream

Le flux de sortie doit inclure des journaux similaires à ceux-ci :

INFO :      Starting FedAvg strategy:
INFO :          ├── Number of rounds: 3
INFO :      ...
INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 5 SuperNodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'train_loss': 2.149280}
INFO :      configure_evaluate: Sampled 10 SuperNodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'eval_loss': 2.31319, 'eval_acc': 0.13004}
INFO :      [ROUND 2/3]
INFO :      ...
INFO :      [ROUND 3/3]
INFO :      ...
INFO :      Strategy execution finished

Note

Dans la commande ci-dessus flwr run, vous n’êtes pas spécifique d’une fédération, c’est parce que pour la prototypage local il n’y a qu’une seule fédération disponible. En raison de cela, le drapeau --federation n’est pas requis.

Note

Si vous êtes sous Windows et voyez un affichage de terminal inattendu, par exemple □[32m□[1m, consultez cette entrée de FAQ.

Pour plus d’informations sur l’utilisation de la ligne de commande Flower contre une application SuperLink en cours d’exécution localement, y compris la façon de lister vos exécutions et de visualiser leurs journaux, consultez Run Flower Locally with a Managed SuperLink.

Une Plongée Plus Profonde dans l’Application

L’application @flwrlabs/quickstart-pytorch démontre un workflow d’apprentissage fédéré simple. Dans l’apprentissage fédéré, le serveur envoie les paramètres du modèle global au client, et le client met à jour le modèle local avec les paramètres reçus du serveur. Il entraîne ensuite le modèle sur les données locales (ce qui change les paramètres du modèle localement) et envoie les paramètres mis à jour/changés du modèle vers le serveur (ou, de manière alternative, il envoie uniquement les gradients vers le serveur, pas les paramètres du modèle complet).

Définissez l’Application Client Flower

Les systèmes d’apprentissage fédéré consistent en un serveur et plusieurs clients (SuperNodes). Dans Flower, nous créons une ServerApp et une ClientApp pour exécuter le code côté serveur et client, respectivement.

La fonctionnalité centrale du ClientApp est de réaliser une action avec les données locales dont dispose l’instance sur laquelle il s’exécute (par exemple un appareil Edge, un serveur dans un centre de données ou un ordinateur portable). Dans ce tutoriel, cette action consiste à entraîner et évaluer le petit modèle CNN défini plus tôt en utilisant les données d’entraînement et de validation locales.

Chargement des données

Cette application entraîne un petit réseau neuronal convolutif sur CIFAR-10. Comme le tutoriel utilise le Simulation Runtime, toutes les données proviennent d’abord d’un jeu de données centralisé, ensuite découpé en partitions, une pour chaque SuperNode simulé.

La fonction load_data() dans task.py utilise Flower Datasets pour charger une partition, la diviser en données d’entraînement et de validation, appliquer les transformations PyTorch, et retourner deux objets DataLoader:

def load_data(partition_id: int, num_partitions: int, batch_size: int):
    """Load partition CIFAR10 data."""
    # Only initialize `FederatedDataset` once
    global fds
    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(
            dataset="uoft-cs/cifar10",
            partitioners={"train": partitioner},
        )
    partition = fds.load_partition(partition_id)
    # Divide data on each SuperNode: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = Compose(
        [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        """Apply transforms to the partition from FederatedDataset."""
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=batch_size, shuffle=True
    )
    testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)
    return trainloader, testloader

Cette partition est nécessaire uniquement pour la simulation. En déploiement, chaque SuperNode chargerait généralement ses propres données locales directement, par exemple à partir d’un chemin fourni via --node-config

Entraînement

Nous pouvons définir comment le ClientApp effectue l’entraînement en enveloppant une fonction avec le décorateur @app.train(). Dans ce cas, nous nommons cette fonction train car nous allons l’utiliser pour entraîner le modèle sur les données locales. La fonction attend toujours deux arguments :

  • Un Message: Le message reçu du serveur. Il contient les paramètres du modèle et toute autre information de configuration envoyée par le serveur.

  • A Context: L’objet de contexte qui contient des informations sur l’exécution du SuperNode et du ClientApp, ainsi que sur la dernière exécution.

À travers le contexte, vous pouvez récupérer les paramètres de configuration définis dans l”pyproject.toml de votre application. Le contexte peut être utilisé pour persister l’état du client à travers plusieurs appels à train ou evaluate. Dans Flower, les objets éphémères sont des objets qui se créent pour l’exécution d’un Message et sont détruits lorsque la réponse est communiquée vers le serveur.

Voyons une mise en œuvre de ClientApp qui utilise le modèle CNN PyTorch précédemment défini, applique les paramètres reçus via le message ServerApp, charge ses données locales, entraîne le modèle avec elles (en utilisant la fonction train_fn) et génère une réponse Message contenant les paramètres mis à jour du modèle ainsi que certaines métriques d’intérêt.

from pytorchexample.task import train as train_fn

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    trainloader, _ = load_data(partition_id, num_partitions, batch_size)

    # Call the training function
    train_loss = train_fn(
        model,
        trainloader,
        context.run_config["local-epochs"],
        msg.content["config"]["lr"],
        device,
    )

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

Note

Les valeurs de partition-id et num-partitions affichées ci-dessus sont fournies par le Simulation Runtime. Dans un environnement de déploiement, le ClientApp chargerait généralement des données qui existent déjà sur le SuperNode. Par exemple, vous pourriez passer l’emplacement de ces données lors du démarrage du SuperNode avec --node-config "data-path=/path/to/data" et puis appeler load_data avec context.node_config["data-path"].

Notez que le train_fn est simplement un nom d’alias pointant vers la fonction d’entraînement définie plus tôt dans ce tutoriel (où nous avons défini le boucle d’entraînement PyTorch et l’optimiseur). À cette fonction, on passe le modèle que nous voulons entraîner localement et le chargeur de données, mais aussi le nombre d’épochs locaux et la vitesse d’apprentissage (lr) à utiliser. Notez comment dans ce cas, la mise en œuvre du local-epochs est lue depuis la configuration de run via le Context, tandis que le lr est lu depuis le ConfigRecord envoyé par le serveur via le Message. Cela peut être utilisé pour ajuster la vitesse d’apprentissage à chaque tour du serveur. Lorsque cette dynamique n’est pas nécessaire, lire le lr depuis la configuration de run via le Context est également parfaitement valide.

Une fois l’entraînement terminé, le ClientApp construit une réponse Message. Cette réponse comprend généralement un RecordDict avec deux enregistrements :

  • Un ArrayRecord contenant les paramètres de modèle mis à jour

  • Un MetricRecord avec des métriques pertinentes (dans ce cas, la perte d’entraînement et le nombre d’exemples utilisés pour l’entraînement)

Note

Le retour du nombre d’exemples sous la clé "num-examples" est obligatoire, car les stratégies telles que FedAvg utilisées par le ServerApp dépendent de cette clé pour agréger les modèles et les métriques par défaut, à moins que vous ne surchargiez l’argument weighted_by_key (par exemple : FedAvg(weighted_by_key="my-different-key")).

Après avoir construit la réponse Message, le ClientApp la retourne. Flower gère ensuite l’envoi de la réponse vers le serveur automatiquement.

Evaluation

Dans un cadre d’apprentissage fédéré typique, le ClientApp implémenterait également une fonction @app.evaluate() pour évaluer le modèle reçu du ServerApp sur des données de validation locales. C’est particulièrement utile pour surveiller les performances du modèle global sur chaque client pendant l’entraînement. La mise en œuvre de la fonction evaluate est très similaire à celle du train, excepté qu’elle appelle la fonction test_fn définie plus tôt dans ce tutoriel (qui implémente le boucle d’évaluation PyTorch) et elle retourne un Message contenant uniquement un MetricRecord avec les métriques d’évaluation (pas de ArrayRecord car les paramètres du modèle ne sont pas mis à jour pendant l’évaluation). Voici comment la fonction evaluate ressemble :

from pytorchexample.task import test as test_fn


@app.evaluate()
def evaluate(msg: Message, context: Context):
    """Evaluate the model on local data."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    _, valloader = load_data(partition_id, num_partitions, batch_size)

    # Call the evaluation function
    eval_loss, eval_acc = test_fn(
        model,
        valloader,
        device,
    )

    # Construct and return reply Message
    metrics = {
        "eval_loss": eval_loss,
        "eval_acc": eval_acc,
        "num-examples": len(valloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"metrics": metric_record})
    return Message(content=content, reply_to=msg)

Comme vous pouvez le voir, l’implémentation du evaluate est presque identique à celle du train, sauf que cela appelle la fonction train_fn au lieu de la fonction Message et qu’elle retourne un MetricRecord contenant uniquement un eval_loss avec des métriques pertinentes pour l’évaluation (eval_acc, num-examples – tous deux scalaires). Nous devons également inclure la clé __PH9__ dans les métriques afin que le serveur puisse agréger correctement les métriques d’évaluation.

Définir l’application Flower ServerApp

Du côté du serveur, nous devons configurer une stratégie qui encapsule l’approche/algorithm de l’apprentissage fédéré, par exemple, Fédération Averagée (FedAvg). Flower dispose d’un certain nombre de stratégies intégrées, mais nous pouvons également utiliser nos propres implémentations de stratégie pour personnaliser presque tous les aspects de l’approche d’apprentissage fédéré. Pour ce tutoriel, nous utilisons l’implémentation intégrée FedAvg et la personnalisons légèrement en spécifiant la fraction de SuperNodes connectés à impliquer dans une ronde d’entraînement.

Pour construire un ServerApp, nous définissons sa méthode @app.main(). Cette méthode reçoit en entrée les arguments suivants :

  • Un objet Grid qui sera utilisé pour interagir avec les SuperNodes exécutant le ClientApp afin de les impliquer dans une ronde d’entraînement/évaluation/requête ou autre

  • un objet Context qui fournit accès à la configuration de run.

Avant de lancer la stratégie via la méthode start, nous voulons initialiser le modèle global. Cela sera le modèle qui sera envoyé aux clients dans la première ronde d’apprentissage fédéré. Nous pouvons faire cela en créant une instance du modèle (Net), en extrayant les paramètres dans son state_dict, et en construisant un ArrayRecord avec eux. Nous pouvons ensuite le rendre disponible à la stratégie via l’argument initial_arrays de la méthode start().

Nous pouvons également passer optionnellement à la méthode start() un objet ConfigRecord contenant les paramètres que nous voulons communiquer aux clients. Ces derniers seront envoyés en tant que partie du Message qui transporte également les paramètres du modèle.

app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Read run config
    fraction_evaluate: float = context.run_config["fraction-evaluate"]
    num_rounds: int = context.run_config["num-server-rounds"]
    lr: float = context.run_config["learning-rate"]

    # Load global model
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())

    # Initialize FedAvg strategy
    strategy = FedAvg(fraction_evaluate=fraction_evaluate)

    # Start strategy, run FedAvg for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=global_evaluate,
    )

    # Save final model to disk
    print("\nSaving final model to disk...")
    state_dict = result.arrays.to_torch_state_dict()
    torch.save(state_dict, "final_model.pt")

La plupart de l’exécution de la méthode ServerApp se déroule à l’intérieur de la méthode strategy.start(). Après le nombre spécifié de tours (num_rounds), la méthode start() retourne un objet Result contenant les paramètres du modèle final et les métriques reçues des clients ou générées par la stratégie elle-même. Nous pouvons ensuite sauvegarder le modèle final sur disque pour une utilisation ultérieure.

Dans les coulisses

Alors, comment cela fonctionne-t-il ? Comment Flower exécute-t-il cette simulation ?

Lorsque nous exécutons flwr run contre la configuration de connexion locale par défaut, Flower soumet l’exécution à la connexion locale gérée SuperLink. Par défaut, la connexion locale SuperLink configurera le Simulation Runtime pour utiliser 10 clients. Chacun d’eux exécutera une instance du modèle ClientApp que nous avons défini plus tôt.

La connexion locale SuperLink commence ensuite l’exécution de ServerApp et lui demande d’envoyer des instructions à ces SuperNodes en utilisant la stratégie FedAvg. Dans cet exemple, FedAvg est configuré avec deux paramètres clés :

  • fraction-train=0.5 → sélectionner 50% des clients disponibles pour l’entraînement

  • fraction-evaluate=1.0 → sélectionner 100% des clients disponibles pour l’évaluation

Cela signifie dans notre exemple, 5 clients sur 10 seront sélectionnés pour l’entraînement, et tous les 10 clients participeront plus tard à l’évaluation.

Une ronde typique ressemble à ceci :

  • Entraînement

    1. FedAvg randomly selects 5 clients (50% of 10).

    2. Flower sends a TRAIN message to each selected ClientApp.

    3. Chaque ClientApp appelle la fonction décorée avec @app.train(), puis retourne un Message contenant un ArrayRecord (les paramètres du modèle mis à jour) et un MetricRecord (la perte d’entraînement et le nombre d’exemples).

    4. La connexion locale ServerApp reçoit toutes les réponses.

    5. FedAvg combine toutes les réponses ArrayRecord dans un nouveau ArrayRecord représentant le modèle global et combine toutes les métriques MetricRecord.

  • Évaluation

    1. FedAvg selects all 10 clients (100%).

    2. Flower sends an EVALUATE message to each ClientApp.

    3. Chaque ClientApp appelle la fonction décorée avec @app.evaluate() et retourne un Message contenant un MetricRecord (la perte d’évaluation, l’exactitude et le nombre d’exemples).

    4. La connexion locale ServerApp reçoit toutes les réponses.

    5. FedAvg aggregates all MetricRecord.

Une fois que les entraînement et l’évaluation sont terminés, commence une nouvelle ronde : une étape d’entraînement supplémentaire, puis une étape d’évaluation supplémentaire, et ainsi de suite, jusqu’à ce que le nombre configuré de tours soit atteint.

Remarques finales

Vous avez maintenant exécuté une application Flower Flower sur SuperGrid et localement. Comparée à la démo NumPy, cette application utilise un vrai modèle, un vrai jeu de données et un entraînement local réel, mais la structure est la même : ServerApp, ClientApp, stratégie et pyproject.toml.

Dans le prochain tutoriel, vous allez personnaliser la stratégie d’apprentissage fédéré pour changer la façon dont le serveur coordonne l’entraînement et l’évaluation.

Prochaines étapes

Avant de continuer, assurez-vous d’adhérer à la communauté Flower sur Flower Discuss (Join Flower Discuss) et Slack (Join Slack).

Il existe un canal Slack dédié si vous avez besoin d’aide, mais nous aimerions également entendre qui vous êtes dans #introductions !

Le Flower Collaborative AI Tutorial - Part 4: Use a federated learning strategy entre dans plus de détails sur les stratégies et le comportement avancé que vous pouvez construire avec elles.