FedProx

class FedProx(fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = 'num-examples', arrayrecord_key: str = 'arrays', configrecord_key: str = 'config', train_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, evaluate_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, proximal_mu: float = 0.0)[source]

Bases: FedAvg

联邦优化策略。

实施基于 https://arxiv.org/abs/1812.06127

FedProx extends FedAvg by introducing a proximal term into the client-side optimization objective. The strategy itself behaves identically to FedAvg on the server side, but each client MUST add a proximal regularization term to its local loss function during training:

\[\frac{\mu}{2} || w - w^t ||^2\]

Where $w^t$ denotes the global parameters and $w$ denotes the local weights being optimized.

This strategy sends the proximal term inside the ConfigRecord as part of the configure_train method under key "proximal-mu". The client can then use this value to add the proximal term to the loss function.

例如,在 PyTorch 中,损失将从:

致:

With global_params being a copy of the model parameters, created after applying the received global weights but before local training begins.

参数:
  • fraction_train (float (default: 1.0)) -- Fraction of nodes used during training. In case min_train_nodes is larger than fraction_train * total_connected_nodes, min_train_nodes will still be sampled.

  • fraction_evaluate (float (default: 1.0)) -- Fraction of nodes used during validation. In case min_evaluate_nodes is larger than fraction_evaluate * total_connected_nodes, min_evaluate_nodes will still be sampled.

  • min_train_nodes (int (default: 2)) -- Minimum number of nodes used during training.

  • min_evaluate_nodes (int (default: 2)) -- Minimum number of nodes used during validation.

  • min_available_nodes (int (default: 2)) -- Minimum number of total nodes in the system.

  • weighted_by_key (str (default: "num-examples")) -- The key within each MetricRecord whose value is used as the weight when computing weighted averages for both ArrayRecords and MetricRecords.

  • arrayrecord_key (str (default: "arrays")) -- Key used to store the ArrayRecord when constructing Messages.

  • configrecord_key (str (default: "config")) -- Key used to store the ConfigRecord when constructing Messages.

  • train_metrics_aggr_fn (Optional[callable] (default: None)) -- Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.

  • evaluate_metrics_aggr_fn (Optional[callable] (default: None)) -- Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.

  • proximal_mu (float (default: 0.0)) -- 优化中使用的近端项权重。0.0 使该策略等同于 FedAvg,系数越大,使用的正则化就越多(也就是说,在训练过程中,客户端参数需要更接近服务器参数)。

Methods

aggregate_evaluate(server_round, replies)

Aggregate MetricRecords in the received Messages.

aggregate_train(server_round, replies)

Aggregate ArrayRecords and MetricRecords in the received Messages.

configure_evaluate(server_round, arrays, ...)

Configure the next round of federated evaluation.

configure_train(server_round, arrays, ...)

Configure the next round of federated training.

start(grid, initial_arrays[, num_rounds, ...])

Execute the federated learning strategy.

summary()

Log summary configuration of the strategy.

aggregate_evaluate(server_round: int, replies: Iterable[Message]) MetricRecord | None

Aggregate MetricRecords in the received Messages.

aggregate_train(server_round: int, replies: Iterable[Message]) tuple[ArrayRecord | None, MetricRecord | None]

Aggregate ArrayRecords and MetricRecords in the received Messages.

configure_evaluate(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message]

Configure the next round of federated evaluation.

configure_train(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message][source]

Configure the next round of federated training.

start(grid: Grid, initial_arrays: ArrayRecord, num_rounds: int = 3, timeout: float = 3600, train_config: ConfigRecord | None = None, evaluate_config: ConfigRecord | None = None, evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None) Result

Execute the federated learning strategy.

Runs the complete federated learning workflow for the specified number of rounds, including training, evaluation, and optional centralized evaluation.

参数:
  • grid (Grid) -- The Grid instance used to send/receive Messages from nodes executing a ClientApp.

  • initial_arrays (ArrayRecord) -- Initial model parameters (arrays) to be used for federated learning.

  • num_rounds (int (default: 3)) -- Number of federated learning rounds to execute.

  • timeout (float (default: 3600)) -- Timeout in seconds for waiting for node responses.

  • train_config (ConfigRecord, optional) -- Configuration to be sent to nodes during training rounds. If unset, an empty ConfigRecord will be used.

  • evaluate_config (ConfigRecord, optional) -- Configuration to be sent to nodes during evaluation rounds. If unset, an empty ConfigRecord will be used.

  • evaluate_fn (Callable[[int, ArrayRecord], Optional[MetricRecord]], optional) -- Optional function for centralized evaluation of the global model. Takes server round number and array record, returns a MetricRecord or None. If provided, will be called before the first round and after each round. Defaults to None.

返回:

Results containing final model arrays and also training metrics, evaluation metrics and global evaluation metrics (if provided) from all rounds.

返回类型:

Results

summary() None[source]

Log summary configuration of the strategy.