예시: PyTorch에서 FedBN - 중앙 집중식에서 연합식으로

이 튜토리얼에서는 non-iid data를 위해 설계된 federated 훈련 전략인 FedBN <https://github.com/med-air/FedBN>`_으로 기존 머신러닝 워크로드의 federated 버전을 구축하기 위해 Flower를 사용하는 방법을 보여드립니다. 우리는 PyTorch를 사용하여 CIFAR-10 데이터 세트에서 컨볼루션 신경망(일괄 정규화 레이어 포함)을 훈련하고 있습니다. FedBN을 적용할 때, :doc:`예제: 파이토치 -중앙 집중식에서 연합식으로 <example-pytorch-from-centralized-to-federated> 와 비교했을 때 몇 가지 사항만 변경 하면 됩니다.

중앙 집중식 훈련

All files are revised based on Example: PyTorch - From Centralized To Federated. The only thing to do is modifying the file called cifar.py, revised part is shown below:

Net() 클래스에 정의된 모델 아키텍처는 그에 따라 배치 정규화 레이어가 추가됩니다.

class Net(nn.Module):

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x

이제 머신 러닝 워크로드를 실행할 수 있습니다:

python3 cifar.py

지금까지는 파이토치를 사용해 본 적이 있다면 상당히 익숙하게 보일 것입니다. 다음 단계로 넘어가서 우리가 구축한 것을 사용하여 FedBN 내에서 하나의 서버와 두 개의 클라이언트로 구성된 연합학습 시스템을 만들어 보겠습니다.

연합 훈련

If you have read Example: PyTorch - From Centralized To Federated, the following parts are easy to follow, only get_parameters and set_parameters function in client.py needed to revise. If not, please read the Example: PyTorch - From Centralized To Federated. first.

Our example consists of one server and two clients. In FedBN, server.py keeps unchanged, we can start the server directly.

python3 server.py

Finally, we will revise our client logic by changing get_parameters and set_parameters in client.py, we will exclude batch normalization parameters from model parameter list when sending to or receiving from the server.

class CifarClient(fl.client.NumPyClient):
    """Flower client implementing CIFAR-10 image classification using
    PyTorch."""

    ...

    def get_parameters(self, config) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN
        return [
            val.cpu().numpy()
            for name, val in self.model.state_dict().items()
            if "bn" not in name
        ]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if "bn" not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

    ...

이제 두 개의 터미널 창을 추가로 열고 다음을 실행할 수 있습니다

python3 client.py

를 입력하고(클릭하기 전에 서버가 계속 실행 중인지 확인하세요), (이전에 중앙 집중된) PyTorch 프로젝트가 두 클라이언트에서 FedBN으로 연합 학습을 실행하는 것을 확인합니다. 축하합니다!

다음 단계

이 예제의 전체 소스 코드는 ‘여기 <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_’에서 확인할 수 있습니다. 물론 이 예제는 두 클라이언트가 완전히 동일한 데이터 세트를 로드하기 때문에 다소 지나치게 단순화되어 있으며, 이는 현실적이지 않습니다. 이제 이 주제를 더 자세히 살펴볼 준비가 되셨습니다. 각 클라이언트에서 서로 다른 CIFAR-10의 하위 집합을 사용해 보는 것은 어떨까요? 클라이언트를 더 추가하는 것은 어떨까요?