FedAdam¶
- class FedAdam(*, fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = 'num-examples', arrayrecord_key: str = 'arrays', configrecord_key: str = 'config', train_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, evaluate_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, eta: float = 0.1, eta_l: float = 0.1, beta_1: float = 0.9, beta_2: float = 0.99, tau: float = 0.001)[source]¶
Bases :
FedOptFedAdam - Adaptive Federated Optimization using Adam.
Implementation based on https://arxiv.org/abs/2003.00295v5
- Paramètres:
fraction_train (float (default: 1.0)) – Fraction des nœuds utilisés pendant l’entraînement. Dans le cas où min_train_nodes est supérieur à fraction_train * total_connected_nodes, min_train_nodes sera tout de même échantillonné.
fraction_evaluate (float (default: 1.0)) – Fraction des nœuds utilisés pendant la validation. Dans le cas où min_evaluate_nodes est supérieur à fraction_evaluate * total_connected_nodes, min_evaluate_nodes sera tout de même échantillonné.
min_train_nodes (int (default: 2)) – Nombre minimum de nœuds utilisés pendant l’entraînement.
min_evaluate_nodes (int (default: 2)) – Nombre minimum de nœuds utilisés pendant la validation.
min_available_nodes (int (default: 2)) – Nombre minimum de nœuds totaux dans le système.
weighted_by_key (str (default: "num-examples")) – La clé dans chaque MetricRecord dont la valeur est utilisée comme poids lors du calcul des moyennes pondérées pour les ArrayRecords et les MetricRecords.
arrayrecord_key (str (default: "arrays")) – Clé utilisée pour stocker l’ArrayRecord lors de la construction des Messages.
configrecord_key (str (default: "config")) – Clé utilisée pour stocker le ConfigRecord lors de la construction des Messages.
train_metrics_aggr_fn (Optional[callable] (default: None)) – Fonction avec signature (liste[RecordDict], str) -> MetricRecord, utilisée pour agréger les MetricRecords à partir des réponses de ronde d’entraînement. Si None, utilise par défaut aggregate_metricrecords, qui effectue une moyenne pondérée en utilisant la clé de facteur de poids fournie.
evaluate_metrics_aggr_fn (Optional[callable] (default: None)) – Fonction avec signature (liste[RecordDict], str) -> MetricRecord, utilisée pour agréger les MetricRecords à partir des réponses de ronde d’entraînement. Si None, utilise par défaut aggregate_metricrecords, qui effectue une moyenne pondérée en utilisant la clé de facteur de poids fournie.
eta (float, optional) – Taux d’apprentissage côté serveur. Vaut 1e-1 par défaut.
eta_l (float, optional) – Taux d’apprentissage côté client. Vaut 1e-1 par défaut.
beta_1 (float, optional) – Paramètre de momentum. Vaut 0,9 par défaut.
beta_2 (float, optional) – Paramètre du second moment. Vaut 0,99 par défaut.
tau (float, optional) – Contrôle le degré d’adaptabilité de l’algorithme. Par défaut, 1e-3.
Methods
aggregate_evaluate(server_round, replies)Aggrégation des MetricRecords reçus dans les Messages.
aggregate_train(server_round, replies)Aggrégation des ArrayRecords et des MetricRecords reçus dans les Messages.
configure_evaluate(server_round, arrays, ...)Configuration du prochain tour d'évaluation fédérée.
configure_train(server_round, arrays, ...)Configuration du prochain tour d'entraînement fédéré.
start(grid, initial_arrays[, num_rounds, ...])Exécution de la stratégie d'apprentissage fédéré.
summary()Loguer la configuration sommaire de la stratégie.
- aggregate_evaluate(server_round: int, replies: Iterable[Message]) MetricRecord | None¶
Aggrégation des MetricRecords reçus dans les Messages.
- aggregate_train(server_round: int, replies: Iterable[Message]) tuple[ArrayRecord | None, MetricRecord | None][source]¶
Aggrégation des ArrayRecords et des MetricRecords reçus dans les Messages.
- configure_evaluate(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message]¶
Configuration du prochain tour d’évaluation fédérée.
- configure_train(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message]¶
Configuration du prochain tour d’entraînement fédéré.
- start(grid: Grid, initial_arrays: ArrayRecord, num_rounds: int = 3, timeout: float = 3600, train_config: ConfigRecord | None = None, evaluate_config: ConfigRecord | None = None, evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None) Result¶
Exécution de la stratégie d’apprentissage fédéré.
Exécuter le flux de travail complet d’apprentissage fédéré pour un nombre spécifié de tours, y compris l’entraînement, l’évaluation et l’évaluation centralisée facultative.
- Paramètres:
grid (Grid) – Instance du Grid utilisée pour envoyer/recevoir des Messages à partir de nœuds exécutant une application ClientApp.
initial_arrays (ArrayRecord) – Paramètres initiaux du modèle (tableaux) à utiliser pour l’apprentissage fédéré.
num_rounds (int (default: 3)) – Nombre de tours d’apprentissage fédéré à exécuter.
timeout (float (default: 3600)) – Délai en secondes pour attendre les réponses des nœuds.
train_config (ConfigRecord, optional) – Configuration à envoyer aux nœuds pendant les tours d’entraînement. Si non défini, un ConfigRecord vide sera utilisé.
evaluate_config (ConfigRecord, optional) – Configuration à envoyer aux nœuds pendant les tours d’évaluation. Si non défini, un ConfigRecord vide sera utilisé.
evaluate_fn (Callable[[int, ArrayRecord], Optional[MetricRecord]], optional) – Fonction facultative pour l’évaluation centralisée du modèle global. Prend le numéro de tour serveur et le record tableau, retourne un MetricRecord ou None. Si fourni, sera appelé avant la première ronde et après chaque ronde. Par défaut à None.
- Renvoie:
Résultats contenant les tableaux de modèles finaux ainsi que les métriques d’entraînement, les métriques d’évaluation et les métriques d’évaluation globales (si fourni) de toutes les rondes.
- Type renvoyé:
Results
- summary() None¶
Loguer la configuration sommaire de la stratégie.