实施策略#

策略抽象类可以实现完全定制的策略。策略基本上就是在服务器上运行的联邦学习算法。策略决定如何对客户端进行采样、如何配置客户端进行训练、如何聚合参数更新以及如何评估模型。Flower 提供了一些内置策略,这些策略基于下文所述的相同 API。

:code:`策略 ` 抽象类#

所有策略实现均源自抽象基类 flwr.server.strategy.Strategy,包括内置实现和第三方实现。这意味着自定义策略实现与内置实现具有完全相同的功能。

策略抽象定义了一些需要实现的抽象方法:

class Strategy(ABC):
    """Abstract base class for server strategy implementations."""

    @abstractmethod
    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize the (global) model parameters."""

    @abstractmethod
    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

    @abstractmethod
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate training results."""

    @abstractmethod
    def configure_evaluate(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""

    @abstractmethod
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation results."""

    @abstractmethod
    def evaluate(
        self, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate the current model parameters."""

创建一个新策略意味着要实现一个新的 class`(从抽象基类 :code:`Strategy 派生),该类要实现前面显示的抽象方法:

class SotaStrategy(Strategy):
    def initialize_parameters(self, client_manager):
        # Your implementation here

    def configure_fit(self, server_round, parameters, client_manager):
        # Your implementation here

    def aggregate_fit(self, server_round, results, failures):
        # Your implementation here

    def configure_evaluate(self, server_round, parameters, client_manager):
        # Your implementation here

    def aggregate_evaluate(self, server_round, results, failures):
        # Your implementation here

    def evaluate(self, parameters):
        # Your implementation here

Flower 服务器按以下顺序调用这些方法:

sequenceDiagram participant Strategy participant S as Flower Server<br/>start_server participant C1 as Flower Client participant C2 as Flower Client Note left of S: Get initial <br/>model parameters S->>Strategy: initialize_parameters activate Strategy Strategy-->>S: Parameters deactivate Strategy Note left of S: Federated<br/>Training rect rgb(249, 219, 130) S->>Strategy: configure_fit activate Strategy Strategy-->>S: List[Tuple[ClientProxy, FitIns]] deactivate Strategy S->>C1: FitIns activate C1 S->>C2: FitIns activate C2 C1-->>S: FitRes deactivate C1 C2-->>S: FitRes deactivate C2 S->>Strategy: aggregate_fit<br/>List[FitRes] activate Strategy Strategy-->>S: Aggregated model parameters deactivate Strategy end Note left of S: Centralized<br/>Evaluation rect rgb(249, 219, 130) S->>Strategy: evaluate activate Strategy Strategy-->>S: Centralized evaluation result deactivate Strategy end Note left of S: Federated<br/>Evaluation rect rgb(249, 219, 130) S->>Strategy: configure_evaluate activate Strategy Strategy-->>S: List[Tuple[ClientProxy, EvaluateIns]] deactivate Strategy S->>C1: EvaluateIns activate C1 S->>C2: EvaluateIns activate C2 C1-->>S: EvaluateRes deactivate C1 C2-->>S: EvaluateRes deactivate C2 S->>Strategy: aggregate_evaluate<br/>List[EvaluateRes] activate Strategy Strategy-->>S: Aggregated evaluation results deactivate Strategy end Note left of S: Next round, continue<br/>with federated training

下文将详细介绍每种方法。

初始化参数 方法#

initialize_parameters 只调用一次,即在执行开始时。它负责以序列化形式(即 Parameters 对象)提供初始全局模型参数。

内置策略会返回用户提供的初始参数。下面的示例展示了如何将初始参数传递给 FedAvg

import flwr as fl
import tensorflow as tf

# Load model for server-side parameter initialization
model = tf.keras.applications.EfficientNetB0(
    input_shape=(32, 32, 3), weights=None, classes=10
)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Get model weights as a list of NumPy ndarray's
weights = model.get_weights()

# Serialize ndarrays to `Parameters`
parameters = fl.common.ndarrays_to_parameters(weights)

# Use the serialized parameters as the initial global parameters
strategy = fl.server.strategy.FedAvg(
    initial_parameters=parameters,
)
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3), strategy=strategy)

Flower 服务器将调用 initialize_parameters,返回传给 initial_parameters 的参数或 None。如果 initialize_parameters 没有返回任何参数(即 None),服务器将随机选择一个客户端并要求其提供参数。这只是一个便捷的功能,在实际应用中并不推荐使用,但在原型开发中可能很有用。在实践中,建议始终使用服务器端参数初始化。

备注

