Design stateful ClientApps

By design, ClientApp objects are stateless. This means that the ClientApp object is recreated each time a new Message is to be processed. This behaviour is identical with Flower's Simulation Engine and Deployment Engine. For the former, it allows us to simulate the running of a large number of nodes on a single machine or across multiple machines. For the latter, it enables each SuperNode to be part of multiple runs, each running a different ClientApp.

When a ClientApp is executed it receives a Context. This context is unique for each ClientApp, meaning that subsequent executions of the same ClientApp from the same node will receive the same Context object. In the Context, the .state attribute can be used to store information that you would like the ClientApp to have access to for the duration of the run. This could be anything from intermediate results such as the history of training losses (e.g. as a list of float values with a new entry appended each time the ClientApp is executed), certain parts of the model that should persist on the client side, or some other arbitrary Python objects. These items would need to be serialized before saving them into the context.

Saving metrics to the context

This section will demonstrate how to save metrics such as accuracy/loss values to the Context so they can be used in subsequent executions of the ClientApp. If your ClientApp makes use of NumPyClient then entire object is also re-created for each call to methods like fit() or evaluate().

Let's begin with a simple setting in which ClientApp is defined as follows. The evaluate() method only generates a random number and prints it.

Tip

You can create a PyTorch project with ready-to-use ClientApp and other components by running flwr new.

import random
from flwr.app import Context, ConfigRecord
from flwr.client import NumPyClient
from flwr.clientapp import ClientApp


class SimpleClient(NumPyClient):

    def __init__(self):
        self.n_val = []

    def evaluate(self, parameters, config):
        n = random.randint(0, 10)  # Generate a random integer between 0 and 10
        self.n_val.append(n)
        # Even though in this line `n_val` has the value returned in the line
        # above, self.n_val will be re-initialized to an empty list the next time
        # this `ClientApp` runs
        return float(0.0), 1, {}


def client_fn(context: Context):
    return SimpleClient().to_client()


# Finally, construct the ClientApp instance by means of the `client_fn` callback
app = ClientApp(client_fn=client_fn)

Let's say we want to save that randomly generated integer and append it to a list that persists in the context. To do that, you'll need to do two key things:

  1. Make the context.state reachable within your client class

  2. Initialise the appropriate record type (in this example we use ConfigRecord) and save/read your entry when required.

def SimpleClient(NumPyClient):

    def __init__(self, context: Context):
        self.client_state = (
            context.state
        )  # add a reference to the state of your ClientApp
        if "eval_metrics" not in self.client_state.config_records:
            self.client_state.config_records["eval_metrics"] = ConfigRecord()

        # Print content of the state
        # You'll see it persists previous entries of `n_val`
        print(self.client_state.config_records)

    def evaluate(self, parameters, config):
        n = random.randint(0, 10)  # Generate a random integer between 0 and 10
        # Add results into a `ConfigRecord` object under the "n_val" key
        # Note a `ConfigRecord` is a special type of python Dictionary
        eval_metrics = self.client_state.config_records["eval_metrics"]
        if "n_val" not in eval_metrics:
            eval_metrics["n_val"] = [n]
        else:
            eval_metrics["n_val"].append(n)

        return float(0.0), 1, {}


def client_fn(context: Context):
    return SimpleClient(context).to_client()  # Note we pass the context


# Finally, construct the ClientApp instance by means of the `client_fn` callback
app = ClientApp(client_fn=client_fn)

If you run the app, you'll see an output similar to the one below. See how after each round the n_val entry in the context gets one additional integer ? Note that the order in which the ClientApp logs these messages might differ slightly between rounds.

# round 1 (.evaluate() hasn't been executed yet, so that's why it's empty)
config_records={'eval_metrics': {}}
config_records={'eval_metrics': {}}

# round 2 (note `eval_metrics` has results added in round 1)
config_records={'eval_metrics': {'n_val': [2]}}
config_records={'eval_metrics': {'n_val': [8]}}

# round 3 (note `eval_metrics` has results added in round 1&2)
config_records={'eval_metrics': {'n_val': [8, 2]}}
config_records={'eval_metrics': {'n_val': [2, 9]}}

# round 4 (note `eval_metrics` has results added in round 1&2&3)
config_records={'eval_metrics': {'n_val': [2, 9, 4]}}
config_records={'eval_metrics': {'n_val': [8, 2, 5]}}

Saving model parameters to the context

