Utilisez une stratégie d’apprentissage fédéré

Bienvenue dans la prochaine partie du tutoriel d’intelligence artificielle collaborative Flower !

Dans les tutoriels précédents, vous avez créé une fédération simulée sur SuperGrid, lancé une application Flower depuis Flower Hub, personnalisé l’application de démonstration NumPy, et ensuite lancé l’application quickstart PyTorch sur SuperGrid et localement. Dans ce tutoriel, vous allez personnaliser cette application PyTorch en changeant et en étendant la stratégie d’apprentissage fédéré utilisée par ServerApp.

Astuce

Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Flower Discuss ou Flower Slack pour vous présenter, poser des questions et obtenir de l’aide.

Allons plus loin que FedAvg avec les stratégies de Flower ! 🌼

Préparation

Ce tutoriel continue depuis le previous tutorial, où vous avez créé et lancé l’application @flwrlabs/quickstart-pytorch. Si vous avez terminé, ouvrez le répertoire existant quickstart-pytorch et continuez d’en là.

Si vous commencez ici directement, installez Flower et récupérez la même application :

# Install Flower with the simulation extra
$ pip install -U "flwr[simulation]"
# Fetch the app from Flower Hub
$ flwr new @flwrlabs/quickstart-pytorch
# Navigate to the app directory
$ cd quickstart-pytorch
# Install the app dependencies
$ pip install -e .

Avec cela, nous sommes prêts à introduire un certain nombre de nouvelles fonctionnalités de stratégie.

Choisissez une autre stratégie

La stratégie encapsule l’approche/algorithm de l’apprentissage fédéré, par exemple, FedAvg. Essayons d’utiliser une autre stratégie cette fois. Modifiez les lignes suivantes dans votre server_app.py pour passer de FedAvg à FedAdagrad.

# ... unchanged
# add this to the imports
from flwr.serverapp.strategy import FedAdagrad

# ... unchanged


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # Read run config
    fraction_evaluate: float = context.run_config["fraction-evaluate"]
    num_rounds: int = context.run_config["num-server-rounds"]
    lr: float = context.run_config["learning-rate"]

    # Load global model
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())

    # Initialize FedAdagrad strategy
    strategy = FedAdagrad(fraction_evaluate=fraction_evaluate)

    # Start strategy, run FedAdagrad for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=global_evaluate,
    )

Exécutez ensuite l’application sur SuperGrid pour confirmer que la nouvelle stratégie est utilisée :

# Log in if you are not already logged in
$ flwr login supergrid
# Run the app across the federation you created earlier in this tutorial series
$ flwr run . supergrid --federation @<username>/<federation-name>

Ouvrez le SuperGrid dashboard, sélectionnez votre fédération, et inspectez les journaux pour la nouvelle exécution. Vous devriez voir que Flower démarre la stratégie FedAdagrad au lieu de FedAvg.

Vous pouvez également exécuter le même application localement pendant le développement ou la débogage :

$ flwr run . local --stream

Paramètre côté serveur évaluation

Flower peut évaluer le modèle agrégé côté serveur ou côté client. L’évaluation côté client et côté serveur sont similaires dans certains aspects, mais différentes dans d’autres.

L’évaluation centralisée (ou évaluation côté serveur) est conceptuellement simple : elle fonctionne de la même manière que l’évaluation dans l’apprentissage automatique centralisé. S’il existe un ensemble de données côté serveur qui peut être utilisé à des fins d’évaluation, alors c’est parfait. Nous pouvons évaluer le modèle nouvellement agrégé après chaque cycle de formation sans avoir à envoyer le modèle aux clients. Nous avons également la chance que l’ensemble de notre ensemble de données d’évaluation soit disponible à tout moment.

