Open in Colab

自定义客户端#

欢迎来到 Flower 联邦学习教程的第四部分。在本教程的前几部分中,我们介绍了 PyTorch 和 Flower 的联邦学习(part 1),了解了如何使用策略来定制服务器和客户端的执行(part 2),并从头开始构建了我们自己的定制策略(part 3)。

在本笔记中,我们将重温 NumPyClient 并引入一个用于构建客户端的新基类,简单命名为 Client。在本教程的前几部分中,我们的客户端基于``NumPyClient``,这是一个便捷类,可以让我们轻松地与具有良好 NumPy 互操作性的机器学习库协同工作。有了 Client,我们获得了很多以前没有的灵活性,但我们也必须做一些以前不需要做的事情。

Star Flower on GitHub ⭐️ 并加入 Slack 上的 Flower 社区,进行交流、提问并获得帮助: 加入 Slack <https://flower.ai/join-slack>`__ 🌼 我们希望在 #introductions 频道听到您的声音!如果有任何不清楚的地方,请访问 #questions 频道。

让我们深入了解一下从 NumPyClientClient 的过程!

步骤 0:准备工作#

在开始实际代码之前,让我们先确保我们已经准备好了所需的一切。

安装依赖项#

首先,我们安装必要的软件包:

[ ]:
!pip install -q flwr[simulation] torch torchvision scipy

现在我们已经安装了所有依赖项,可以导入本教程所需的所有内容:

[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

import flwr as fl

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

可以切换到已启用 GPU 加速的运行时(在 Google Colab 上: 运行时 > 更改运行时类型 > 硬件加速: GPU > 保存``)。但请注意,Google Colab 并非总能提供 GPU 加速。如果在以下部分中看到与 GPU 可用性相关的错误,请考虑通过设置 DEVICE = torch.device("cpu") 切回基于 CPU 的执行。如果运行时已启用 GPU 加速,你应该会看到输出``Training on cuda``,否则会显示``Training on cpu``。

数据加载#

现在,让我们加载 CIFAR-10 训练集和测试集,将它们分割成十个较小的数据集(每个数据集又分为训练集和验证集),并将所有数据都封装在各自的 DataLoader 中。

[ ]:
NUM_CLIENTS = 10


def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

模型培训/评估#

让我们继续使用常见的模型定义(包括 set_parametersget_parameters)、训练和测试函数:

[ ]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

步骤 1:重温 NumPyClient#

到目前为止,我们通过子类化 flwr.client.NumPyClient 实现了我们的客户端。我们实现了三个方法:get_parameters, fit`, 和``evaluate。最后,我们用一个名为 client_fn 的函数来创建该类的实例:

[ ]:
class FlowerNumPyClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def numpyclient_fn(cid) -> FlowerNumPyClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerNumPyClient(cid, net, trainloader, valloader)

我们以前见过这种情况,目前没有什么新东西。与之前的笔记相比,唯一*小*的不同是命名,我们把 FlowerClient 改成了 FlowerNumPyClient,把 client_fn 改成了 numpyclient_fn。让我们运行它看看输出结果:

[ ]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

fl.simulation.start_simulation(
    client_fn=numpyclient_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    client_resources=client_resources,
)

结果不出所料,两个客户端正在进行三轮联邦学习训练。

让我们再深入一点,讨论一下 Flower 是如何执行模拟的。每当一个客户端被选中进行工作时,start_simulation` 就会调用函数 numpyclient_fn 来创建我们的 FlowerNumPyClient 实例(同时加载模型和数据)。

但令人惊讶的部分也许就在这里: Flower 实际上并不直接使用 FlowerNumPyClient 对象。相反,它封装了该对象,使其看起来像 flwr.client.Client 的子类,而不是 flwr.client.NumPyClient。事实上,Flower 核心框架不知道如何处理 NumPyClient,它只知道如何处理 ClientNumPyClient 只是建立在``Client``之上的便捷抽象类。

与其在 NumPyClient 上构建,我们可以直接在 Client 上构建。

步骤 2:从 NumPyClient 移至 Client#

让我们尝试使用 Client 代替 NumPyClient 做同样的事情。

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)


class FlowerClient(fl.client.Client):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.cid}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.cid}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.cid}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)
        # return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

在详细讨论代码之前,让我们试着运行它!必须确保我们基于 Client 的新客户端能正常运行,对吗?

[ ]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    client_resources=client_resources,
)

就是这样,我们现在开始使用 Client。它看起来可能与我们使用 NumPyClient 所做的类似。那么有什么不同呢?

