FedProx

class FedProx(fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = 'num-examples', arrayrecord_key: str = 'arrays', configrecord_key: str = 'config', train_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, evaluate_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, proximal_mu: float = 0.0)[source]

Bases : FedAvg

Federated Optimization strategy.

Implémentation basée sur https://arxiv.org/abs/1812.06127

FedProx étend FedAvg en introduisant un terme proximal dans l’objectif d’optimisation côté-client. La stratégie elle-même se comporte identiquement à FedAvg du côté serveur, mais chaque client DOIT ajouter un terme de régularisation proximale à la fonction de perte locale pendant l’entraînement:

\[\frac{\mu}{2} || w - w^t ||^2\]

Où $w^t$ désigne les paramètres globaux et $w$ désigne les poids locaux étant optimisés.

Cette stratégie envoie le terme proximal à l’intérieur du ConfigRecord sous forme de partie de la méthode configure_train sous clé "proximal-mu". Le client peut alors utiliser cette valeur pour ajouter le terme proximal à la fonction de perte.

Dans PyTorch, par exemple, la perte passerait de :

loss = criterion(net(inputs), labels)

En :

# Get proximal term weight from message
mu = msg.content["config"]["proximal-mu"]

# Compute proximal term
proximal_term = 0.0
for local_weights, global_weights in zip(net.parameters(), global_params):
    proximal_term += (local_weights - global_weights).norm(2)

# Update loss
loss = criterion(net(inputs), labels) + (mu / 2) * proximal_term

Avec global_params étant une copie des paramètres du modèle, créés après l’application des poids globaux reçus mais avant que le entraînement local ne commence.

global_params = copy.deepcopy(net).parameters()
Paramètres:
  • fraction_train (float (default: 1.0)) – Fraction des nœuds utilisés pendant l’entraînement. Dans le cas où min_train_nodes est supérieur à fraction_train * total_connected_nodes, min_train_nodes sera tout de même échantillonné.

  • fraction_evaluate (float (default: 1.0)) – Fraction des nœuds utilisés pendant la validation. Dans le cas où min_evaluate_nodes est supérieur à fraction_evaluate * total_connected_nodes, min_evaluate_nodes sera tout de même échantillonné.

  • min_train_nodes (int (default: 2)) – Nombre minimum de nœuds utilisés pendant l’entraînement.

  • min_evaluate_nodes (int (default: 2)) – Nombre minimum de nœuds utilisés pendant la validation.

  • min_available_nodes (int (default: 2)) – Nombre minimum de nœuds totaux dans le système.

  • weighted_by_key (str (default: "num-examples")) – La clé dans chaque MetricRecord dont la valeur est utilisée comme poids lors du calcul des moyennes pondérées pour les ArrayRecords et les MetricRecords.

  • arrayrecord_key (str (default: "arrays")) – Clé utilisée pour stocker l’ArrayRecord lors de la construction des Messages.

  • configrecord_key (str (default: "config")) – Clé utilisée pour stocker le ConfigRecord lors de la construction des Messages.

  • train_metrics_aggr_fn (Optional[callable] (default: None)) – Fonction avec signature (liste[RecordDict], str) -> MetricRecord, utilisée pour agréger les MetricRecords à partir des réponses de ronde d’entraînement. Si None, utilise par défaut aggregate_metricrecords, qui effectue une moyenne pondérée en utilisant la clé de facteur de poids fournie.

  • evaluate_metrics_aggr_fn (Optional[callable] (default: None)) – Fonction avec signature (liste[RecordDict], str) -> MetricRecord, utilisée pour agréger les MetricRecords à partir des réponses de ronde d’entraînement. Si None, utilise par défaut aggregate_metricrecords, qui effectue une moyenne pondérée en utilisant la clé de facteur de poids fournie.

  • proximal_mu (float (default: 0.0)) – La pondération du terme proximal utilisé dans l’optimisation. 0,0 fait que cette stratégie est équivalente à FedAvg, et plus la coefficient est élevée, plus la régularisation sera utilisée (c’est-à-dire que les paramètres des clients devront être plus proches des paramètres du serveur pendant l’entraînement).

Methods

aggregate_evaluate(server_round, replies)

Aggrégation des MetricRecords reçus dans les Messages.

aggregate_train(server_round, replies)

Aggrégation des ArrayRecords et des MetricRecords reçus dans les Messages.

configure_evaluate(server_round, arrays, ...)

Configuration du prochain tour d'évaluation fédérée.

configure_train(server_round, arrays, ...)

Configuration du prochain tour d'entraînement fédéré.

start(grid, initial_arrays[, num_rounds, ...])

Exécution de la stratégie d'apprentissage fédéré.

summary()

Loguer la configuration sommaire de la stratégie.

aggregate_evaluate(server_round: int, replies: Iterable[Message]) MetricRecord | None

Aggrégation des MetricRecords reçus dans les Messages.

aggregate_train(server_round: int, replies: Iterable[Message]) tuple[ArrayRecord | None, MetricRecord | None]

Aggrégation des ArrayRecords et des MetricRecords reçus dans les Messages.

configure_evaluate(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message]

Configuration du prochain tour d’évaluation fédérée.

configure_train(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message][source]

Configuration du prochain tour d’entraînement fédéré.

start(grid: Grid, initial_arrays: ArrayRecord, num_rounds: int = 3, timeout: float = 3600, train_config: ConfigRecord | None = None, evaluate_config: ConfigRecord | None = None, evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None) Result

Exécution de la stratégie d’apprentissage fédéré.

Exécuter le flux de travail complet d’apprentissage fédéré pour un nombre spécifié de tours, y compris l’entraînement, l’évaluation et l’évaluation centralisée facultative.

Paramètres:
  • grid (Grid) – Instance du Grid utilisée pour envoyer/recevoir des Messages à partir de nœuds exécutant une application ClientApp.

  • initial_arrays (ArrayRecord) – Paramètres initiaux du modèle (tableaux) à utiliser pour l’apprentissage fédéré.

  • num_rounds (int (default: 3)) – Nombre de tours d’apprentissage fédéré à exécuter.

  • timeout (float (default: 3600)) – Délai en secondes pour attendre les réponses des nœuds.

  • train_config (ConfigRecord, optional) – Configuration à envoyer aux nœuds pendant les tours d’entraînement. Si non défini, un ConfigRecord vide sera utilisé.

  • evaluate_config (ConfigRecord, optional) – Configuration à envoyer aux nœuds pendant les tours d’évaluation. Si non défini, un ConfigRecord vide sera utilisé.

  • evaluate_fn (Callable[[int, ArrayRecord], Optional[MetricRecord]], optional) – Fonction facultative pour l’évaluation centralisée du modèle global. Prend le numéro de tour serveur et le record tableau, retourne un MetricRecord ou None. Si fourni, sera appelé avant la première ronde et après chaque ronde. Par défaut à None.

Renvoie:

Résultats contenant les tableaux de modèles finaux ainsi que les métriques d’entraînement, les métriques d’évaluation et les métriques d’évaluation globales (si fourni) de toutes les rondes.

Type renvoyé:

Results

summary() None[source]

Loguer la configuration sommaire de la stratégie.