Communiquez des messages personnalisés.

Bienvenue dans la prochaine partie du tutoriel d’intelligence artificielle collaborative Flower !

Dans les tutoriels précédents, vous avez créé une fédération simulée sur SuperGrid, exécuté et personnalisé des applications Flower, passé du démonstration NumPy à l’application quickstart PyTorch, personnalisé la stratégie utilisée par la ServerApp, puis construit une stratégie plus personnalisée. Dans ce tutoriel, vous allez vous concentrer sur la ClientApp et apprendre à communiquer des informations supplémentaires entre ClientApp et ServerApp via des objets Message.

Astuce

Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Flower Discuss ou Flower Slack pour vous présenter, poser des questions et obtenir de l’aide.

Allons plus loin et voyons comment sérialiser des objets Python arbitraires et les communiquer ! 🌼

Préparation

Ce tutoriel continue à partir des tutoriels précédents PyTorch dans cette série. Si vous avez déjà l’application quickstart-pytorch, ouvrez ce répertoire et continuez d’où vous en étiez.

Si vous commencez ici directement, installez Flower et récupérez la même application :

# Install Flower with the simulation extra
$ pip install -U "flwr[simulation]"
# Fetch the app from Flower Hub
$ flwr new @flwrlabs/quickstart-pytorch
# Navigate to the app directory
$ cd quickstart-pytorch
# Install the app dependencies
$ pip install -e .

Construire des messages.

Dans Flower, le serveur et les clients communiquent en envoyant et recevant des objets Message. Un Message transporte un RecordDict comme principal payload. Le RecordDict est comme un dictionnaire Python qui peut contenir plusieurs enregistrements de types différents. Il existe trois principaux types d’enregistrements :

  • ArrayRecord: Contient des paramètres de modèle sous forme de dictionnaire de tableaux NumPy

  • MetricRecord: Contient des métriques d’entraînement ou d’évaluation sous forme de dictionnaire d’integers, de flottants, de listes d’integers ou de listes de flottants.

  • ConfigRecord: Contient des paramètres de configuration sous forme de dictionnaire d’integers, de flottants, de chaînes, de booléens ou de bytes. Les listes de ces types sont également supportées.

Voyons quelques exemples de comment travailler avec ces types de registres et, finalement, construire un RecordDict qui peut être envoyé sur un Message.

from flwr.app import ArrayRecord, MetricRecord, ConfigRecord, RecordDict

# ConfigRecord can be used to communicate configs between ServerApp and ClientApp
# They can hold scalars, but also strings and booleans
config = ConfigRecord(
    {"batch_size": 32, "use_augmentation": True, "data-path": "/my/dataset"}
)

# MetricRecords expect scalar-based metrics (i.e. int/float/list[int]/list[float])
# By limiting the types Flower can aggregate MetricRecords automatically
metrics = MetricRecord({"accuracy": 0.9, "losses": [0.1, 0.001], "perplexity": 2.31})

# ArrayRecord objects are designed to communicate arrays/tensors/weights from ML models
array_record = ArrayRecord(my_model.state_dict())  # for a PyTorch model
array_record_other = ArrayRecord(my_model.to_numpy_ndarrays())  # for other ML models

# A RecordDict is like a dictionary that holds named records.
# This is the main payload of a Message
rd = RecordDict({"my-config": config, "metrics": metrics, "my-model": array_record})

Revisiter la réponse des applications Client

Allons nous rappeler de la communication entre ClientApp et ServerApp. Une fonction ClientApp enveloppée avec @app.train() retournerait généralement les paramètres de modèle localement mis à jour, ainsi que certaines métriques pertinentes pour le processus d’entraînement, telles que la perte d’entraînement et l’exactitude. En code, cela ressemblerait à :

@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # ... prepare model, load data, train locally

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

Ensuite, sur ServerApp, la stratégie Flower aggregera automatiquement les ArrayRecord et MetricRecord de chaque client dans un seul ArrayRecord et MetricRecord qui peut être utilisé pour mettre à jour le modèle global et enregistrer les métriques agrégées. Maintenant, qu’est-ce que nous voulions envoyer des informations supplémentaires du ClientApp vers le ServerApp ? Par exemple, supposons que nous voulions envoyer la durée d’exécution de l”ClientApp. Nous pouvons faire cela en ajoutant une nouvelle métrique à la MetricRecord. Elle sera également agrégée automatiquement par la stratégie. Si vous faites par exemple :

import time


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    start_time = time.time()

    # ... prepare model, load data, train locally

    end_time = time.time()
    training_time = end_time - start_time

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
        "training_time": training_time,  # New metric
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    return Message(content=content, reply_to=msg)

Si vous souhaitez communiquer d’autres types d’objets et les laisser hors du processus de mise en aggregation, vous pouvez utiliser un ConfigRecord. En plus des entiers et des flottants, vous pouvez utiliser un ConfigRecord pour envoyer des chaînes, des booléens et même des bytes. Dans le prochain chapitre, nous allons apprendre à communiquer des objets Python arbitraires en les sérialisant d’abord en bytes.

Communiquer des objets arbitraires

Supposons que la phase d’entraînement de notre ClientApp produise une classe de données comme celle ci-dessous et nous voudrions communiquer cela au ServerApp via le Message. Allons-y et définissons cela dans task.py:

from dataclasses import dataclass


@dataclass
class TrainProcessMetadata:
    """Metadata about the training process."""

    training_time: float
    converged: bool
    training_losses: dict[str, float]  # e.g. { "epoch_1": 0.5, "epoch_2": 0.3 }

Maintenant, voyons comment le ClientApp peut sérialiser cet objet, l’envoyer au ServerApp, faire la stratégie désérialiser-le à nouveau en objet original, et l’utiliser.