服务器端参数初始化是一种强大的机制。例如,它可以用来从先前保存的检查点恢复训练。它也是实现混合方法所需的基本能力,例如,使用联邦学习对预先训练好的模型进行微调。

:code:`configure_fit`方法#

configure_fit 负责配置即将开始的一轮训练。*配置*在这里是什么意思?配置一轮训练意味着选择客户并决定向这些客户发送什么指令。configure_fit 说明了这一点:

@abstractmethod
def configure_fit(
    self,
    server_round: int,
    parameters: Parameters,
    client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
    """Configure the next round of training."""

返回值是一个元组列表,每个元组代表将发送到特定客户端的指令。策略实现通常在 configure_fit 中执行以下步骤:

  • 使用 client_manager 随机抽样所有(或部分)可用客户端(每个客户端都表示为 ClientProxy 对象)

  • 将每个 ClientProxy 与持有当前全局模型 parametersconfig dict 的 FitIns 配对

More sophisticated implementations can use configure_fit to implement custom client selection logic. A client will only participate in a round if the corresponding ClientProxy is included in the list returned from configure_fit.

备注

该返回值的结构为用户提供了很大的灵活性。由于指令是按客户端定义的,因此可以向每个客户端发送不同的指令。这使得自定义策略成为可能,例如在不同的客户端上训练不同的模型,或在不同的客户端上使用不同的超参数(通过 config dict)。

aggregate_fit 方法#

aggregate_fit 负责汇总在 configure_fit 中选择并要求训练的客户端所返回的结果。

@abstractmethod
def aggregate_fit(
    self,
    server_round: int,
    results: List[Tuple[ClientProxy, FitRes]],
    failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
    """Aggregate training results."""

当然,失败是有可能发生的,因此无法保证服务器会从它发送指令(通过 configure_fit)的所有客户端获得结果。因此 aggregate_fit 会收到 results 的列表,但也会收到 failures 的列表。

aggregate_fit 返回一个可选的 Parameters 对象和一个聚合度量的字典。Parameters 返回值是可选的,因为 aggregate_fit 可能会认为所提供的结果不足以进行聚合(例如,失败次数过多)。

:code:`configure_evaluate`方法#

configure_evaluate 负责配置下一轮评估。*配置*在这里是什么意思?配置一轮评估意味着选择客户端并决定向这些客户端发送什么指令。configure_evaluate 说明了这一点:

@abstractmethod
def configure_evaluate(
    self,
    server_round: int,
    parameters: Parameters,
    client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
    """Configure the next round of evaluation."""

返回值是一个元组列表,每个元组代表将发送到特定客户端的指令。策略实现通常在 configure_evaluate 中执行以下步骤:

  • 使用 client_manager 随机抽样所有(或部分)可用客户端(每个客户端都表示为 ClientProxy 对象)

  • 将每个 ClientProxy 与持有当前全局模型 parametersconfig dict 的 EvaluateIns 配对

More sophisticated implementations can use configure_evaluate to implement custom client selection logic. A client will only participate in a round if the corresponding ClientProxy is included in the list returned from configure_evaluate.

备注

该返回值的结构为用户提供了很大的灵活性。由于指令是按客户端定义的,因此可以向每个客户端发送不同的指令。这使得自定义策略可以在不同客户端上评估不同的模型,或在不同客户端上使用不同的超参数(通过 config dict)。

aggregate_evaluate 方法#

aggregate_evaluate 负责汇总在 configure_evaluate 中选择并要求评估的客户端返回的结果。

@abstractmethod
def aggregate_evaluate(
    self,
    server_round: int,
    results: List[Tuple[ClientProxy, EvaluateRes]],
    failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
    """Aggregate evaluation results."""

当然,失败是有可能发生的,因此无法保证服务器会从它发送指令(通过 configure_evaluate)的所有客户端获得结果。因此, aggregate_evaluate 会接收 results 的列表,但也会接收 failures 的列表。

aggregate_evaluate 返回一个可选的 float`(损失值)和一个聚合指标字典。:code:`float 返回值是可选的,因为 aggregate_evaluate 可能会认为所提供的结果不足以进行聚合(例如,失败次数过多)。

:code:`evaluate`方法#

evaluate 负责在服务器端评估模型参数。除了 configure_evaluate/aggregate_evaluate 之外,evaluate 可以使策略同时执行服务器端和客户端(联邦)评估。

@abstractmethod
def evaluate(
    self, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    """Evaluate the current model parameters."""

返回值也是可选的,因为策略可能不需要执行服务器端评估,或者因为用户定义的 evaluate 方法可能无法成功完成(例如,它可能无法加载服务器端评估数据)。