Aggregate evaluation results

The Flower server does not prescribe a way to aggregate evaluation results, but it enables the user to fully customize result aggregation.

Agréger les résultats de l’évaluation personnalisée

The same Strategy-customization approach can be used to aggregate custom evaluation results coming from individual clients. Clients can return custom metrics to the server by returning a dictionary:

from flwr.client import NumPyClient


class FlowerClient(NumPyClient):

    def fit(self, parameters, config):
        # ...
        pass

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""

        # Update local model with global parameters
        self.model.set_weights(parameters)

        # Evaluate global model parameters on the local test data
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test)

        # Return results, including the custom accuracy metric
        num_examples_test = len(self.x_test)
        return float(loss), num_examples_test, {"accuracy": float(accuracy)}

Le serveur peut alors utiliser une stratégie personnalisée pour agréger les mesures fournies dans ces dictionnaires :

from flwr.server.strategy import FedAvg


class AggregateCustomMetricStrategy(FedAvg):
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation accuracy using weighted average."""

        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(
            server_round, results, failures
        )

        # Weigh accuracy of each client by number of examples used
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        # Aggregate and print custom metric
        aggregated_accuracy = sum(accuracies) / sum(examples)
        print(
            f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}"
        )

        # Return aggregated loss and metrics (i.e., aggregated accuracy)
        return float(aggregated_loss), {"accuracy": float(aggregated_accuracy)}


def server_fn(context: Context) -> ServerAppComponents:
    # Read federation rounds from config
    num_rounds = context.run_config["num-server-rounds"]
    config = ServerConfig(num_rounds=num_rounds)

    # Define strategy
    strategy = AggregateCustomMetricStrategy(
        # (same arguments as FedAvg here)
    )

    return ServerAppComponents(
        config=config,
        strategy=strategy,  # <-- pass the custom strategy here
    )


# Create ServerApp
app = ServerApp(server_fn=server_fn)