Exemple : FedBN dans PyTorch - De la centralisation à la fédération#

Ce tutoriel te montrera comment utiliser Flower pour construire une version fédérée d’une charge de travail d’apprentissage automatique existante avec FedBN, une stratégie de formation fédérée conçue pour les données non-identifiées. Nous utilisons PyTorch pour former un réseau neuronal convolutif (avec des couches de normalisation par lots) sur l’ensemble de données CIFAR-10. Lors de l’application de FedBN, seules quelques modifications sont nécessaires par rapport à Exemple : PyTorch - De la centralisation à la fédération.

Formation centralisée#

Tous les fichiers sont révisés sur la base de Exemple : PyTorch - From Centralized To Federated. La seule chose à faire est de modifier le fichier appelé cifar.py, la partie révisée est montrée ci-dessous :

L’architecture du modèle définie dans la classe Net() est ajoutée avec les couches de normalisation par lots en conséquence.

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

Tu peux maintenant exécuter ta charge de travail d’apprentissage automatique :

python3 cifar.py

Jusqu’à présent, tout ceci devrait te sembler assez familier si tu as déjà utilisé PyTorch. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un système d’apprentissage fédéré au sein de FedBN, le système se compose d’un serveur et de deux clients.

Formation fédérée#

Si vous avez lu Exemple : PyTorch - From Centralized To Federated, les parties suivantes sont faciles à suivre, seules les fonctions get_parameters et set_parameters dans client.py ont besoin d’être révisées. Si ce n’est pas le cas, veuillez lire Exemple : PyTorch - From Centralized To Federated <https://flower.dev/docs/example-pytorch-from-centralized-to-federated.html>. d’abord.

Notre exemple consiste en un serveur et deux clients. Dans FedBN, server.py reste inchangé, nous pouvons démarrer le serveur directement.

python3 server.py

Enfin, nous allons réviser notre logique client en modifiant get_parameters et set_parameters dans client.py, nous allons exclure les paramètres de normalisation des lots de la liste des paramètres du modèle lors de l’envoi ou de la réception depuis le serveur.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    ...

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN
        return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

    ...

Tu peux maintenant ouvrir deux autres fenêtres de terminal et lancer

python3 client.py

dans chaque fenêtre (assure-toi que le serveur est toujours en cours d’exécution avant de le faire) et tu verras ton projet PyTorch (auparavant centralisé) exécuter l’apprentissage fédéré avec la stratégie FedBN sur deux clients. Félicitations !

Prochaines étapes#

The full source code for this example can be found here. Our example is of course somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using different subsets of CIFAR-10 on each client? How about adding more clients?