L’évaluation fédérée (ou évaluation côté client) est plus complexe, mais aussi plus puissante : elle ne nécessite pas d’ensemble de données centralisé et nous permet d’évaluer les modèles sur un plus grand ensemble de données, ce qui donne souvent des résultats d’évaluation plus réalistes. En fait, de nombreux scénarios exigent que nous utilisions l’évaluation fédérée** si nous voulons obtenir des résultats d’évaluation représentatifs. Mais cette puissance a un coût : une fois que nous commençons à évaluer côté client, nous devons savoir que notre ensemble de données d’évaluation peut changer au cours des cycles d’apprentissage consécutifs si ces clients ne sont pas toujours disponibles. De plus, l’ensemble de données détenu par chaque client peut également changer au cours des cycles consécutifs. Cela peut conduire à des résultats d’évaluation qui ne sont pas stables, donc même si nous ne changions pas le modèle, nous verrions nos résultats d’évaluation fluctuer au cours des cycles consécutifs.

Nous avons vu comment l’évaluation fédérée fonctionne côté client (c’est-à-dire en implémentant une fonction enveloppée par le @app.evaluate décorateur dans votre ClientApp). Maintenant, voyons comment nous pouvons évaluer les paramètres du modèle agrégé côté serveur.

Pour cela, nous utilisons la fonction global_evaluate définie dans server_app.py. Cette fonction est un rappel qui sera passé à la méthode start de notre stratégie. Cela signifie que la stratégie appellera cette fonction après chaque tour d’apprentissage fédéré en passant deux arguments : le tour actuel d’apprentissage fédéré et les paramètres du modèle agrégé.

Notre fonction global_evaluate effectue les étapes suivantes :

  1. Chargez les paramètres du modèle agrégé dans un modèle PyTorch

  2. Chargement de l’intégralité du jeu de test CIFAR-10

  3. Évaluez le modèle sur le jeu de test

  4. Retournez les métriques d’évaluation sous forme de MetricRecord

from flwr.app import ArrayRecord, MetricRecord


def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
    """Evaluate model on central data."""

    # Load the model and initialize it with the received weights
    model = Net()
    model.load_state_dict(arrays.to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load entire test set
    test_dataloader = load_centralized_dataset()

    # Evaluate the global model on the test set
    test_loss, test_acc = test(model, test_dataloader, device)

    # Return the evaluation metrics
    return MetricRecord({"accuracy": test_acc, "loss": test_loss})

Rappelez-vous que nous avons mentionné que cette global_evaluate sera appelée par la stratégie. Pour y parvenir, il faut passer l’objet à la méthode start de la stratégie, comme montré ci-dessous. L’application quickstart le fait déjà, assurez-vous donc que cette partie reste dans server_app.py après avoir passé à FedAdagrad.

@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # ... unchanged

    # Start strategy, run FedAdagrad for `num_rounds`
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=global_evaluate,
    )

    # .. unchanged

À partir de là, nous allons exécuter localement afin que vous puissiez itérer plus rapidement pendant l’édition de l’application. Exécutez la simulation locale avec :

$ flwr run . local --stream

Vous noterez que les journaux du serveur enregistrent les métriques retournées par le callback après chaque tour. De plus, à la fin de l’exécution, notez le ServerApp-side Evaluate Metrics affiché :

INFO :          ServerApp-side Evaluate Metrics:
INFO :          { 0: {'accuracy': '1.0000e-01', 'loss': '2.3053e+00'},
INFO :            1: {'accuracy': '1.0000e-01', 'loss': '2.3203e+00'},
INFO :            2: {'accuracy': '2.3230e-01', 'loss': '2.0144e+00'},
INFO :            3: {'accuracy': '2.5720e-01', 'loss': '1.9258e+00'}}

Envoi de configurations aux clients depuis des stratégies

Dans certaines situations, nous voulons configurer l’exécution côté client (entraînement, évaluation) depuis le serveur. Un exemple de cela est le serveur demandant aux clients d’entraîner avec un taux d’apprentissage différent en fonction du numéro de tour actuel. Flower fournit une façon d’envoyer des valeurs de configuration depuis le serveur vers les clients comme partie de la Message que le ClientApp reçoit. Voyons comment nous pouvons faire cela.

