Implement FedBN

This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with FedBN, a federated training method designed for non-IID data. We are using PyTorch to train a Convolutional Neural Network (with Batch Normalization layers) on the CIFAR-10 dataset. When applying FedBN, only minor changes are needed compared to Quickstart PyTorch.

Model

A full introduction to federated learning with PyTorch and Flower can be found in Quickstart PyTorch. This how-to guide varies only a few details in task.py. FedBN requires a model architecture (defined in class Net()) that uses Batch Normalization layers:

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

Try editing the model architecture, then run the project to ensure everything still works:

flwr run .

So far this should all look fairly familiar if you’ve used Flower with PyTorch before.

FedBN

To adopt FedBN, only the get_parameters and set_parameters functions in task.py need to be revised. FedBN only changes the client-side by excluding batch normalization parameters from being exchanged with the server.

We revise the client logic by changing get_parameters and set_parameters in task.py. The batch normalization parameters are excluded from model parameter list when sending to or receiving from the server:

class FlowerClient(NumPyClient):
    """Flower client for CIFAR-10 image classification using PyTorch."""

    # ... [other FlowerClient methods]

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays
        # Exclude parameters of BN layers when using FedBN
        return [
            val.cpu().numpy()
            for name, val in self.model.state_dict().items()
            if "bn" not in name
        ]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if "bn" not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

    ...

To test the new appraoch, run the project again:

flwr run .

Your PyTorch project now runs federated learning with FedBN. Congratulations!

다음 단계

The example is of course over-simplified since all clients load the exact same dataset. This isn’t realistic. You now have the tools to explore this topic further. How about using different subsets of CIFAR-10 on each client? How about adding more clients?