Build a strategy from scratch¶

Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and the Flower framework (part 1) and we learned how strategies can be used to customize the execution on both the server and the clients (part 2).

In this notebook, we’ll continue to customize the federated learning system we built previously by creating a custom version of FedAvg using the Flower framework, Flower Datasets, and PyTorch.

Let’s build a new Strategy from scratch! đŸŒŒ


Avant de commencer le code proprement dit, assurons-nous que nous disposons de tout ce dont nous avons besoin.

Installation des dépendances¶

Tout d’abord, nous installons les paquets nĂ©cessaires :

!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

Maintenant que toutes les dépendances sont installées, nous pouvons importer tout ce dont nous avons besoin pour ce tutoriel :

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Il est possible de passer Ă  un runtime dont l’accĂ©lĂ©ration GPU est activĂ©e (sur Google Colab : Runtime > Change runtime type > Hardware acclerator : GPU > Save). Note cependant que Google Colab n’est pas toujours en mesure de proposer l’accĂ©lĂ©ration GPU. Si tu vois une erreur liĂ©e Ă  la disponibilitĂ© du GPU dans l’une des sections suivantes, envisage de repasser Ă  une exĂ©cution basĂ©e sur le CPU en dĂ©finissant DEVICE = torch.device("cpu"). Si le runtime a activĂ© l’accĂ©lĂ©ration GPU, tu devrais voir apparaĂźtre le rĂ©sultat Training on cuda, sinon il dira Training on cpu.

Chargement des données¶

Chargeons maintenant les ensembles d’entraĂźnement et de test CIFAR-10, divisons-les en dix ensembles de donnĂ©es plus petits (chacun divisĂ© en ensemble d’entraĂźnement et de validation) et enveloppons le tout dans leur propre DataLoader.

def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

Formation/évaluation du modÚle¶

Continuons avec la dĂ©finition habituelle du modĂšle (y compris set_parameters et get_parameters), les fonctions d’entraĂźnement et de test :

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels =,
            outputs = net(images)
            loss = criterion(net(images), labels)
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels =,
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Client de Flower¶

To implement the Flower client, we (again) create a subclass of flwr.client.NumPyClient and implement the three methods get_parameters, fit, and evaluate. Here, we also pass the partition_id to the client and use it log additional details. We then create an instance of ClientApp and pass it the client_fn.

class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(, parameters)
        train(, self.trainloader, epochs=1)
        return get_parameters(, len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(, parameters)
        loss, accuracy = test(, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()

# Create the ClientApp
client = ClientApp(client_fn=client_fn)

Testons ce que nous avons jusqu’à prĂ©sent avant de continuer :

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    # If no strategy is provided, by default, ServerAppComponents will use FedAvg
    return ServerAppComponents(config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}

# Run simulation

Élaborer une stratĂ©gie Ă  partir de zĂ©ro¶

Remplaçons la mĂ©thode configure_fit de façon Ă  ce qu’elle transmette un taux d’apprentissage plus Ă©levĂ© (potentiellement aussi d’autres hyperparamĂštres) Ă  l’optimiseur d’une fraction des clients. Nous garderons l’échantillonnage des clients tel qu’il est dans FedAvg et changerons ensuite le dictionnaire de configuration (l’un des attributs FitIns).

from typing import Union

from flwr.common import (
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg

class FedCustom(Strategy):
    def __init__(
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
    ) -> None:
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients

    def __repr__(self) -> str:
        return "FedCustom"

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        return ndarrays_to_parameters(ndarrays)

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients

        # Create custom configs
        n_clients = len(clients)
        half_clients = n_clients // 2
        standard_config = {"lr": 0.001}
        higher_lr_config = {"lr": 0.003}
        fit_configurations = []
        for idx, client in enumerate(clients):
            if idx < half_clients:
                fit_configurations.append((client, FitIns(parameters, standard_config)))
                    (client, FitIns(parameters, higher_lr_config))
        return fit_configurations

    def aggregate_fit(
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""

        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate global model parameters using an evaluation function."""

        # Let's assume we won't perform the global model evaluation on the server side.
        return None

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

Il ne reste plus qu’à utiliser la stratĂ©gie personnalisĂ©e nouvellement crĂ©Ă©e FedCustom lors du dĂ©marrage de l’expĂ©rience :

def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=FedCustom(),  # <-- pass the new strategy here

# Run simulation


Dans ce carnet, nous avons vu comment mettre en place une stratĂ©gie personnalisĂ©e. Une stratĂ©gie personnalisĂ©e permet un contrĂŽle granulaire sur la configuration des nƓuds clients, l’agrĂ©gation des rĂ©sultats, et bien plus encore. Pour dĂ©finir une stratĂ©gie personnalisĂ©e, il te suffit d’écraser les mĂ©thodes abstraites de la classe de base (abstraite) Strategy. Pour rendre les stratĂ©gies personnalisĂ©es encore plus puissantes, tu peux passer des fonctions personnalisĂ©es au constructeur de ta nouvelle classe (__init__) et appeler ensuite ces fonctions Ă  chaque fois que c’est nĂ©cessaire.

Prochaines étapes¶

Before you continue, make sure to join the Flower community on Flower Discuss (Join Flower Discuss) and on Slack (Join Slack).

Il existe un canal dĂ©diĂ© aux questions si vous avez besoin d’aide, mais nous aimerions aussi savoir qui vous ĂȘtes dans #introductions !

The Flower Federated Learning Tutorial - Part 4 introduces Client, the flexible API underlying NumPyClient.

