Open in Colab

Customize the client¶

Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1), we learned how strategies can be used to customize the execution on both the server and the clients (part 2), and we built our own custom strategy from scratch (part 3).

Dans ce carnet, nous revisitons NumPyClient` et introduisons une nouvelle classe de base pour construire des clients, simplement appelĂ©e Client`. Dans les parties prĂ©cĂ©dentes de ce tutoriel, nous avons basĂ© notre client sur NumPyClient, une classe de commoditĂ© qui facilite le travail avec les bibliothĂšques d’apprentissage automatique qui ont une bonne interopĂ©rabilitĂ© NumPy. Avec Client, nous gagnons beaucoup de flexibilitĂ© que nous n’avions pas auparavant, mais nous devrons Ă©galement faire quelques choses que nous n’avions pas Ă  faire auparavant.

Star Flower on GitHub ⭐ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help: - Join Flower Discuss We’d love to hear from you in the Introduction topic! If anything is unclear, post in Flower Help - Beginners. - Join Flower Slack We’d love to hear from you in the #introductions channel! If anything is unclear, head over to the #questions channel.

Let’s go deeper and see what it takes to move from NumPyClient to Client! đŸŒŒ

Étape 0 : PrĂ©paration¶

Avant de commencer le code proprement dit, assurons-nous que nous disposons de tout ce dont nous avons besoin.

Installation des dépendances¶

Tout d’abord, nous installons les paquets nĂ©cessaires :

[ ]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision scipy

Maintenant que toutes les dépendances sont installées, nous pouvons importer tout ce dont nous avons besoin pour ce tutoriel :

[ ]:
from collections import OrderedDict
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Il est possible de passer Ă  un runtime dont l’accĂ©lĂ©ration GPU est activĂ©e (sur Google Colab : Runtime > Change runtime type > Hardware acclerator : GPU > Save). Note cependant que Google Colab n’est pas toujours en mesure de proposer l’accĂ©lĂ©ration GPU. Si tu vois une erreur liĂ©e Ă  la disponibilitĂ© du GPU dans l’une des sections suivantes, envisage de repasser Ă  une exĂ©cution basĂ©e sur le CPU en dĂ©finissant DEVICE = torch.device("cpu"). Si le runtime a activĂ© l’accĂ©lĂ©ration GPU, tu devrais voir apparaĂźtre le rĂ©sultat Training on cuda, sinon il dira Training on cpu.

Chargement des données¶

Let’s now define a loading function for the CIFAR-10 training and test set, partition them into num_partitions smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader.

[ ]:
def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

Formation/évaluation du modÚle¶

Continuons avec la dĂ©finition habituelle du modĂšle (y compris set_parameters et get_parameters), les fonctions d’entraĂźnement et de test :

[ ]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Étape 1 : Revoir NumPyClient¶

So far, we’ve implemented our client by subclassing flwr.client.NumPyClient. The three methods we implemented are get_parameters, fit, and evaluate.

[ ]:
class FlowerNumPyClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

Then, we define the function numpyclient_fn that is used by Flower to create the FlowerNumpyClient instances on demand. Finally, we create the ClientApp and pass the numpyclient_fn to it.

[ ]:
def numpyclient_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerNumPyClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
numpyclient = ClientApp(client_fn=numpyclient_fn)

We’ve seen this before, there’s nothing new so far. The only tiny difference compared to the previous notebook is naming, we’ve changed FlowerClient to FlowerNumPyClient and client_fn to numpyclient_fn. Next, we configure the number of federated learning rounds using ServerConfig and create the ServerApp with this config:

[ ]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(config=config)


# Create ServerApp
server = ServerApp(server_fn=server_fn)

Finally, we specify the resources for each client and run the simulation to see the output we get:

[ ]:
# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}

NUM_PARTITIONS = 10

