示例: PyTorch 中的 FedBN - 从集中式到联邦式#

This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with FedBN, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. When applying FedBN, only few changes needed compared to Example: PyTorch - From Centralized To Federated.

集中式训练#

All files are revised based on Example: PyTorch - From Centralized To Federated. The only thing to do is modifying the file called cifar.py, revised part is shown below:

类 Net() 中定义的模型架构会相应添加Batch Normalization层。

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

现在,您可以运行您的机器学习工作了:

python3 cifar.py

So far this should all look fairly familiar if you've used PyTorch before. Let's take the next step and use what we've built to create a federated learning system within FedBN, the sytstem consists of one server and two clients.

联邦培训#

If you have read Example: PyTorch - From Centralized To Federated, the following parts are easy to follow, onyl get_parameters and set_parameters function in client.py needed to revise. If not, please read the Example: PyTorch - From Centralized To Federated. first.

我们的示例包括一个*服务器*和两个*客户端*。在 FedBN 中,server.py 保持不变,我们可以直接启动服务器。

python3 server.py

最后,我们将修改 client 的逻辑,修改 client.py 中的 get_parametersset_parameters,在向服务器发送或从服务器接收时,我们将从模型参数列表中排除batch normalization层的参数。

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    ...

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN
        return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

    ...

现在,您可以打开另外两个终端窗口并运行程序

python3 client.py

确保服务器仍在运行后,然后您就能看到您的 PyTorch 项目(之前是集中式的)通过 FedBN 策略在两个客户端上运行联合学习。祝贺!

下一步工作#

本示例的完整源代码可在 <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ 找到。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。让我们准备好进一步探讨这一主题。如在每个客户端使用不同的 CIFAR-10 子集,或者增加客户端的数量。