Aggregate evaluation results#

The Flower server does not prescribe a way to aggregate evaluation results, but it enables the user to fully customize result aggregation.

Agréger les résultats de l’évaluation personnalisée#

La même approche de personnalisation Stratégie peut être utilisée pour agréger les résultats d’évaluation personnalisés provenant de clients individuels. Les clients peuvent renvoyer des mesures personnalisées au serveur en renvoyant un dictionnaire :

class CifarClient(fl.client.NumPyClient):

    def get_parameters(self, config):
        # ...

    def fit(self, parameters, config):
        # ...

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""

        # Update local model with global parameters
        self.model.set_weights(parameters)

        # Evaluate global model parameters on the local test data
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test)

        # Return results, including the custom accuracy metric
        num_examples_test = len(self.x_test)
        return loss, num_examples_test, {"accuracy": accuracy}

Le serveur peut alors utiliser une stratégie personnalisée pour agréger les mesures fournies dans ces dictionnaires :

class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation accuracy using weighted average."""

        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

        # Weigh accuracy of each client by number of examples used
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        # Aggregate and print custom metric
        aggregated_accuracy = sum(accuracies) / sum(examples)
        print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")

        # Return aggregated loss and metrics (i.e., aggregated accuracy)
        return aggregated_loss, {"accuracy": aggregated_accuracy}

# Create strategy and run server
strategy = AggregateCustomMetricStrategy(
    # (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)