# Run simulation
run_simulation(
    server_app=server,
    client_app=numpyclient,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

This works as expected, ten clients are training for three rounds of federated learning.

Let’s dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, run_simulation launches the ClientApp object which in turn calls the function numpyclient_fn to create an instance of our FlowerNumPyClient (along with loading the model and the data).

Mais voici la partie la plus surprenante : Flower n’utilise pas directement l’objet FlowerNumPyClient. Au lieu de cela, il enveloppe l’objet pour le faire ressembler Ă  une sous-classe de flwr.client.Client, et non de flwr.client.NumPyClient. En fait, le noyau de Flower ne sait pas comment gĂ©rer les NumPyClient, il sait seulement comment gĂ©rer les Client. NumPyClient est juste une abstraction de commoditĂ© construite au dessus de Client.

Au lieu de construire par-dessus NumPyClient`, nous pouvons construire directement par-dessus Client`.

Étape 2 : Passer de NumPyClient Ă  Client¶

Essayons de faire la mĂȘme chose en utilisant Client au lieu de NumPyClient.

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

Avant de discuter du code plus en dĂ©tail, essayons de l’exĂ©cuter ! Nous devons nous assurer que notre nouveau client basĂ© sur le Client fonctionne, n’est-ce pas ?

[ ]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

Voilà, nous utilisons maintenant Client. Cela ressemble probablement à ce que nous avons fait avec NumPyClient. Alors quelle est la différence ?

First of all, it’s more code. But why? The difference comes from the fact that Client expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into bytes. Turning parameters (e.g., NumPy ndarray’s) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy ndarray’s) is called deserialization. Flower needs to do both: it needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.

La seule vraie diffĂ©rence entre Client et NumPyClient est que NumPyClient s’occupe de la sĂ©rialisation et de la dĂ©sĂ©rialisation pour toi. Il peut le faire parce qu’il s’attend Ă  ce que tu renvoies des paramĂštres sous forme de NumPy ndarray, et il sait comment les gĂ©rer. Cela permet de travailler avec des bibliothĂšques d’apprentissage automatique qui ont une bonne prise en charge de NumPy (la plupart d’entre elles) en un clin d’Ɠil.

In terms of API, there’s one major difference: all methods in Client take exactly one argument (e.g., FitIns in Client.fit) and return exactly one value (e.g., FitRes in Client.fit). The methods in NumPyClient on the other hand have multiple arguments (e.g., parameters and config in NumPyClient.fit) and multiple return values (e.g., parameters, num_example, and metrics in NumPyClient.fit) if there are multiple things to handle. These *Ins and *Res objects in Client wrap all the individual values you’re used to from NumPyClient.

Étape 3 : SĂ©rialisation personnalisĂ©e¶

Nous allons ici explorer comment mettre en Ɠuvre une sĂ©rialisation personnalisĂ©e Ă  l’aide d’un exemple simple.

Mais d’abord, qu’est-ce que la sĂ©rialisation ? La sĂ©rialisation est simplement le processus de conversion d’un objet en octets bruts, et tout aussi important, la dĂ©sĂ©rialisation est le processus de reconversion des octets bruts en objet. Ceci est trĂšs utile pour la communication rĂ©seau. En effet, sans la sĂ©rialisation, tu ne pourrais pas faire passer un objet Python par Internet.

L’apprentissage fĂ©dĂ©rĂ© s’appuie fortement sur la communication Internet pour la formation en envoyant des objets Python dans les deux sens entre les clients et le serveur, ce qui signifie que la sĂ©rialisation est un Ă©lĂ©ment essentiel de l’apprentissage fĂ©dĂ©rĂ©.

Dans la section suivante, nous allons Ă©crire un exemple de base oĂč, au lieu d’envoyer une version sĂ©rialisĂ©e de nos ndarray contenant nos paramĂštres, nous allons d’abord convertir les ndarray en matrices Ă©parses, avant de les envoyer. Cette technique peut ĂȘtre utilisĂ©e pour Ă©conomiser de la bande passante, car dans certains cas oĂč les poids d’un modĂšle sont Ă©pars (contenant de nombreuses entrĂ©es 0), les convertir en une matrice Ă©parse peut grandement amĂ©liorer leur taille en octets.

Nos fonctions de sérialisation/désérialisation personnalisées¶

C’est lĂ  que la vĂ©ritable sĂ©rialisation/dĂ©sĂ©rialisation se produira, en particulier dans ndarray_to_sparse_bytes pour la sĂ©rialisation et sparse_bytes_to_ndarray pour la dĂ©sĂ©rialisation.

Notez que nous avons importé la bibliothÚque scipy.sparse afin de convertir nos tableaux.

[ ]:
from io import BytesIO
from typing import cast

import numpy as np

from flwr.common.typing import NDArray, NDArrays, Parameters


def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(tensors=tensors, tensor_type="numpy.ndarray")


def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]


