FedAvgยถ
- class FedAvg(fraction_train: float = 1.0, fraction_evaluate: float = 1.0, min_train_nodes: int = 2, min_evaluate_nodes: int = 2, min_available_nodes: int = 2, weighted_by_key: str = 'num-examples', arrayrecord_key: str = 'arrays', configrecord_key: str = 'config', train_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None, evaluate_metrics_aggr_fn: Callable[[list[RecordDict], str], MetricRecord] | None = None)[์์ค]ยถ
๊ธฐ๋ฐ ํด๋์ค:
Strategy
Federated Averaging strategy.
Implementation based on https://arxiv.org/abs/1602.05629
- ๋งค๊ฐ๋ณ์:
fraction_train (float (default: 1.0)) โ Fraction of nodes used during training. In case min_train_nodes is larger than fraction_train * total_connected_nodes, min_train_nodes will still be sampled.
fraction_evaluate (float (default: 1.0)) โ Fraction of nodes used during validation. In case min_evaluate_nodes is larger than fraction_evaluate * total_connected_nodes, min_evaluate_nodes will still be sampled.
min_train_nodes (int (default: 2)) โ Minimum number of nodes used during training.
min_evaluate_nodes (int (default: 2)) โ Minimum number of nodes used during validation.
min_available_nodes (int (default: 2)) โ Minimum number of total nodes in the system.
weighted_by_key (str (default: "num-examples")) โ The key within each MetricRecord whose value is used as the weight when computing weighted averages for both ArrayRecords and MetricRecords.
arrayrecord_key (str (default: "arrays")) โ Key used to store the ArrayRecord when constructing Messages.
configrecord_key (str (default: "config")) โ Key used to store the ConfigRecord when constructing Messages.
train_metrics_aggr_fn (Optional[callable] (default: None)) โ Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.
evaluate_metrics_aggr_fn (Optional[callable] (default: None)) โ Function with signature (list[RecordDict], str) -> MetricRecord, used to aggregate MetricRecords from training round replies. If None, defaults to aggregate_metricrecords, which performs a weighted average using the provided weight factor key.
๋ฉ์๋
aggregate_evaluate
(server_round, replies)Aggregate MetricRecords in the received Messages.
aggregate_train
(server_round, replies)Aggregate ArrayRecords and MetricRecords in the received Messages.
configure_evaluate
(server_round, arrays, ...)Configure the next round of federated evaluation.
configure_train
(server_round, arrays, ...)Configure the next round of federated training.
start
(grid, initial_arrays[, num_rounds, ...])Execute the federated learning strategy.
summary
()Log summary configuration of the strategy.
- aggregate_evaluate(server_round: int, replies: Iterable[Message]) MetricRecord | None [์์ค]ยถ
Aggregate MetricRecords in the received Messages.
- aggregate_train(server_round: int, replies: Iterable[Message]) tuple[ArrayRecord | None, MetricRecord | None] [์์ค]ยถ
Aggregate ArrayRecords and MetricRecords in the received Messages.
- configure_evaluate(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message] [์์ค]ยถ
Configure the next round of federated evaluation.
- configure_train(server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid) Iterable[Message] [์์ค]ยถ
Configure the next round of federated training.
- start(grid: Grid, initial_arrays: ArrayRecord, num_rounds: int = 3, timeout: float = 3600, train_config: ConfigRecord | None = None, evaluate_config: ConfigRecord | None = None, evaluate_fn: Callable[[int, ArrayRecord], MetricRecord] | None = None) Result ยถ
Execute the federated learning strategy.
Runs the complete federated learning workflow for the specified number of rounds, including training, evaluation, and optional centralized evaluation.
- ๋งค๊ฐ๋ณ์:
grid (Grid) โ The Grid instance used to send/receive Messages from nodes executing a ClientApp.
initial_arrays (ArrayRecord) โ Initial model parameters (arrays) to be used for federated learning.
num_rounds (int (default: 3)) โ Number of federated learning rounds to execute.
timeout (float (default: 3600)) โ Timeout in seconds for waiting for node responses.
train_config (ConfigRecord, optional) โ Configuration to be sent to nodes during training rounds. If unset, an empty ConfigRecord will be used.
evaluate_config (ConfigRecord, optional) โ Configuration to be sent to nodes during evaluation rounds. If unset, an empty ConfigRecord will be used.
evaluate_fn (Callable[[int, ArrayRecord], MetricRecord], optional) โ Optional function for centralized evaluation of the global model. Takes server round number and array record, returns a MetricRecord. If provided, will be called before the first round and after each round. Defaults to None.
- ๋ฐํ:
Results containing final model arrays and also training metrics, evaluation metrics and global evaluation metrics (if provided) from all rounds.
- ๋ฐํ ํ์:
Results