Aggregate evaluation results¶

Flower strategies (e.g. FedAvg and all that derive from it) automatically aggregate the metrics in the MetricRecord in the Messages replied by the ClientApps. By default, a weighted aggregation is performed for all metrics using as weight the value assigned to the weighted_by_key attribute of a strategy.

When constructing your strategy, you can set both the key used to perform weighted aggregation but also the callback function used to aggregate metrics.

Note

By default, Flower strategies use as weighted_by_key="num-examples". If you are interested, see the full implementation of how the default weighted aggregation callback works here.

from flwr.serverapp.strategy import FedAvg
from flwr.serverapp.strategy.strategy_utils import aggregate_metricrecords

strategy = FedAvg(
    # ... other parameters ...
    weighted_by_key="your-key",  # Key to use for weighted averaging
    evaluate_metrics_aggr_fn=my_metrics_aggr_function,  # Custom aggregation function
)

Let’s see how we can define a custom aggregation function for MetricRecord objects received in the reply of an evaluation round.

Note

Note that Flower strategies also have a train_metrics_aggr_fn attribute that allows you to define a custom aggregation function for received MetricRecord objects in reply messages of a training round. By default, it performs weighted averaging using the value assigned to the weighted_by_key exactly as the evaluate_metrics_aggr_fn presented earlier.

Using a custom metrics aggregation function¶

The evaluate_metrics_aggr_fn can be customized to support any evaluation results aggregation logic you need. Its definition is:

Callable[[list[RecordDict], str], MetricRecord]

It takes a list of RecordDict and a weighting key as inputs and returns a MetricRecord. For example, the function below extracts and returns the minimum value for each metric key across all Message:

from flwr.app import MetricRecord, RecordDict


def custom_metrics_aggregation_fn(
    records: list[RecordDict], weighting_metric_name: str
) -> MetricRecord:
    """Extract the minimum value for each metric key."""
    aggregated_metrics = MetricRecord()

    # Track current minimum per key in a plain dict,
    # then copy into MetricRecord at the end
    mins = {}

    for record in records:
        for record_item in record.metric_records.values():
            for key, value in record_item.items():
                if key == weighting_metric_name:
                    # We exclude the weighting key from the aggregated MetricRecord
                    continue

                if key in mins:
                    if value < mins[key]:
                        mins[key] = value
                else:
                    mins[key] = value

    for key, value in mins.items():
        aggregated_metrics[key] = value

    return aggregated_metrics