def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
    """Serialize NumPy ndarray to bytes."""
    bytes_io = BytesIO()

    if len(ndarray.shape) > 1:
        # We convert our ndarray into a sparse matrix
        ndarray = torch.tensor(ndarray).to_sparse_csr()

        # And send it byutilizing the sparse matrix attributes
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.savez(
            bytes_io,  # type: ignore
            crow_indices=ndarray.crow_indices(),
            col_indices=ndarray.col_indices(),
            values=ndarray.values(),
            allow_pickle=False,
        )
    else:
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.save(bytes_io, ndarray, allow_pickle=False)
    return bytes_io.getvalue()


def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore

    if "crow_indices" in loader:
        # We convert our sparse matrix back to a ndarray, using the attributes we sent
        ndarray_deserialized = (
            torch.sparse_csr_tensor(
                crow_indices=loader["crow_indices"],
                col_indices=loader["col_indices"],
                values=loader["values"],
            )
            .to_dense()
            .numpy()
        )
    else:
        ndarray_deserialized = loader
    return cast(NDArray, ndarray_deserialized)

CÎté client¶

Pour pouvoir sĂ©rialiser nos ndarray en paramĂštres sparse, il nous suffira d’appeler nos fonctions personnalisĂ©es dans notre flwr.client.Client.

En effet, dans get_parameters nous devons sérialiser les paramÚtres que nous avons obtenus de notre réseau en utilisant nos ndarrays_to_sparse_parameters personnalisés définis ci-dessus.

Dans fit, nous devons d’abord dĂ©sĂ©rialiser les paramĂštres provenant du serveur en utilisant notre sparse_parameters_to_ndarrays personnalisĂ©, puis nous devons sĂ©rialiser nos rĂ©sultats locaux avec ndarrays_to_sparse_parameters.

Dans evaluate, nous n’aurons besoin que de dĂ©sĂ©rialiser les paramĂštres globaux avec notre fonction personnalisĂ©e.

[ ]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters = ndarrays_to_sparse_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()

CÎté serveur¶

Pour cet exemple, nous utiliserons simplement FedAvg comme stratégie. Pour modifier la sérialisation et la désérialisation ici, il suffit de réimplémenter les fonctions evaluate et aggregate_fit de FedAvg. Les autres fonctions de la stratégie seront héritées de la super-classe FedAvg.

Comme tu peux le voir, seule une ligne a été modifiée dans evaluate :

parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

Et pour aggregate_fit, nous allons d’abord dĂ©sĂ©rialiser chaque rĂ©sultat que nous avons reçu :

weights_results = [
    (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
    for _, fit_res in results
]

Puis sérialise le résultat agrégé :

parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
[ ]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


class FedSparse(FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        """Custom FedAvg strategy with sparse matrices.

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """

        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None

        # We deserialize using our custom method
        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # We deserialize each of the results with our custom method
        weights_results = [
            (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        # We serialize the aggregated result using our custom method
        parameters_aggregated = ndarrays_to_sparse_parameters(
            aggregate(weights_results)
        )

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

Nous pouvons maintenant exécuter notre exemple de sérialisation personnalisée !

[ ]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(
        config=config,
        strategy=FedSparse(),  # <-- pass the new strategy here
    )


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

Récapitulation¶

Dans cette partie du tutoriel, nous avons vu comment construire des clients en sous-classant soit NumPyClient, soit Client. NumPyClient est une abstraction de commoditĂ© qui facilite le travail avec les bibliothĂšques d’apprentissage automatique qui ont une bonne interopĂ©rabilitĂ© NumPy. Client est une abstraction plus flexible qui nous permet de faire des choses qui ne sont pas possibles dans NumPyClient. Pour ce faire, elle nous oblige Ă  gĂ©rer nous-mĂȘmes la sĂ©rialisation et la dĂ©sĂ©rialisation des paramĂštres.

Prochaines étapes¶

Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).

Il existe un canal dĂ©diĂ© aux questions si vous avez besoin d’aide, mais nous aimerions aussi savoir qui vous ĂȘtes dans #introductions !

C’est la derniĂšre partie du tutoriel Flower (pour l’instant !), fĂ©licitations ! Tu es maintenant bien Ă©quipĂ© pour comprendre le reste de la documentation. Il y a de nombreux sujets que nous n’avons pas abordĂ©s dans le tutoriel, nous te recommandons les ressources suivantes :


Open in Colab