Quickstart MLX#
In this federated learning tutorial we will learn how to train simple MLP on MNIST using Flower and MLX. It is recommended to create a virtual environment and run everything within a virtualenv.
Letโs use flwr new to create a complete Flower+MLX project. It will generate all the files needed to run, by default with the Simulation Engine, a federation of 10 nodes using FedAvg. The dataset will be partitioned using Flower Datasetโs IidPartitioner.
Now that we have a rough idea of what this example is about, letโs get started. First, install Flower in your new environment:
# In a new Python environment
$ pip install flwr
Then, run the command below. You will be prompted to select of the
available templates (choose MLX
), give a name to your project, and
type in your developer name:
$ flwr new
After running it youโll notice a new directory with your project name has been created. It should have the following structure:
<your-project-name>
โโโ <your-project-name>
โ โโโ __init__.py
โ โโโ client_app.py # Defines your ClientApp
โ โโโ server_app.py # Defines your ServerApp
โ โโโ task.py # Defines your model, training and data loading
โโโ pyproject.toml # Project metadata like dependencies and configs
โโโ README.md
If you havenโt yet installed the project and its dependencies, you can do so by:
# From the directory where your pyproject.toml is
$ pip install -e .
To run the project do:
# Run with default arguments
$ flwr run .
With default arguments you will see an output like this one:
Loading project configuration...
Success
INFO : Starting Flower ServerApp, config: num_rounds=3, no round_timeout
INFO :
INFO : [INIT]
INFO : Requesting initial parameters from one random client
WARNING : FAB ID is not provided; the default ClientApp will be loaded.
INFO : Received initial parameters from one random client
INFO : Evaluating initial global parameters
INFO :
INFO : [ROUND 1]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
WARNING : No fit_metrics_aggregation_fn provided
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
WARNING : No evaluate_metrics_aggregation_fn provided
INFO :
INFO : [ROUND 2]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO : [ROUND 3]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO : [SUMMARY]
INFO : Run finished 3 round(s) in 8.15s
INFO : History (loss, distributed):
INFO : round 1: 2.243802046775818
INFO : round 2: 2.101812958717346
INFO : round 3: 1.7419301986694335
INFO :
You can also override the parameters defined in
[tool.flwr.app.config]
section in the pyproject.toml
like this:
# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 lr=0.05"
What follows is an explanation of each component in the project you just
created: dataset partition, the model, defining the ClientApp
and
defining the ServerApp
.
The Data#
We will use Flower Datasets to easily download and partition the MNIST dataset. In this example youโll make use of the IidPartitioner to generate num_partitions partitions. You can choose other partitioners available in Flower Datasets:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
partition_splits["train"].set_format("numpy")
partition_splits["test"].set_format("numpy")
train_partition = partition_splits["train"].map(
lambda img: {"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0},
input_columns="image",
)
test_partition = partition_splits["test"].map(
lambda img: {"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0},
input_columns="image",
)
data = (
train_partition["img"],
train_partition["label"].astype(np.uint32),
test_partition["img"],
test_partition["label"].astype(np.uint32),
)
train_images, train_labels, test_images, test_labels = map(mx.array, data)
The Model#
We define the model as in the centralized MLX example, itโs a simple MLP:
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
We also define some utility functions to test our model and to iterate over batches.
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
The ClientApp#
The main changes we have to make to use MLX with Flower will be
found in the get_params()
and set_params()
functions. Indeed,
MLX doesnโt provide an easy way to convert the model parameters into a
list of np.array
objects (the format we need for the serialization
of the messages to work).
The way MLX stores its parameters is as follows:
{
"layers": [
{"weight": mlx.core.array, "bias": mlx.core.array},
{"weight": mlx.core.array, "bias": mlx.core.array},
...,
{"weight": mlx.core.array, "bias": mlx.core.array}
]
}
Therefore, to get our list of np.array
objects, we need to extract
each array and convert them into a NumPy array:
def get_params(model):
layers = model.parameters()["layers"]
return [np.array(val) for layer in layers for _, val in layer.items()]
For the set_params()
function, we perform the reverse operation. We
receive a list of NumPy arrays and want to convert them into MLX
parameters. Therefore, we iterate through pairs of parameters and assign
them to the weight and bias keys of each layer dict:
def set_params(model, parameters):
new_params = {}
new_params["layers"] = [
{"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
for i in range(0, len(parameters), 2)
]
model.update(new_params)
The rest of the functionality is directly inspired by the centralized
case. The fit()
method in the client trains the model using the
local dataset:
def fit(self, parameters, config):
self.set_parameters(parameters)
for _ in range(self.num_epochs):
for X, y in batch_iterate(
self.batch_size, self.train_images, self.train_labels
):
_, grads = self.loss_and_grad_fn(self.model, X, y)
self.optimizer.update(self.model, grads)
mx.eval(self.model.parameters(), self.optimizer.state)
return self.get_parameters(config={}), len(self.train_images), {}
Here, after updating the parameters, we perform the training as in the centralized case, and return the new parameters.
And for the evaluate()
method of the client:
def evaluate(self, parameters, config):
self.set_parameters(parameters)
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
loss = loss_fn(self.model, self.test_images, self.test_labels)
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
We also begin by updating the parameters with the ones sent by the
server, and then we compute the loss and accuracy using the functions
defined above. In the constructor of the FlowerClient
we instantiate
the MLP model as well as other components such as the optimizer.
Putting everything together we have:
class FlowerClient(NumPyClient):
def __init__(
self,
data,
num_layers,
hidden_dim,
num_classes,
batch_size,
learning_rate,
num_epochs,
):
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.num_epochs = num_epochs
self.batch_size = batch_size
def get_parameters(self, config):
return get_params(self.model)
def set_parameters(self, parameters):
set_params(self.model, parameters)
def fit(self, parameters, config):
self.set_parameters(parameters)
for _ in range(self.num_epochs):
for X, y in batch_iterate(
self.batch_size, self.train_images, self.train_labels
):
_, grads = self.loss_and_grad_fn(self.model, X, y)
self.optimizer.update(self.model, grads)
mx.eval(self.model.parameters(), self.optimizer.state)
return self.get_parameters(config={}), len(self.train_images), {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
loss = loss_fn(self.model, self.test_images, self.test_labels)
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
Finally, we can construct a ClientApp
using the FlowerClient
defined above by means of a client_fn()
callback. Note that
context
enables you to get access to hyperparemeters defined in
pyproject.toml
to configure the run. In this tutorial we access,
among other hyperparameters, the local-epochs
setting to control the
number of epochs a ClientApp
will perform when running the fit()
method.
def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
data = load_data(partition_id, num_partitions)
num_layers = context.run_config["num-layers"]
hidden_dim = context.run_config["hidden-dim"]
num_classes = 10
batch_size = context.run_config["batch-size"]
learning_rate = context.run_config["lr"]
num_epochs = context.run_config["local-epochs"]
# Return Client instance
return FlowerClient(
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
).to_client()
# Flower ClientApp
app = ClientApp(client_fn)
The ServerApp#
To construct a ServerApp
, we define a server_fn()
callback with
an identical signature to that of client_fn()
, but the return type
is ServerAppComponents
as opposed to Client.
In this example we use the FedAvg
strategy.
def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
# Define strategy
strategy = FedAvg()
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(strategy=strategy, config=config)
# Create ServerApp
app = ServerApp(server_fn=server_fn)
Congratulations! Youโve successfully built and run your first federated learning system.
์ฐธ๊ณ
Check the source code
of the extended version of this tutorial in
examples/quickstart-mlx
in the Flower GitHub repository.