FedProxยถ
- class FedProx(fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = 'num-examples', arrayrecord_key: str = 'arrays', configrecord_key: str = 'config', train_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, evaluate_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, proximal_mu: float = 0.0)[์์ค]ยถ
๊ธฐ๋ฐ ํด๋์ค:
FedAvg
Federated Optimization strategy.
Implementation based on https://arxiv.org/abs/1812.06127
FedProx extends FedAvg by introducing a proximal term into the client-side optimization objective. The strategy itself behaves identically to FedAvg on the server side, but each client MUST add a proximal regularization term to its local loss function during training:
\[\frac{\mu}{2} || w - w^t ||^2\]Where $w^t$ denotes the global parameters and $w$ denotes the local weights being optimized.
This strategy sends the proximal term inside the
ConfigRecord
as part of theconfigure_train
method under key"proximal-mu"
. The client can then use this value to add the proximal term to the loss function.In PyTorch, for example, the loss would go from:
To:
With
global_params
being a copy of the model parameters, created after applying the received global weights but before local training begins.- ๋งค๊ฐ๋ณ์:
fraction_train (float (default: 1.0)) โ Fraction of nodes used during training. In case min_train_nodes is larger than fraction_train * total_connected_nodes, min_train_nodes will still be sampled.
fraction_evaluate (float (default: 1.0)) โ Fraction of nodes used during validation. In case min_evaluate_nodes is larger than fraction_evaluate * total_connected_nodes, min_evaluate_nodes will still be sampled.
min_train_nodes (int (default: 2)) โ Minimum number of nodes used during training.
min_evaluate_nodes (int (default: 2)) โ Minimum number of nodes used during validation.
min_available_nodes (int (default: 2)) โ Minimum number of total nodes in the system.
weighted_by_key (str (default: "num-examples")) โ The key within each MetricRecord whose value is used as the weight when computing weighted averages for both ArrayRecords and MetricRecords.
arrayrecord_key (str (default: "arrays")) โ Key used to store the ArrayRecord when constructing Messages.
configrecord_key (str (default: "config")) โ Key used to store the ConfigRecord when constructing Messages.
train_metrics_aggr_fn (Optional[callable] (default: None)) โ Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.
evaluate_metrics_aggr_fn (Optional[callable] (default: None)) โ Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.
proximal_mu (float (default: 0.0)) โ The weight of the proximal term used in the optimization. 0.0 makes this strategy equivalent to FedAvg, and the higher the coefficient, the more regularization will be used (that is, the client parameters will need to be closer to the server parameters during training).
๋ฉ์๋
aggregate_evaluate
(server_round, replies)Aggregate MetricRecords in the received Messages.
aggregate_train
(server_round, replies)Aggregate ArrayRecords and MetricRecords in the received Messages.
configure_evaluate
(server_round, arrays, ...)Configure the next round of federated evaluation.
configure_train
(server_round, arrays, ...)Configure the next round of federated training.
start
(grid, initial_arrays[, num_rounds, ...])Execute the federated learning strategy.
summary
()Log summary configuration of the strategy.
- aggregate_evaluate(server_round: int, replies: Iterable[Message]) MetricRecord | None ยถ
Aggregate MetricRecords in the received Messages.
- aggregate_train(server_round: int, replies: Iterable[Message]) tuple[ArrayRecord | None, MetricRecord | None] ยถ
Aggregate ArrayRecords and MetricRecords in the received Messages.
- configure_evaluate(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message] ยถ
Configure the next round of federated evaluation.
- configure_train(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message] [์์ค]ยถ
Configure the next round of federated training.
- start(grid: Grid, initial_arrays: ArrayRecord, num_rounds: int = 3, timeout: float = 3600, train_config: ConfigRecord | None = None, evaluate_config: ConfigRecord | None = None, evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None) Result ยถ
Execute the federated learning strategy.
Runs the complete federated learning workflow for the specified number of rounds, including training, evaluation, and optional centralized evaluation.
- ๋งค๊ฐ๋ณ์:
grid (Grid) โ The Grid instance used to send/receive Messages from nodes executing a ClientApp.
initial_arrays (ArrayRecord) โ Initial model parameters (arrays) to be used for federated learning.
num_rounds (int (default: 3)) โ Number of federated learning rounds to execute.
timeout (float (default: 3600)) โ Timeout in seconds for waiting for node responses.
train_config (ConfigRecord, optional) โ Configuration to be sent to nodes during training rounds. If unset, an empty ConfigRecord will be used.
evaluate_config (ConfigRecord, optional) โ Configuration to be sent to nodes during evaluation rounds. If unset, an empty ConfigRecord will be used.
evaluate_fn (Callable[[int, ArrayRecord], Optional[MetricRecord]], optional) โ Optional function for centralized evaluation of the global model. Takes server round number and array record, returns a MetricRecord or None. If provided, will be called before the first round and after each round. Defaults to None.
- ๋ฐํ:
Results containing final model arrays and also training metrics, evaluation metrics and global evaluation metrics (if provided) from all rounds.
- ๋ฐํ ํ์:
Results