Using ConfigRecord or MetricRecord to save "simple" components is fine (e.g., float, integer, boolean, string, bytes, and lists of these types. Note that MetricRecord only supports float, integer, and lists of these types) Flower has a specific type of record, a ArrayRecord, for storing model parameters or more generally data arrays.

Let's see a couple of examples of how to save NumPy arrays first and then how to save parameters of PyTorch and TensorFlow models.

Note

The examples below omit the definition of a ClientApp to keep the code blocks concise. To make use of ArrayRecord objects in your ClientApp you can follow the same principles as outlined earlier.

Saving NumPy arrays to the context

Elements stored in a ArrayRecord are of type Array, which is a data structure that holds bytes and metadata that can be used for deserialization. Let's see how to create an Array from a NumPy array and insert it into a ArrayRecord.

Note

Array objects carry bytes as their main payload and additional metadata to use for deserialization. You can also implement your own serialization/deserialization.

Let's see how to use those functions to store a NumPy array into the context.

import numpy as np
from flwr.app import Array, ArrayRecord, Context


# Let's create a simple NumPy array
arr_np = np.random.randn(3, 3)

# If we print it
# array([[-1.84242409, -1.01539537, -0.46528405],
#        [ 0.32991896,  0.55540414,  0.44085534],
#        [-0.10758364,  1.97619858, -0.37120501]])

# Now, let's serialize it and construct an Array
arr = Array(arr_np)

# If we print it (note the binary data)
# Array(dtype='float64', shape=[3, 3], stype='numpy.ndarray', data=b'\x93NUMPY\x01\x00v\x00...)

# It can be inserted in a ArrayRecord like this
arr_record = ArrayRecord()
arr_record["my_array"] = arr
# You can also do it via the constructor
# arr_record = ArrayRecord({"my_array": arr})

# If you don't need the keys, you can also pass a list of Numpy arrays
# arr_record = ArrayRecord([arr_np])

# Then, it can be added to the state in the context
context.state["some_parameters"] = arr_record

To extract the data in a ArrayRecord, you just need to deserialize the array if interest. For example, following the example above:

# Get Array from context
arr = context.state["some_parameters"]["my_array"]

# If you constructed the ArrayRecord with a list of Numpy, then do
# arr = context.state["some_parameters"].to_numpy_ndarrays()[0]  # get first array

# Deserialize it
arr_deserialized = arr.numpy()

# If we print it (it should show the exact same values as earlier)
# array([[-1.84242409, -1.01539537, -0.46528405],
#        [ 0.32991896,  0.55540414,  0.44085534],
#        [-0.10758364,  1.97619858, -0.37120501]])

Saving PyTorch parameters to the context

Flower offers one-liner utilities to convert PyTorch model parameters to/from ArrayRecord objects. Let's see how to do that.

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.app import ArrayRecord


class Net(nn.Module):
    """A very simple model"""

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, 5)
        self.fc = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(self.conv(x))
        return self.fc(x)


# Instantiate model as usual
model = Net()

# Save the state_dict into a single RecordDict
arr_record = ArrayRecord(model.state_dict())

# Add to a context
context.state["net_parameters"] = arr_record

Let say now you want to apply the parameters stored in your context to a new instance of the model (as it happens each time a ClientApp is executed). You will need to:

  1. Retrieve the ArrayRecord from the context

  2. Construct a state_dict and load it

state_dict = {}
# Extract record from context
arr_record = context.state["net_parameters"]

# Deserialize the parameters
state_dict = arr_record.to_torch_state_dict()

# Apply state dict to a new model instance
model_ = Net()
model_.load_state_dict(state_dict)
# now this model has the exact same parameters as the one created earlier
# You can verify this by doing
for p, p_ in zip(model.state_dict().values(), model_.state_dict().values()):
    assert torch.allclose(p, p_), "`state_dict`s do not match"

And that's it! Recall that even though this example shows how to store the entire state_dict in a ArrayRecord, you can just save part of it. The process would be identical, but you might need to adjust how it is loaded into an existing model using PyTorch APIs.

Saving Tensorflow/Keras parameters to the context

Follow the same steps as done above but replace the state_dict logic with simply get_weights() to convert the model parameters to a list of NumPy arrays that can then be saved into an ArrayRecord. Then, after deserialization, use set_weights() to apply the new parameters to a model.

import tensorflow as tf
from flwr.app import ArrayRecord

# Define a simple model
model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(10),
    ]
)

# Save model weights into an ArrayRecord and add to a context
context.state["model_weights"] = ArrayRecord(model.get_weights())

...

# Extract record from context and apply to the modele
model.set_weights(context.state["model_weights"].to_numpy_ndarrays())