Au start méthode de notre stratégie, on nous passe déjà un ConfigRecord spécifiant le taux d’apprentissage initial. Ce ConfigRecord sera envoyé aux clients dans toutes les Messages s’adressant à la fonction @app.train() du ClientApp. Supposons que nous voulions diminuer le taux d’apprentissage par un facteur de 0,5 tous les 5 tours, alors nous devons surcharger la méthode configure_train de notre stratégie et y intégrer une telle logique.

Pour cela, nous créons une nouvelle classe héritant de FedAdagrad et on surcharge la méthode configure_train. On utilise ensuite cette nouvelle stratégie dans notre ServerApp. Voyons comment cela ressemble en code. Créez un nouveau fichier appelé custom_strategy.py dans le répertoire pytorchexample et ajoutez le code suivant :

from typing import Iterable
from flwr.serverapp import Grid
from flwr.serverapp.strategy import FedAdagrad
from flwr.app import ArrayRecord, ConfigRecord, Message


class CustomFedAdagrad(FedAdagrad):
    def configure_train(
        self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
    ) -> Iterable[Message]:
        """Configure the next round of federated training and maybe do LR decay."""
        # Decrease learning rate by a factor of 0.5 every 5 rounds
        if server_round % 5 == 0 and server_round > 0:
            config["lr"] *= 0.5
            print("LR decreased to:", config["lr"])
        # Pass the updated config and the rest of arguments to the parent class
        return super().configure_train(server_round, arrays, config, grid)

Ensuite, nous utilisons cette nouvelle stratégie dans notre ServerApp en l’important dans votre server_app.py et en l’utilisant au lieu de la stratégie standard FedAdagrad:

# ... unchanged
# add this to the imports
from pytorchexample.custom_strategy import CustomFedAdagrad

# ... unchanged


@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""

    # ... unchanged

    # Initialize custom FedAdagrad strategy
    strategy = CustomFedAdagrad(fraction_evaluate=fraction_evaluate)

    # ... rest unchanged

Exécutez localement à nouveau, cette fois augmentez le nombre de tours à 15 pour voir la décadence du taux d’apprentissage en action.

$ flwr run . local --stream --run-config="num-server-rounds=15"

Vous noterez que dans l’étape configure_train des tours 5 et 10, le taux d’apprentissage est diminué par un facteur de 0,5 et le nouveau taux d’apprentissage est imprimé à la console.

Comment savons-nous que ClientApp utilise ce nouveau taux d’apprentissage ? Rappelez-vous que dans client_app.py, nous lisons le taux d’apprentissage du Message reçu par la fonction @app.train():

@app.train()
def train(msg: Message, context: Context):

    # ... setup

    # Call the training function
    train_loss = train_fn(
        model,
        trainloader,
        context.run_config["local-epochs"],
        msg.content["config"]["lr"],
        device,
    )

    # ... prepare reply Message
    return Message(content=content, reply_to=msg)

Félicitations ! Vous avez créé votre première stratégie personnalisée ajoutant de la dynamique à ce qui est envoyé aux clients.

Récapitulation

Dans ce tutoriel, nous avons vu comment nous pouvons améliorer progressivement notre système en personnalisant la stratégie, en choisissant une autre stratégie, en appliquant la décroissance de taux d’apprentissage au niveau de la stratégie et en évaluant les modèles sur le côté serveur. C’est tout à fait une grande flexibilité avec si peu de code, n’est-ce pas ?

In the later sections, we’ve seen how we can communicate arbitrary values between server and clients to fully customize client-side execution. With that capability, we built a larger Federated Learning simulation using the Flower Simulation Runtime.

Prochaines étapes

Avant de continuer, assurez-vous d’adhérer à la communauté Flower sur Flower Discuss (Join Flower Discuss) et Slack (Join Slack).

Il existe un canal Slack dédié si vous avez besoin d’aide, mais nous aimerions également entendre qui vous êtes dans #introductions !

Le tutoriel Flower Collaborative AI - Partie 5 : créer une stratégie à partir de zéro montre comment créer une Strategy entièrement personnalisée à partir de zéro.