Implémentez FedBN

Cette étape démonstrative vous montrera comment utiliser Flower pour construire une version fédérée d’un workload de ML existant avec FedBN, un méthode d’entraînement fédéré conçue pour des données non IID. Nous utilisons PyTorch pour entraîner une Réseau Neural Convolutionnel (avec couches de Normalisation par Batch) sur le jeu de données CIFAR-10. Lorsque vous appliquez FedBN, seules des modifications mineures sont nécessaires par rapport à Quickstart PyTorch.

Model

Une introduction complète à l’apprentissage fédéré avec PyTorch et Flower peut être trouvée dans Quickstart PyTorch. Ce guide vous montre uniquement quelques détails modifiés dans task.py. FedBN nécessite une architecture de modèle (définie dans la classe Net()) qui utilise des couches de normalisation par batch :

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

Essayez d’éditer l’architecture du modèle, puis exécutez le projet pour vous assurer que tout fonctionne toujours :

flwr run .

Cela devrait ressembler à quelque chose de familier si vous avez déjà utilisé Flower avec PyTorch avant.

FedBN

Pour adopter FedBN, nous révisons la méthode train dans ClientApp. Les paramètres de normalisation par batch sont exclus du dictionnaire d’état du modèle lorsqu’ils sont envoyés ou reçus vers le ServerApp:

app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""

    # Load the model and initialize it with the received weights
    model = Net()
    state_dict = msg.content["arrays"].to_torch_state_dict()

    # Exclude batch normalization parameters
    state_dict_wo_bn = OrderedDict(
        (k, v) for k, v in state_dict.items() if "bn" not in k
    )
    model.load_state_dict(state_dict_wo_bn, strict=True)

    # ... [Perform training]

    # Construct and return reply Message
    state_dict_wo_bn = OrderedDict(
        (k, v) for k, v in model.state_dict().items() if "bn" not in k
    )
    model_record = ArrayRecord(state_dict_wo_bn)

    ...

Pour tester la nouvelle approche, exécutez le projet à nouveau :

flwr run .

Votre projet PyTorch fonctionne maintenant avec l’apprentissage fédéré en utilisant FedBN. Félicitations !

Prochaines étapes

L’exemple est certainement trop simplifié puisque tous les ClientApps chargent exactement le même jeu de données. C’est loin d’être réaliste. Vous avez maintenant les outils pour explorer ce sujet plus en profondeur. Qu’en pensez-vous de l’utilisation de sous-ensembles différents de CIFAR-10 sur chaque client ? Qu’en pensez-vous de l’ajout de clients supplémentaires ?