Aggréguez les résultats d’évaluation

Flower stratégies (par exemple FedAvg et toutes celles dérivées de celle-ci) agrègent automatiquement les métriques dans le MetricRecord du Messages retourné par le ClientApps. Par défaut, une agrégation pondérée est effectuée pour toutes les métriques en utilisant la valeur attribuée à l’attribut weighted_by_key d’une stratégie.

Lorsque vous construisez votre stratégie, vous pouvez définir à la fois la clé utilisée pour effectuer une agrégation pondérée mais aussi la fonction de rappel utilisée pour aggréger les métriques.

Note

Par défaut, les stratégies Flower utilisent comme weighted_by_key="num-examples". Si vous êtes intéressé, voir l’implémentation complète de comment la fonction d’agrégation pondérée par défaut fonctionne here.

from flwr.serverapp.strategy import FedAvg
from flwr.serverapp.strategy.strategy_utils import aggregate_metricrecords

strategy = FedAvg(
    # ... other parameters ...
    weighted_by_key="your-key",  # Key to use for weighted averaging
    evaluate_metrics_aggr_fn=my_metrics_aggr_function,  # Custom aggregation function
)

Voyons comment nous pouvons définir une fonction d’agrégation personnalisée pour les objets MetricRecord reçus dans le retour d’une ronde d’évaluation.

Note

Notez que les stratégies Flower ont également un attribut train_metrics_aggr_fn qui vous permet de définir une fonction d’agrégation personnalisée pour les objets MetricRecord reçus dans les messages de réponse d’un tour d’entraînement. Par défaut, elle effectue une moyenne pondérée en utilisant la valeur attribuée à l’attribut weighted_by_key exactement comme présenté plus tôt.

Utilisez une fonction d’agrégation de métriques personnalisée

Le evaluate_metrics_aggr_fn peut être personnalisé pour supporter toute logique d’agrégation des résultats d’évaluation que vous avez besoin. Sa définition est :

Callable[[list[RecordDict], str], MetricRecord]

Elle prend en entrée une liste de RecordDict et une clé de pondération comme arguments et renvoie un MetricRecord. Par exemple, la fonction ci-dessous extrait et renvoie la valeur minimale pour chaque clé de métrique sur tous les Message:

from flwr.app import MetricRecord, RecordDict


def custom_metrics_aggregation_fn(
    records: list[RecordDict], weighting_metric_name: str
) -> MetricRecord:
    """Extract the minimum value for each metric key."""
    aggregated_metrics = MetricRecord()

    # Track current minimum per key in a plain dict,
    # then copy into MetricRecord at the end
    mins = {}

    for record in records:
        for record_item in record.metric_records.values():
            for key, value in record_item.items():
                if key == weighting_metric_name:
                    # We exclude the weighting key from the aggregated MetricRecord
                    continue

                if key in mins:
                    if value < mins[key]:
                        mins[key] = value
                else:
                    mins[key] = value

    for key, value in mins.items():
        aggregated_metrics[key] = value

    return aggregated_metrics