Open in Colab

自定义客户端#

欢迎来到 Flower 联邦学习教程的第四部分。在本教程的前几部分中,我们介绍了 PyTorch 和 Flower 的联邦学习(part 1),了解了如何使用策略来定制服务器和客户端的执行(part 2),并从头开始构建了我们自己的定制策略(part 3)。

在本笔记中,我们将重温 NumPyClient 并引入一个用于构建客户端的新基类,简单命名为 Client。在本教程的前几部分中,我们的客户端基于``NumPyClient``,这是一个便捷类,可以让我们轻松地与具有良好 NumPy 互操作性的机器学习库协同工作。有了 Client,我们获得了很多以前没有的灵活性,但我们也必须做一些以前不需要做的事情。

Star Flower on GitHub ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help: - Join Flower Discuss We’d love to hear from you in the Introduction topic! If anything is unclear, post in Flower Help - Beginners. - Join Flower Slack We’d love to hear from you in the #introductions channel! If anything is unclear, head over to the #questions channel.

Let’s go deeper and see what it takes to move from NumPyClient to Client! 🌼

步骤 0:准备工作#

在开始实际代码之前,让我们先确保我们已经准备好了所需的一切。

安装依赖项#

首先,我们安装必要的软件包:

[ ]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision scipy

现在我们已经安装了所有依赖项,可以导入本教程所需的所有内容:

[ ]:
from collections import OrderedDict
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

可以切换到已启用 GPU 加速的运行时(在 Google Colab 上: 运行时 > 更改运行时类型 > 硬件加速: GPU > 保存``)。但请注意,Google Colab 并非总能提供 GPU 加速。如果在以下部分中看到与 GPU 可用性相关的错误,请考虑通过设置 DEVICE = torch.device("cpu") 切回基于 CPU 的执行。如果运行时已启用 GPU 加速,你应该会看到输出``Training on cuda``,否则会显示``Training on cpu``。

数据加载#

Let’s now define a loading function for the CIFAR-10 training and test set, partition them into num_partitions smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader.

[ ]:
def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

模型培训/评估#

让我们继续使用常见的模型定义(包括 set_parametersget_parameters)、训练和测试函数:

[ ]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

步骤 1:重温 NumPyClient#

So far, we’ve implemented our client by subclassing flwr.client.NumPyClient. The three methods we implemented are get_parameters, fit, and evaluate.

[ ]:
class FlowerNumPyClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

Then, we define the function numpyclient_fn that is used by Flower to create the FlowerNumpyClient instances on demand. Finally, we create the ClientApp and pass the numpyclient_fn to it.

[ ]:
def numpyclient_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerNumPyClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
numpyclient = ClientApp(client_fn=numpyclient_fn)

We’ve seen this before, there’s nothing new so far. The only tiny difference compared to the previous notebook is naming, we’ve changed FlowerClient to FlowerNumPyClient and client_fn to numpyclient_fn. Next, we configure the number of federated learning rounds using ServerConfig and create the ServerApp with this config:

[ ]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(config=config)


# Create ServerApp
server = ServerApp(server_fn=server_fn)

Finally, we specify the resources for each client and run the simulation to see the output we get:

[ ]:
# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}

NUM_PARTITIONS = 10

