Design stateful ClientApps¶
By design, ClientApp objects are stateless. This means that the ClientApp
object is
recreated each time a new Message
is to be processed. This behavior is identical
with Flower’s Simulation Engine and Deployment Engine. For the former, it allows us to
simulate the running of a large number of nodes on a single machine or across multiple
machines. For the latter, it enables each SuperNode
to be part of multiple runs,
each running a different ClientApp
.
When a ClientApp
is executed it receives a Context. This context is unique for each
ClientApp
, meaning that subsequent executions of the same ClientApp
from the
same node will receive the same Context
object. In the Context
, the .state
attribute (of type RecordDict) can be used to store information that you would like the
ClientApp
to have access to for the duration of the run. This could be anything from
intermediate results such as the history of training losses (e.g. as a list of float
values with a new entry appended each time the ClientApp
is executed), certain parts
of the model that should persist on the client side, or some other arbitrary Python
objects. These items would need to be serialized before saving them into the context.
Saving metrics to the context¶
This section will demonstrate how to save metrics such as accuracy/loss values to the
Context so they can be used in subsequent executions of the ClientApp
.
Let’s begin with a simple setting in which ClientApp
is defined as follows. The
train()
function only generates a random number, prints it, and return an empty
message.
Tip
You can create a PyTorch project with ready-to-use ClientApp
and other
components by running flwr new
.
import random
from flwr.app import Context, Message, RecordDict
from flwr.clientapp import ClientApp
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Generate a random integer between 0 and 10
n = random.randint(0, 10)
print(n)
return Message(RecordDict(), reply_to=msg)
With the minimal ClientApp
above, each time a Message
is addressed to this
train
function, a new random integer will be generated and printed. Let’s say we
want to save that randomly generated integer and append it to a list that persists in
the Context
. This way, each time the function executes,, it prints the history of
random integers. Let’s see how this looks in code:
Tip
Recall, the state
attribute of a Context
object is of type RecordDict,
which is a special dictionary for different types of records available in Flower.
This means that you can save to it not just MetricRecord as in the example below,
but also ArrayRecord and ConfigRecord objects.
import random
from flwr.app import Context, Message, RecordDict
from flwr.clientapp import ClientApp
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""
# Generate a random integer between 0 and 10
n = random.randint(0, 10)
print(n)
# Append to list in context or initialize if it doesn't exist
if "random-metrics" not in context.state:
# Initialize MetricRecord in state
context.state["random-metrics"] = MetricRecord({"random-ints": []})
# Append to record
context.state["random-metrics"]["random-ints"].append(n)
# Print history
print(context.state["random-metrics"])
return Message(RecordDict(), reply_to=msg)
If you run a Flower App including the above logic in your ClientApp
and having just
two clients in your federation sampled in each round, you’ll see an output similar to
the one below. See how after each round the random-metrics
record in the Context
gets one additional integer? Note that, in Simulation Runtime, the order of log messages
may change each round due to the random ordering of simulated clients.
# round 1
config_records={'random-metrics': {'random-ints': [2]}}
config_records={'random-metrics': {'random-ints': [7]}}
# round 2
config_records={'random-metrics': {'random-ints': [2, 5]}}
config_records={'random-metrics': {'random-ints': [7, 4]}}
# round 3
config_records={'random-metrics': {'random-ints': [2, 5, 1]}}
config_records={'random-metrics': {'random-ints': [7, 4, 2]}}
Saving model parameters to the context¶
Using ConfigRecord or MetricRecord to save “simple” components is fine (e.g., float, integer, boolean, string, bytes, and lists of these types. Note that MetricRecord only supports float, integer, and lists of these types). Flower has a specific type of record, an ArrayRecord, for storing model parameters, or more generally, data arrays.
Let’s see a couple of examples of how to save NumPy arrays first and then how to save parameters of PyTorch and TensorFlow models.
Note
The examples below omit the definition of a ClientApp
to keep the code blocks
concise. To make use of ArrayRecord
objects in your ClientApp
you can follow
the same principles as outlined earlier.
Saving NumPy arrays to the context¶
Elements stored in an ArrayRecord
are of type Array, which is a data structure that
holds bytes
and metadata that can be used for deserialization. Let’s see how to
create an Array
from a NumPy array and insert it into an ArrayRecord
.
Note
Array objects carry bytes as their main payload and additional metadata to use for deserialization. You can also implement your own serialization/deserialization.
Let’s see how to use those functions to store a NumPy array into the context.
import numpy as np
from flwr.app import Array, ArrayRecord, Context
# Let's create a simple NumPy array
arr_np = np.random.randn(3, 3)
# If we print it
# array([[-1.84242409, -1.01539537, -0.46528405],
# [ 0.32991896, 0.55540414, 0.44085534],
# [-0.10758364, 1.97619858, -0.37120501]])
# Now, let's serialize it and construct an Array
arr = Array(arr_np)
# If we print it (note the binary data)
# Array(dtype='float64', shape=[3, 3], stype='numpy.ndarray', data=b'\x93NUMPY\x01\x00v\x00...)
# It can be inserted in an ArrayRecord like this
arr_record = ArrayRecord()
arr_record["my_array"] = arr
# You can also do it via the constructor
# arr_record = ArrayRecord({"my_array": arr})
# If you don't need the keys, you can also pass a list of Numpy arrays
# arr_record = ArrayRecord([arr_np])
# Then, it can be added to the state in the context
context.state["some_parameters"] = arr_record
To extract the data in an ArrayRecord
, you just need to deserialize the array of
interest. For example, following the example above:
# Get Array from context
arr = context.state["some_parameters"]["my_array"]
# If you constructed the ArrayRecord with a list of Numpy, then do
# arr = context.state["some_parameters"].to_numpy_ndarrays()[0] # get first array
# Deserialize it
arr_deserialized = arr.numpy()
# If we print it (it should show the exact same values as earlier)
# array([[-1.84242409, -1.01539537, -0.46528405],
# [ 0.32991896, 0.55540414, 0.44085534],
# [-0.10758364, 1.97619858, -0.37120501]])
Saving PyTorch parameters to the context¶
Flower offers one-liner utilities to convert PyTorch model parameters to/from
ArrayRecord
objects. Let’s see how to do that.
import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.app import ArrayRecord
class Net(nn.Module):
"""A very simple model"""
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 32, 5)
self.fc = nn.Linear(1024, 10)
def forward(self, x):
x = F.relu(self.conv(x))
return self.fc(x)
# Instantiate model as usual
model = Net()
# Save the state_dict into a single ArrayRecord
arr_record = ArrayRecord(model.state_dict())
# Add to a context
context.state["net_parameters"] = arr_record
Let’s say now you want to apply the parameters stored in your context to a new instance
of the model (as it happens each time a ClientApp
is executed). You will need to:
Retrieve the
ArrayRecord
from the contextConstruct a
state_dict
and load it
state_dict = {}
# Extract record from context
arr_record = context.state["net_parameters"]
# Deserialize the parameters
state_dict = arr_record.to_torch_state_dict()
# Apply state dict to a new model instance
model_ = Net()
model_.load_state_dict(state_dict)
# now this model has the exact same parameters as the one created earlier
# You can verify this by doing
for p, p_ in zip(model.state_dict().values(), model_.state_dict().values()):
assert torch.allclose(p, p_), "`state_dict`s do not match"
And that’s it! Recall that even though this example shows how to store the entire
state_dict
in an ArrayRecord
, you can just save part of it. The process would be
identical, but you might need to adjust how it is loaded into an existing model using
PyTorch APIs.
Saving Tensorflow/Keras parameters to the context¶
Follow the same steps as done above but replace the state_dict
logic with simply
get_weights()
to convert the model parameters to a list of NumPy arrays that can then be saved into an
ArrayRecord
. Then, after deserialization, use set_weights() to apply the
new parameters to a model.
import tensorflow as tf
from flwr.app import ArrayRecord
# Define a simple model
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
# Save model weights into an ArrayRecord and add to a context
context.state["model_weights"] = ArrayRecord(model.get_weights())
...
# Extract record from context and apply to the model
model.set_weights(context.state["model_weights"].to_numpy_ndarrays())