示例: PyTorch 中的 FedBN - 从集中式到联邦式

This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with FedBN, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. When applying FedBN, only few changes needed compared to Example: PyTorch - From Centralized To Federated.


All files are revised based on Example: PyTorch - From Centralized To Federated. The only thing to do is modifying the file called cifar.py, revised part is shown below:

类 Net() 中定义的模型架构会相应添加Batch Normalization层。

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x


python3 cifar.py

So far this should all look fairly familiar if you've used PyTorch before. Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients.


If you have read Example: PyTorch - From Centralized To Federated, the following parts are easy to follow, only get_parameters and set_parameters function in client.py needed to revise. If not, please read the Example: PyTorch - From Centralized To Federated. first.

Our example consists of one server and two clients. In FedBN, server.py keeps unchanged, we can start the server directly.

python3 server.py

Finally, we will revise our client logic by changing get_parameters and set_parameters in client.py, we will exclude batch normalization parameters from model parameter list when sending to or receiving from the server.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using


    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN
        return [
            for name, val in self.model.state_dict().items()
            if "bn" not in name

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if "bn" not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)



python3 client.py

确保服务器仍在运行后,然后您就能看到您的 PyTorch 项目(之前是集中式的)通过 FedBN 策略在两个客户端上运行联合学习。祝贺!


本示例的完整源代码可在 <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ 找到。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。让我们准备好进一步探讨这一主题。如在每个客户端使用不同的 CIFAR-10 子集,或者增加客户端的数量。