Quickstart 🤗 Transformers

Dans ce tutoriel d’apprentissage fédéré, nous allons apprendre à entraîner un grand modèle de langage (LLM) sur le IMDB dataset en utilisant Flower et la bibliothèque Transformers 🤗 Hugging Face. Il est recommandé de créer un environnement virtuel et d’exécuter tout dans un virtualenv.

Laissez-nous utiliser flwr new pour créer un projet complet Flower+🤗 Hugging Face. Cela générera tous les fichiers nécessaires pour exécuter une fédération de 10 nœuds en utilisant FedAvg. Par défaut, l’application générée utilise un profil de simulation local qui flwr run soumet à un SuperLink géré local, qui exécute ensuite l’exécution avec Flower Simulation Runtime. Le dataset sera partitionné en utilisant le Flower Datasets du IidPartitioner.

Maintenant que nous avons une idée approximative de ce que cet exemple est sur, allons-y. Tout d’abord, installez Flower dans votre nouvel environnement :

# In a new Python environment
$ pip install flwr[simulation]

Ensuite, exécutez la commande suivante :

$ flwr new @flwrlabs/quickstart-huggingface

Après l’avoir exécuté, vous remarquerez un nouveau répertoire nommé quickstart-huggingface qui a été créé. Il devrait avoir la structure suivante :

quickstart-huggingface
├── huggingface_example
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp   ├── server_app.py   # Defines your ServerApp   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata like dependencies and configs
└── README.md

Si vous n’avez pas encore installé le projet et ses dépendances, vous pouvez le faire avec :

# From the directory where your pyproject.toml is
$ pip install -e .

Pour lancer le projet, faites:

# Run with default arguments and stream logs
$ flwr run . --stream

Le processus flwr run . soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Pour le workflow local complet, voir Exécuter Flower Localement avec un SuperLink Géré.

Avec les arguments par défaut, vous verrez une sortie en flux continu comme celle-ci :

Starting local SuperLink on 127.0.0.1:39093...
Successfully started run 1859953118041441032
INFO :      Starting FedAvg strategy:
INFO :          ├── Number of rounds: 3
INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 5 nodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'train_loss': 0.6974}
INFO :      configure_evaluate: Sampled 10 nodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          └──> Aggregated MetricRecord: {'val_loss': 0.0223, 'val_accuracy': 0.5024}
INFO :      [ROUND 2/3]
INFO :      ...
INFO :      [ROUND 3/3]
INFO :      ...
INFO :      Strategy execution finished in 151.02s
INFO :      Final results:
INFO :          ServerApp-side Evaluate Metrics:
INFO :          {}

Vous pouvez également lancer le projet avec GPU comme suit:

# Run with default arguments
$ flwr run . localhost-gpu --stream

Cela utilisera les arguments par défaut où chaque ClientApp utilisera 4 processeurs et au plus 4 ClientApps s’exécuteront sur un GPU donné.

Vous pouvez également surcharger les paramètres définis dans la section [tool.flwr.app.config] de pyproject.toml comme suit:

# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 fraction-train=0.2"

Voici une explication de chaque composant dans le projet que vous venez de créer : partitionnement du jeu de données, modèle, définir la ClientApp et définir la ServerApp.

Les données

Ce tutoriel utilise Flower Datasets pour télécharger et partitionner facilement le jeu de données IMDB. Dans cet exemple, vous utiliserez IidPartitioner pour générer des partitions num_partitions. Vous pouvez choisir parmi les other partitioners disponibles dans Flower Datasets. Pour tokeniser le texte, nous chargerons également le tokenizer à partir du modèle Transformer pré-entraîné que nous utiliserons pendant l’entraînement - plus d’informations sur cela dans la section suivante. Chaque ClientApp appellera cette fonction pour créer des chargements de données avec les données correspondant à leur partition de données.

partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
    dataset="stanfordnlp/imdb",
    partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)


def tokenize_function(examples):
    return tokenizer(
        examples["text"], truncation=True, add_special_tokens=True
    )


partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
partition_train_test = partition_train_test.rename_column("label", "labels")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
    partition_train_test["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator,
)

testloader = DataLoader(
    partition_train_test["test"], batch_size=32, collate_fn=data_collator
)

Le Modèle

Nous allons faire appel à 🤗 Hugging Face pour fédérer l’entraînement de modèles de langage sur plusieurs clients en utilisant Flower. Plus spécifiquement, nous allons affiner un modèle Transformer pré-entraîné (bert-tiny) pour la classification de séquences sur le jeu de données des notes IMDB. L’objectif final est de détecter si une note de film est positive ou négative. Si vous avez accès à des GPUs plus grandes, n’hésitez pas à utiliser des modèles plus grands !

net = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2
)

Notez que ici, model_name est une chaîne qui sera chargée à partir de Context dans ClientApp et ServerApp.

En plus de charger les poids et l’architecture du modèle pré-entraîné, nous incluons également deux fonctions d’utilité pour effectuer à la fois l’entraînement (c’est-à-dire train()) et l’évaluation (c’est-à-dire test()) en utilisant le modèle ci-dessus. Ces fonctions devraient vous paraître familières si vous avez une certaine expérience antérieure avec PyTorch. Notez que ces fonctions n’ont rien de spécifique à Flower. La fonction d’entraînement sera normalement appelée, comme nous le verrons plus tard, depuis un client Flower passant ses propres données. En résumé, vos clients peuvent utiliser des fonctions d’entraînement/test standard pour effectuer l’entraînement local ou l’évaluation :

def train_fn(net, trainloader, epochs, device) -> None:
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test_fn(net, testloader, device) -> tuple[Any | float, Any]:
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

L’Application Client

Les principales modifications que nous devons apporter pour utiliser 🤗 Hugging Face avec Flower ont trait à la conversion du ArrayRecord reçu dans le Message en un state_dict PyTorch et vice versa lors de la génération de la réponse Message depuis l’Application Client. Nous pouvons nous servir des méthodes intégrées dans le ArrayRecord pour effectuer ces conversions :

# Load the model
model = get_model(model_name)

# Extract ArrayRecord from Message and convert to PyTorch state_dict
arrays = msg.content["arrays"]
# Load state_dict into the model
model.load_state_dict(arrays.to_torch_state_dict(), strict=True)

# ... do some training

# Convert state_dict back into an ArrayRecord
model_record = ArrayRecord(model.state_dict())

Le reste de la fonctionnalité est directement inspiré du cas centralisé. Le ClientApp comporte trois méthodes de base (train, evaluate, et query) que nous pouvons implémenter à des fins différentes. Par exemple : train pour entraîner le modèle reçu en utilisant les données locales ; evaluate pour évaluer la performance du modèle reçu sur un jeu de validation ; et query pour récupérer des informations sur le nœud exécutant le ClientApp. Dans ce tutoriel, nous ne ferons que faire usage de train et evaluate.

Voyons comment la méthode train peut être implémentée. Elle reçoit en arguments d’entrée un Message depuis le ServerApp. Par défaut, elle porte :

  • un ArrayRecord avec les tableaux du modèle à fédérer. Par défaut, ils peuvent être récupérés avec la clé "arrays" lors de l’accès au contenu du message.

  • une ConfigRecord avec la configuration transmise depuis le ServerApp. Par défaut, elle peut être récupérée avec la clé "config" lors de l’accès au contenu du message.