# Run simulation
run_simulation(
    server_app=server,
    client_app=numpyclient,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

This works as expected, ten clients are training for three rounds of federated learning.

Let’s dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, run_simulation launches the ClientApp object which in turn calls the function numpyclient_fn to create an instance of our FlowerNumPyClient (along with loading the model and the data).

但令人惊讶的部分也许就在这里: Flower 实际上并不直接使用 FlowerNumPyClient 对象。相反,它封装了该对象,使其看起来像 flwr.client.Client 的子类,而不是 flwr.client.NumPyClient。事实上,Flower 核心框架不知道如何处理 NumPyClient,它只知道如何处理 ClientNumPyClient 只是建立在``Client``之上的便捷抽象类。

与其在 NumPyClient 上构建,我们可以直接在 Client 上构建。

步骤 2:从 NumPyClient 移至 Client#

让我们尝试使用 Client 代替 NumPyClient 做同样的事情。

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

在详细讨论代码之前,让我们试着运行它!必须确保我们基于 Client 的新客户端能正常运行,对吗?

[ ]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

就是这样,我们现在开始使用 Client。它看起来可能与我们使用 NumPyClient 所做的类似。那么有什么不同呢?

首先,它的代码更多。但为什么呢?区别在于 Client 希望我们处理参数的序列化和反序列化。Flower 要想通过网络发送参数,最终需要将这些参数转化为 字节。把参数(例如 NumPy 的 ndarray 参数)变成原始字节叫做序列化。将原始字节转换成更有用的东西(如 NumPy ``ndarray`)称为反序列化。Flower 需要同时做这两件事:它需要在服务器端序列化参数并将其发送到客户端,客户端需要反序列化参数以便将其用于本地训练,然后再次序列化更新后的参数并将其发送回服务器,服务器(最后)再次反序列化参数以便将其与从其他客户端接收到的更新汇总在一起。

Client 与 NumPyClient 之间的唯一**真正区别在于,NumPyClient 会为你处理序列化和反序列化。NumPyClient之所以能做到这一点,是因为它预计你会以NumPy ndarray的形式返回参数,而且它知道如何处理这些参数。这使得与具有良好 NumPy 支持的大多数机器学习库一起工作变得轻而易举。

在 API 方面,有一个主要区别:Client 中的所有方法都只接受一个参数(例如,Client.fit 中的 FitIns),并只返回一个值(例如,Client.fit 中的 FitRes)。另一方面,NumPyClient``中的方法有多个参数(例如,``NumPyClient.fit``中的``parameters``和``config)和多个返回值(例如,NumPyClient.fit``中的``parametersnum_example``和``metrics)。在 Client 中的这些 *Ins*Res 对象封装了你在 NumPyClient 中习惯使用的所有单个值。

步骤 3:自定义序列化#

下面我们将通过一个简单的示例来探讨如何实现自定义序列化。

首先,什么是序列化?序列化只是将对象转换为原始字节的过程,同样重要的是,反序列化是将原始字节转换回对象的过程。这对网络通信非常有用。事实上,如果没有序列化,你就无法通过互联网传输一个 Python 对象。

通过在客户端和服务器之间来回发送 Python 对象,联合学习在很大程度上依赖于互联网通信进行训练。这意味着序列化是联邦学习的重要组成部分。

在下面的章节中,我们将编写一个基本示例,在发送包含参数的 ndarray 前,我们将首先把 ndarray 转换为稀疏矩阵,而不是发送序列化版本。这种技术可以用来节省带宽,因为在某些情况下,模型的参数是稀疏的(包含许多 0 条目),将它们转换成稀疏矩阵可以大大提高它们的字节数。

我们的定制序列化/反序列化功能#

这才是真正的序列化/反序列化,尤其是在用于序列化的 ndarray_too_sparse_bytes 和用于反序列化的 sparse_bytes_too_ndarray 中。

请注意,为了转换数组,我们导入了 scipy.sparse 库。

[ ]:
from io import BytesIO
from typing import cast

import numpy as np

from flwr.common.typing import NDArray, NDArrays, Parameters


def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(tensors=tensors, tensor_type="numpy.ndarray")


def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]


def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
    """Serialize NumPy ndarray to bytes."""
    bytes_io = BytesIO()

    if len(ndarray.shape) > 1:
        # We convert our ndarray into a sparse matrix
        ndarray = torch.tensor(ndarray).to_sparse_csr()

        # And send it byutilizing the sparse matrix attributes
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.savez(
            bytes_io,  # type: ignore
            crow_indices=ndarray.crow_indices(),
            col_indices=ndarray.col_indices(),
            values=ndarray.values(),
            allow_pickle=False,
        )
    else:
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.save(bytes_io, ndarray, allow_pickle=False)
    return bytes_io.getvalue()


def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore

    if "crow_indices" in loader:
        # We convert our sparse matrix back to a ndarray, using the attributes we sent
        ndarray_deserialized = (
            torch.sparse_csr_tensor(
                crow_indices=loader["crow_indices"],
                col_indices=loader["col_indices"],
                values=loader["values"],
            )
            .to_dense()
            .numpy()
        )
    else:
        ndarray_deserialized = loader
    return cast(NDArray, ndarray_deserialized)

客户端#

为了能够将我们的 ndarray 序列化为稀疏参数,我们只需在 flwr.client.Client 中调用我们的自定义函数。

事实上,在 get_parameters 中,我们需要使用上文定义的自定义 ndarrays_too_sparse_parameters 序列化从网络中获取的参数。

fit 中,我们首先需要使用自定义的 sparse_parameters_to_ndarrays 反序列化来自服务器的参数,然后使用 ndarrays_to_sparse_parameters 序列化本地结果。

evaluate 中,我们只需要用自定义函数反序列化全局参数。

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters = ndarrays_to_sparse_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()

服务器端#

在本例中,我们将只使用 FedAvg 作为策略。要改变这里的序列化和反序列化,我们只需重新实现 FedAvgevaluateaggregate_fit 函数。策略的其他函数将从超类 FedAvg 继承。

正如你所看到的,``evaluate``中只修改了一行:

parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

而对于 aggregate_fit,我们将首先反序列化收到的每个结果:

weights_results = [
    (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
    for _, fit_res in results
]

然后将汇总结果序列化:

parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
[ ]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


class FedSparse(FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        """Custom FedAvg strategy with sparse matrices.

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """

        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None

        # We deserialize using our custom method
        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # We deserialize each of the results with our custom method
        weights_results = [
            (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        # We serialize the aggregated result using our custom method
        parameters_aggregated = ndarrays_to_sparse_parameters(
            aggregate(weights_results)
        )

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

现在我们可以运行自定义序列化示例!

[ ]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(
        config=config,
        strategy=FedSparse(),  # <-- pass the new strategy here
    )


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

回顾#

在本部分教程中,我们已经了解了如何通过子类化 NumPyClientClient 来构建客户端。NumPyClient "是一个便捷的抽象类,可以让我们更容易地与具有良好NumPy互操作性的机器学习库一起工作。``Client``是一个更灵活的抽象类,允许我们做一些在`NumPyClient``中做不到的事情。为此,它要求我们自己处理参数序列化和反序列化。

接下来的步骤#

Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).

如果您需要帮助,我们有专门的 #questions 频道,但我们也很乐意在 #introductions 中了解您是谁!

这暂时是 Flower 教程的最后一部分,恭喜您!您现在已经具备了理解其余文档的能力。本教程还有许多内容没有涉及,我们推荐您参考以下资源:


Open in Colab