首先,它的代码更多。但为什么呢?区别在于 Client 希望我们处理参数的序列化和反序列化。Flower 要想通过网络发送参数,最终需要将这些参数转化为 字节。把参数(例如 NumPy 的 ndarray 参数)变成原始字节叫做序列化。将原始字节转换成更有用的东西(如 NumPy ``ndarray`)称为反序列化。Flower 需要同时做这两件事:它需要在服务器端序列化参数并将其发送到客户端,客户端需要反序列化参数以便将其用于本地训练,然后再次序列化更新后的参数并将其发送回服务器,服务器(最后)再次反序列化参数以便将其与从其他客户端接收到的更新汇总在一起。

Client 与 NumPyClient 之间的唯一**真正区别在于,NumPyClient 会为你处理序列化和反序列化。NumPyClient之所以能做到这一点,是因为它预计你会以NumPy ndarray的形式返回参数,而且它知道如何处理这些参数。这使得与具有良好 NumPy 支持的大多数机器学习库一起工作变得轻而易举。

在 API 方面,有一个主要区别:Client 中的所有方法都只接受一个参数(例如,Client.fit 中的 FitIns),并只返回一个值(例如,Client.fit 中的 FitRes)。另一方面,NumPyClient``中的方法有多个参数(例如,``NumPyClient.fit``中的``parameters``和``config)和多个返回值(例如,NumPyClient.fit``中的``parametersnum_example``和``metrics)。在 Client 中的这些 *Ins*Res 对象封装了你在 NumPyClient 中习惯使用的所有单个值。

步骤 3:自定义序列化#

下面我们将通过一个简单的示例来探讨如何实现自定义序列化。

首先,什么是序列化?序列化只是将对象转换为原始字节的过程,同样重要的是,反序列化是将原始字节转换回对象的过程。这对网络通信非常有用。事实上,如果没有序列化,你就无法通过互联网传输一个 Python 对象。

通过在客户端和服务器之间来回发送 Python 对象,联合学习在很大程度上依赖于互联网通信进行训练。这意味着序列化是联邦学习的重要组成部分。

在下面的章节中,我们将编写一个基本示例,在发送包含参数的 ndarray 前,我们将首先把 ndarray 转换为稀疏矩阵,而不是发送序列化版本。这种技术可以用来节省带宽,因为在某些情况下,模型的参数是稀疏的(包含许多 0 条目),将它们转换成稀疏矩阵可以大大提高它们的字节数。

我们的定制序列化/反序列化功能#

这才是真正的序列化/反序列化,尤其是在用于序列化的 ndarray_too_sparse_bytes 和用于反序列化的 sparse_bytes_too_ndarray 中。

请注意,为了转换数组,我们导入了 scipy.sparse 库。

[ ]:
from io import BytesIO
from typing import cast

import numpy as np

from flwr.common.typing import NDArray, NDArrays, Parameters


def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(tensors=tensors, tensor_type="numpy.ndarray")


def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]


def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
    """Serialize NumPy ndarray to bytes."""
    bytes_io = BytesIO()

    if len(ndarray.shape) > 1:
        # We convert our ndarray into a sparse matrix
        ndarray = torch.tensor(ndarray).to_sparse_csr()

        # And send it byutilizing the sparse matrix attributes
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.savez(
            bytes_io,  # type: ignore
            crow_indices=ndarray.crow_indices(),
            col_indices=ndarray.col_indices(),
            values=ndarray.values(),
            allow_pickle=False,
        )
    else:
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.save(bytes_io, ndarray, allow_pickle=False)
    return bytes_io.getvalue()


def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore

    if "crow_indices" in loader:
        # We convert our sparse matrix back to a ndarray, using the attributes we sent
        ndarray_deserialized = (
            torch.sparse_csr_tensor(
                crow_indices=loader["crow_indices"],
                col_indices=loader["col_indices"],
                values=loader["values"],
            )
            .to_dense()
            .numpy()
        )
    else:
        ndarray_deserialized = loader
    return cast(NDArray, ndarray_deserialized)

客户端#

为了能够将我们的 ndarray 序列化为稀疏参数,我们只需在 flwr.client.Client 中调用我们的自定义函数。

事实上,在 get_parameters 中,我们需要使用上文定义的自定义 ndarrays_too_sparse_parameters 序列化从网络中获取的参数。

fit 中,我们首先需要使用自定义的 sparse_parameters_to_ndarrays 反序列化来自服务器的参数,然后使用 ndarrays_to_sparse_parameters 序列化本地结果。

evaluate 中,我们只需要用自定义函数反序列化全局参数。

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
)


class FlowerClient(fl.client.Client):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.cid}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters = ndarrays_to_sparse_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.cid}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.cid}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

服务器端#

在本例中,我们将只使用 FedAvg 作为策略。要改变这里的序列化和反序列化,我们只需重新实现 FedAvgevaluateaggregate_fit 函数。策略的其他函数将从超类 FedAvg 继承。

正如你所看到的,``evaluate``中只修改了一行:

parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

而对于 aggregate_fit,我们将首先反序列化收到的每个结果:

weights_results = [
    (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
    for _, fit_res in results
]

然后将汇总结果序列化:

parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
[ ]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


class FedSparse(FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        """Custom FedAvg strategy with sparse matrices.

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """

        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None

        # We deserialize using our custom method
        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # We deserialize each of the results with our custom method
        weights_results = [
            (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        # We serialize the aggregated result using our custom method
        parameters_aggregated = ndarrays_to_sparse_parameters(
            aggregate(weights_results)
        )

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

现在我们可以运行自定义序列化示例!

[ ]:
strategy = FedSparse()

fl.simulation.start_simulation(
    strategy=strategy,
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    client_resources=client_resources,
)

回顾#

在本部分教程中,我们已经了解了如何通过子类化 NumPyClientClient 来构建客户端。NumPyClient "是一个便捷的抽象类,可以让我们更容易地与具有良好NumPy互操作性的机器学习库一起工作。``Client``是一个更灵活的抽象类,允许我们做一些在`NumPyClient``中做不到的事情。为此,它要求我们自己处理参数序列化和反序列化。

接下来的步骤#

在继续之前,请务必加入 Slack 上的 Flower 社区:Join Slack

如果您需要帮助,我们有专门的 #questions 频道,但我们也很乐意在 #introductions 中了解您是谁!

这暂时是 Flower 教程的最后一部分,恭喜您!您现在已经具备了理解其余文档的能力。本教程还有许多内容没有涉及,我们推荐您参考以下资源:


Open in Colab