Envoi depuis les ClientApps

Supposons que notre ClientApp entraîne le modèle localement et génère une instance de TrainProcessMetadata. Pour envoyer cela comme partie de la réponse du message, nous devons le sérialiser en octets. Dans ce cas, nous pouvons utiliser le module pickle de la bibliothèque standard Python. Nous pouvons ensuite envoyer l’objet sérialisé dans un ConfigRecord dans la réponse Message. Voyons comment cela ressemblerait à du code:

L’exemple ci-dessous se concentre sur la logique des métadonnées supplémentaires ; gardez le modèle et les données de votre fonction d’entraînement existante inchangés.

Avertissement

Le code suivant est destiné à des fins de démonstration uniquement. Dans les applications réelles, puisque pickle peut exécuter du code arbitraire lors de l’unpacking, vous devriez utiliser une méthode de sérialisation plus sûre que pickle, comme json ou une solution personnalisée simple si l’objet n’est pas trop complexe. pickle est utilisé ici uniquement pour la simplicité.

import pickle
import time


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # ... prepare model, load data, train locally
    # The train function returns the training loss
    start_time = time.time()
    train_loss = train_fn(...)
    # Construct a TrainProcessMetadata object
    train_metadata = TrainProcessMetadata(
        training_time=time.time() - start_time,
        converged=True,
        training_losses={"final": train_loss},
    )

    # Serialize the TrainProcessMetadata object to bytes
    train_meta_bytes = pickle.dumps(train_metadata)
    # Construct a ConfigRecord
    config_record = ConfigRecord({"meta": train_meta_bytes})

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict(
        {
            "arrays": model_record,
            "metrics": metric_record,
            "train_metadata": config_record,
        }
    )
    return Message(content=content, reply_to=msg)

Voyons ensuite comment la stratégie sur le ServerApp peut désérialiser l’objet à nouveau en sa forme originale et l’utiliser.

Réception sur ServerApps

Comme vous le savez, une stratégie Flower aggregera automatiquement les ArrayRecord et MetricRecord de chaque client. Cependant, elle ne fera rien avec le ConfigRecord que nous venons de envoyer. Nous pouvons surcharger la méthode aggregate_train de notre stratégie pour gérer la désérialisation et l’utilisation de l’objet TrainProcessMetadata.

Note

Nous surchargeons la méthode aggregate_train car nous avons envoyé l’objet depuis une fonction @app.train(). Si nous avions envoyé cela depuis une fonction @app.evaluate(), nous aurions surchargé la méthode aggregate_evaluate au lieu de celle-ci.

Créons une nouvelle stratégie personnalisée, ou réutilisons celle créée dans les tutoriels de stratégies précédents, dans server_app.py qui étend la stratégie FedAdagrad et surcharge la méthode aggregate_train pour désérialiser l’objet TrainProcessMetadata de chaque client et imprimer le temps d’entraînement et l’état de convergence:

import pickle
from dataclasses import asdict
from typing import Iterable, Optional


class CustomFedAdagrad(FedAdagrad):

    def aggregate_train(
        self,
        server_round: int,
        replies: Iterable[Message],
    ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
        """Aggregate ArrayRecords and MetricRecords in the received Messages."""

        for reply in replies:
            if reply.has_content():
                # Retrieve the ConfigRecord from the message
                config_record = reply.content["train_metadata"]
                metadata_bytes = config_record["meta"]
                # Deserialize it
                train_meta = pickle.loads(metadata_bytes)
                print(asdict(train_meta))
        # Aggregate the ArrayRecords and MetricRecords as usual
        return super().aggregate_train(server_round, replies)

Finalement, nous exécutons l’application Flower.

$ flwr run . local --stream

Le flwr run . local soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Voir Exécuter Flower Localement avec un SuperLink Géré pour le workflow local complet.

Vous remarquerez que les métadonnées d’entraînement de chaque client sont enregistrés dans la console du ServerApp. Si vous avez terminé l’implémentation de la création de l’objet TrainProcessMetadata dans le ClientApp, vous devriez voir un sortie similaire à celui-ci :

INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 5 SuperNodes (out of 50)
{'training_time': 123.45, 'converged': True, 'training_losses': {'epoch1': 0.56, 'epoch2': 0.34}}
{'training_time': 130.67, 'converged': False, 'training_losses': {'epoch1': 0.60, 'epoch2': 0.40}}
...

Vous pouvez maintenant utiliser ces informations dans votre logique de stratégie comme nécessaire. Par exemple, pour mettre en œuvre une méthode d’agrégation personnalisée sur la base de l’état de convergence ou pour logger des métriques supplémentaires.

Récapitulation

Dans cette partie du tutoriel, nous avons vu comment communiquer des objets Python arbitraires entre le ClientApp et le ServerApp en les sérialisant en octets et en les envoyant comme un ConfigRecord dans un Message. Nous avons également appris à désérialiser-les à nouveau en leur forme originale sur le côté serveur et l’utiliser dans une stratégie personnalisée. Notez que les étapes présentées ici sont identiques si vous avez besoin de sérialiser des objets dans la stratégie pour les envoyer aux clients.

Prochaines étapes

Avant de continuer, assurez-vous d’adhérer à la communauté Flower sur Flower Discuss (Join Flower Discuss) et Slack (Join Slack).

Il existe un canal Slack dédié si vous avez besoin d’aide, mais nous aimerions également entendre qui vous êtes dans #introductions !

C’est la dernière partie du tutoriel Flower (pour l’instant !), félicitations ! Tu es maintenant bien équipé pour comprendre le reste de la documentation. Il y a de nombreux sujets que nous n’avons pas abordés dans le tutoriel, nous te recommandons les ressources suivantes :