Quickstart PyTorch¶
Dans ce tutoriel d’apprentissage fédéré, nous allons apprendre à entraîner un réseau de neurones convolutionnel sur CIFAR-10 en utilisant Flower et PyTorch. Il est recommandé de créer un environnement virtuel et d’exécuter tout dans un virtualenv.
Utilisez flwr new pour créer un projet complet Flower+PyTorch. Il générera tous les fichiers nécessaires pour exécuter une fédération de 10 nœuds en utilisant FedAvg. Par défaut, l’application générée utilise un profil de simulation local qui flwr run soumet à un SuperLink géré local, qui exécute ensuite l’exécution avec Flower Simulation Runtime. Le dataset sera partitionné en utilisant le IidPartitioner de Flower Dataset.
Maintenant que nous avons une idée approximative de ce que cet exemple est sur, allons-y. Tout d’abord, installez Flower dans votre nouvel environnement :
# In a new Python environment
$ pip install flwr[simulation]
Ensuite, exécutez la commande suivante :
$ flwr new @flwrlabs/quickstart-pytorch
Après avoir exécuté cela, vous remarquerez que un nouveau répertoire nommé quickstart-pytorch a été créé. Il devrait avoir la structure suivante :
quickstart-pytorch
├── pytorchexample
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
Si vous n’avez pas encore installé le projet et ses dépendances, vous pouvez le faire avec :
# From the directory where your pyproject.toml is
$ pip install -e .
Pour lancer le projet, faites:
# Run with default arguments and stream logs
$ flwr run . --stream
Le processus flwr run . soumet l’exécution, imprime l’ID d’exécution et retourne sans diffuser les journaux. Pour le workflow local complet, voir Exécuter Flower Localement avec un SuperLink Géré.
Avec les arguments par défaut, vous verrez une sortie en flux continu comme celle-ci :
Starting local SuperLink on 127.0.0.1:39093...
Successfully started run 1859953118041441032
INFO : Starting FedAvg strategy:
INFO : ├── Number of rounds: 3
INFO : [ROUND 1/3]
INFO : configure_train: Sampled 5 nodes (out of 10)
INFO : aggregate_train: Received 5 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'train_loss': 2.149280}
INFO : configure_evaluate: Sampled 10 nodes (out of 10)
INFO : aggregate_evaluate: Received 10 results and 0 failures
INFO : └──> Aggregated MetricRecord: {'eval_loss': 2.31319, 'eval_acc': 0.10004}
INFO : [ROUND 2/3]
INFO : ...
INFO : [ROUND 3/3]
INFO : ...
INFO : Strategy execution finished in 16.56s
INFO : Final results:
INFO : Server-side Evaluate Metrics:
INFO : {}
Vous pouvez également surcharger les paramètres définis dans la section [tool.flwr.app.config] de pyproject.toml comme suit:
# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 local-epochs=3"
Voici une explication de chaque composant dans le projet que vous venez de créer : partitionnement du jeu de données, modèle, définir la ClientApp et définir la ServerApp.
Les données¶
Ce tutoriel utilise Flower Datasets pour télécharger et partitionner facilement le dataset CIFAR-10. Dans cet exemple, vous utiliserez IidPartitioner pour générer des partitions num_partitions. Vous pouvez choisir parmi les other partitioners disponibles dans Flower Datasets. Chaque ClientApp appellera cette fonction pour créer des chargements de données avec les données correspondant à leur partition de données.
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch
# Construct dataloaders
partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(
partition_train_test["train"], batch_size=batch_size, shuffle=True
)
testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)
Le Modèle¶
Nous avons défini un réseau neuronal convolutionnel simple, mais vous êtes libres de le remplacer par un modèle plus sophistiqué si vous le souhaitez :
class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
En outre de la définition de l’architecture du modèle, nous incluons également deux fonctions d’utilité pour effectuer à la fois l’entraînement (c’est-à-dire train()) et l’évaluation (c’est-à-dire test()) en utilisant le modèle ci-dessus. Ces fonctions devraient ressembler à des choses familières si vous avez une certaine expérience antérieure avec PyTorch. Notez que ces fonctions n’ont rien de spécifique à Flower. La fonction d’entraînement sera normalement appelée, comme nous le verrons plus tard, depuis un client Flower qui passe ses propres données. En résumé, vos clients peuvent utiliser des fonctions d’entraînement/test standard pour effectuer un entraînement local ou une évaluation :
def train(net, trainloader, epochs, lr, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
net.train()
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_trainloss = running_loss / (epochs * len(trainloader))
return avg_trainloss
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy
L’Application Client¶
Les principales modifications que nous devons apporter pour utiliser PyTorch avec Flower ont affaire à la conversion du ArrayRecord reçu dans la partie Message en un dictionnaire de l’état PyTorch, et vice versa lorsqu’on génère la réponse Message à partir du ClientApp. Nous pouvons faire usage des méthodes intégrées dans le ArrayRecord pour effectuer ces conversions:
@app.train()
def train(msg: Message, context: Context):
# Instantiate a PyTorch model
model = Net()
# Extract ArrayRecord from Message and convert to PyTorch state_dict
state_dict = msg.content["arrays"].to_torch_state_dict()
# Load received state_dict into model
model.load_state_dict(state_dict)
# ...
# Convert state_dict back into an ArrayRecord
array_record = ArrayRecord(model.state_dict())
Le reste de la fonctionnalité est directement inspiré du cas centralisé. Le ClientApp comporte trois méthodes de base (train, evaluate, et query) que nous pouvons implémenter à des fins différentes. Par exemple : train pour entraîner le modèle reçu en utilisant les données locales ; evaluate pour évaluer la performance du modèle reçu sur un jeu de validation ; et query pour récupérer des informations sur le nœud exécutant le ClientApp. Dans ce tutoriel, nous ne ferons que faire usage de train et evaluate.
Voyons comment la méthode train peut être implémentée. Elle reçoit en arguments d’entrée un Message depuis le ServerApp. Par défaut, elle porte :
un
ArrayRecordavec les tableaux du modèle à fédérer. Par défaut, ils peuvent être récupérés avec la clé"arrays"lors de l’accès au contenu du message.une
ConfigRecordavec la configuration transmise depuis leServerApp. Par défaut, elle peut être récupérée avec la clé"config"lors de l’accès au contenu du message.
La méthode train reçoit également le Context, ce qui donne accès aux configurations pour votre exécution et votre nœud. Les hyperparamètres de la configuration d’exécution sont définis dans la partie pyproject.toml de votre application Flower. La configuration du nœud ne peut être définie qu’en exécutant Flower avec le Deployment Runtime et ce n’est pas directement configurable pendant les simulations.
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Load the model and initialize it with the received weights
model = Net()
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load the data
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
batch_size = context.run_config["batch-size"]
trainloader, _ = load_data(partition_id, num_partitions, batch_size)
# Call the training function
train_loss = train_fn(
model,
trainloader,
context.run_config["local-epochs"],
msg.content["config"]["lr"],
device,
)
# Construct and return reply Message
# Include the locally-trained model
model_record = ArrayRecord(model.state_dict())
# Include some statistics such as the training loss
# We also want to include the number of examples used for training
# so the strategy in the ServerApp can do FedAvg
metrics = {
"train_loss": train_loss,
"num-examples": len(trainloader.dataset),
}
metric_record = MetricRecord(metrics)
# RecordDict are the main payload type in Messages
# We insert both the ArrayRecord and the MetricRecord into it
content = RecordDict({"arrays": model_record, "metrics": metric_record})
return Message(content=content, reply_to=msg)
La méthode @app.evaluate() serait presque identique avec deux exceptions : (1) le modèle n’est pas entraîné localement, mais il est utilisé pour évaluer sa performance sur le jeu de validation local ; (2) inclure le modèle dans la réponse Message n’est plus nécessaire car il ne subit aucune modification locale.
L’Application Serveur¶
Pour construire un ServerApp, nous définissons sa méthode @app.main(). Cette méthode reçoit en arguments d’entrée :
Un objet
Gridqui sera utilisé pour interagir avec les nœuds s’exécutant laClientAppafin d’impliquer les clients dans une ronde d’entraînement/évaluation/requête ou autre.Un objet
Contextqui fournit accès à la configuration de l’exécution.
Dans cet exemple, nous utilisons FedAvg et la configurons avec une valeur spécifique de fraction_evaluate qui est lue à partir de la configuration d’exécution. Vous pouvez trouver la valeur par défaut définie dans la partie pyproject.toml. Ensuite, l’exécution de la stratégie est lancée lorsqu’on invoque sa méthode start. À cela on passe:
l’objet
Grid.un
ArrayRecordportant un modèle initialisé aléatoirement qui servira de modèle global pour fédérer.Un objet
ConfigRecordavec les hyperparamètres d’entraînement à envoyer aux clients. La stratégie insérera également le numéro actuel de ronde dans cette configuration avant de l’envoyer aux nœuds participants.Le paramètre
num_roundsspécifiant combien de rondes deFedAvgeffectuer.une fonction
evaluate_fnqui sera appelée pour évaluer le modèle global sur les données de test centralisées après chaque ronde.
# Create ServerApp
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
"""Main entry point for the ServerApp."""
# Read run config
fraction_evaluate: float = context.run_config["fraction-evaluate"]
num_rounds: int = context.run_config["num-server-rounds"]
lr: float = context.run_config["learning-rate"]
# Load global model
global_model = Net()
arrays = ArrayRecord(global_model.state_dict())
# Initialize FedAvg strategy
strategy = FedAvg(fraction_evaluate=fraction_evaluate)
# Start strategy, run FedAvg for `num_rounds`
result = strategy.start(
grid=grid,
initial_arrays=arrays,
train_config=ConfigRecord({"lr": lr}),
num_rounds=num_rounds,
evaluate_fn=global_evaluate,
)
if context.run_config["save-model"]:
# Save final model to disk
print("\nSaving final model to disk...")
state_dict = result.arrays.to_torch_state_dict()
torch.save(state_dict, "final_model.pt")
Notez que la méthode start du stratégie retourne un objet Result. Cet objet contient toutes les informations pertinentes sur le processus FL, y compris les poids du modèle final sous forme de ArrayRecord, et des métriques d’entraînement et d’évaluation fédérées sous forme de MetricRecords.
Félicitations ! Vous avez réussi à créer et exécuter votre premier système d’apprentissage fédéré.
Astuce
Vérifiez la documentation de Run simulations pour en savoir plus sur la façon de configurer et d’exécuter les simulations Flower.
Note
Vérifiez la partie source code de l’édition étendue de ce tutoriel dans examples/quickstart-pytorch dans le dépôt GitHub Flower.