La méthode train reçoit également le Context, donnant accès aux configurations pour votre exécution et nœud. Les hyperparamètres de la configuration d’exécution sont définis dans la section pyproject.toml de votre application Flower. La configuration du nœud ne peut être configurée que lors de l’exécution de Flower avec le Deployment Runtime et ce, directement configurable pendant les simulations.

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context) -> Message:
    """Train the model on local data."""

    # Get this client's dataset partition
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    model_name = context.run_config["model-name"]
    trainloader, _ = load_data(partition_id, num_partitions, model_name)

    # Load model
    model = get_model(model_name)

    # Initialize it with the received weights
    arrays = msg.content["arrays"]
    model.load_state_dict(arrays.to_torch_state_dict(), strict=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Train the model on local data
    train_fn(model, trainloader, epochs=1, device=device)

    # Construct and return reply Message
    model_record = ArrayRecord(model.state_dict())
    metrics = MetricRecord({"num-examples": len(trainloader)})
    # Construct RecordDict and add ArrayRecord and MetricRecord
    content = RecordDict({"arrays": model_record, "metrics": metrics})
    return Message(content=content, reply_to=msg)

La méthode @app.evaluate() serait presque identique avec deux exceptions : (1) le modèle n’est pas entraîné localement, mais il est utilisé pour évaluer sa performance sur le jeu de validation local ; (2) inclure le modèle dans la réponse Message n’est plus nécessaire car il ne subit aucune modification locale.

L’Application Serveur

Pour construire un ServerApp, nous définissons sa méthode @app.main(). Cette méthode reçoit en arguments d’entrée :

  • Un objet Grid qui sera utilisé pour interagir avec les nœuds s’exécutant la ClientApp afin d’impliquer les clients dans une ronde d’entraînement/évaluation/requête ou autre.

  • Un objet Context qui fournit accès à la configuration de l’exécution.

Dans cet exemple, nous utilisons le FedAvg et le configurons avec une valeur spécifique de fraction_train qui est chargée à partir de la configuration d’exécution. Vous pouvez trouver la valeur par défaut définie dans pyproject.toml. Ensuite, l’exécution de la stratégie est lancée lorsqu’on invoque sa méthode start. À cela on passe:

  • l’objet Grid.

  • Un objet ArrayRecord portant un modèle initialisé aléatoirement qui servira de modèle global à fédérer.

  • Un objet ConfigRecord avec les hyperparamètres d’entraînement à envoyer aux clients. La stratégie insérera également le numéro actuel de ronde dans cette configuration avant de l’envoyer aux nœuds participants.

  • Le paramètre num_rounds spécifiant combien de rondes de FedAvg effectuer.

# Create ServerApp
app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:

    # Define model to federate and extract parameters
    model_name = context.run_config["model-name"]
    model = get_model(model_name)
    arrays = ArrayRecord(model.state_dict())

    # Instantiate strategy
    fraction_train = context.run_config["fraction-train"]
    fraction_evaluate = context.run_config["fraction-evaluate"]
    strategy = FedAvg(
        fraction_train=fraction_train,
        fraction_evaluate=fraction_evaluate,
    )

    num_rounds = context.run_config["num-server-rounds"]
    # Start the strategy
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        num_rounds=num_rounds,
    )

    if context.run_config["save-model"]:
        # Save final model to disk
        print("\nSaving final model to disk...")
        state_dict = result.arrays.to_torch_state_dict()
        torch.save(state_dict, "final_model.pt")

Notez que la méthode start de la stratégie retourne un objet résultat. Cet objet contient toutes les informations pertinentes sur le processus FL, y compris les poids du modèle final sous forme de ArrayRecord, et des métriques d’entraînement et d’évaluation fédérées sous forme de MetricRecords. Vous pouvez facilement logger les métriques à l’aide de Python’s pprint et, si save-model est défini sur true, sauvegarder le modèle global state_dict en utilisant torch.save.

Félicitations ! Vous avez réussi à créer et à lancer votre premier système d’apprentissage fédéré pour un LLM.

Astuce

Vérifiez la documentation de Run simulations pour en savoir plus sur la façon de configurer et d’exécuter les simulations Flower.

Note

Vérifiez le code source de la version étendue de ce tutoriel dans examples/quickstart-huggingface dans le dépôt GitHub Flower. Pour un exemple complet d’un fine-tuning fédéré d’un LLM avec Flower, reportez-vous à l’exemple FlowerTune LLM dans le dépôt GitHub Flower.