.. _quickstart-jax:
Quickstart JAX
==============
In this federated learning tutorial we will learn how to train a linear regression model
using Flower and `JAX `_. It is recommended to
create a virtual environment and run everything within a :doc:`virtualenv
`.
Let's use ``flwr new`` to create a complete Flower+JAX project. It will generate all the
files needed to run, by default with the Flower Simulation Engine, a federation of 10
nodes using |fedavg|_. A random regression dataset will be loaded from scikit-learn's
|makeregression|_ function.
Now that we have a rough idea of what this example is about, let's get started. First,
install Flower in your new environment:
.. code-block:: shell
# In a new Python environment
$ pip install flwr
Then, run the command below. You will be prompted to select one of the available
templates (choose ``JAX``), give a name to your project, and type in your developer
name:
.. code-block:: shell
$ flwr new
After running it you'll notice a new directory with your project name has been created.
It should have the following structure:
.. code-block:: shell
├──
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
If you haven't yet installed the project and its dependencies, you can do so by:
.. code-block:: shell
# From the directory where your pyproject.toml is
$ pip install -e .
To run the project, do:
.. code-block:: shell
# Run with default arguments
$ flwr run .
With default arguments you will see an output like this one:
.. code-block:: shell
Loading project configuration...
Success
INFO : Starting Flower ServerApp, config: num_rounds=3, no round_timeout
INFO :
INFO : [INIT]
INFO : Requesting initial parameters from one random client
INFO : Received initial parameters from one random client
INFO : Starting evaluation of initial global parameters
INFO : Evaluation returned no results (`None`)
INFO :
INFO : [ROUND 1]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
WARNING : No fit_metrics_aggregation_fn provided
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
WARNING : No evaluate_metrics_aggregation_fn provided
INFO :
INFO : [ROUND 2]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO : [ROUND 3]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
INFO : aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO : [SUMMARY]
INFO : Run finished 3 round(s) in 6.07s
INFO : History (loss, distributed):
INFO : round 1: 0.29372873306274416
INFO : round 2: 5.820648354415425e-08
INFO : round 3: 1.526226667528834e-14
INFO :
You can also override the parameters defined in the ``[tool.flwr.app.config]`` section
in ``pyproject.toml`` like this:
.. code-block:: shell
# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 input-dim=5"
What follows is an explanation of each component in the project you just created:
dataset partition, the model, defining the ``ClientApp`` and defining the ``ServerApp``.
The Data
--------
This tutorial uses scikit-learn's |makeregression|_ function to generate a random
regression problem.
.. code-block:: python
def load_data():
# Load dataset
X, y = make_regression(n_features=3, random_state=0)
X, X_test, y, y_test = train_test_split(X, y)
return X, y, X_test, y_test
The Model
---------
We defined a simple linear regression model to demonstrate how to create a JAX model,
but feel free to replace it with a more sophisticated JAX model if you'd like, (such as
with NN-based `Flax `_):
.. code-block:: python
def load_model(model_shape):
# Extract model parameters
params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
return params
In addition to defining the model architecture, we also include two utility functions to
perform both training (i.e. ``train()``) and evaluation (i.e. ``evaluation()``) using
the above model.
.. code-block:: python
def loss_fn(params, X, y):
# Return MSE as loss
err = jnp.dot(X, params["w"]) + params["b"] - y
return jnp.mean(jnp.square(err))
def train(params, grad_fn, X, y):
loss = 1_000_000
num_examples = X.shape[0]
for epochs in range(50):
grads = grad_fn(params, X, y)
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params, X, y)
return params, loss, num_examples
def evaluation(params, grad_fn, X_test, y_test):
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
return loss_test, num_examples
The ClientApp
-------------
The main changes we have to make to use JAX with Flower will be found in the
``get_params()`` and ``set_params()`` functions. In ``get_params()``, JAX model
parameters are extracted and represented as a list of NumPy arrays. The ``set_params()``
function is the opposite: given a list of NumPy arrays it applies them to an existing
JAX model.
.. note::
The ``get_params()`` and ``set_params()`` functions here are conceptually similar to
the ``get_weights()`` and ``set_weights()`` functions that we defined in the
:doc:`QuickStart PyTorch ` tutorial.
.. code-block:: python
def get_params(params):
parameters = []
for _, val in params.items():
parameters.append(np.array(val))
return parameters
def set_params(local_params, global_params):
for key, value in list(zip(local_params.keys(), global_params)):
local_params[key] = value
The rest of the functionality is directly inspired by the centralized case. The
``fit()`` method in the client trains the model using the local dataset. Similarly, the
``evaluate()`` method is used to evaluate the model received on a held-out validation
set that the client might have:
.. code-block:: python
class FlowerClient(NumPyClient):
def __init__(self, input_dim):
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
self.grad_fn = jax.grad(loss_fn)
model_shape = self.train_x.shape[1:]
self.params = load_model(model_shape)
def fit(self, parameters, config):
set_params(self.params, parameters)
self.params, loss, num_examples = train(
self.params, self.grad_fn, self.train_x, self.train_y
)
parameters = get_params({})
return parameters, num_examples, {"loss": float(loss)}
def evaluate(self, parameters, config):
set_params(self.params, parameters)
loss, num_examples = evaluation(
self.params, self.grad_fn, self.test_x, self.test_y
)
return float(loss), num_examples, {"loss": float(loss)}
Finally, we can construct a ``ClientApp`` using the ``FlowerClient`` defined above by
means of a ``client_fn()`` callback. Note that the `context` enables you to get access
to hyperparemeters defined in your ``pyproject.toml`` to configure the run. In this
tutorial we access the ``local-epochs`` setting to control the number of epochs a
``ClientApp`` will perform when running the ``fit()`` method. You could define
additioinal hyperparameters in ``pyproject.toml`` and access them here.
.. code-block:: python
def client_fn(context: Context):
input_dim = context.run_config["input-dim"]
# Return Client instance
return FlowerClient(input_dim).to_client()
# Flower ClientApp
app = ClientApp(client_fn)
The ServerApp
-------------
To construct a ``ServerApp`` we define a ``server_fn()`` callback with an identical
signature to that of ``client_fn()`` but the return type is |serverappcomponents|_ as
opposed to a |client|_ In this example we use the ``FedAvg`` strategy. To it we pass a
randomly initialized model that will server as the global model to federated. Note that
the value of ``input_dim`` is read from the run config. You can find the default value
defined in the ``pyproject.toml``.
.. code-block:: python
def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
input_dim = context.run_config["input-dim"]
# Initialize global model
params = get_params(load_model((input_dim,)))
initial_parameters = ndarrays_to_parameters(params)
# Define strategy
strategy = FedAvg(initial_parameters=initial_parameters)
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(strategy=strategy, config=config)
# Create ServerApp
app = ServerApp(server_fn=server_fn)
Congratulations! You've successfully built and run your first federated learning system
for JAX with Flower!
.. note::
Check the source code of the extended version of this tutorial in
|quickstart_jax_link|_ in the Flower GitHub repository.
.. |client| replace:: ``Client``
.. |fedavg| replace:: ``FedAvg``
.. |makeregression| replace:: ``make_regression()``
.. |quickstart_jax_link| replace:: ``examples/quickstart-jax``
.. |serverappcomponents| replace:: ``ServerAppComponents``
.. _client: ref-api/flwr.client.Client.html#client
.. _fedavg: ref-api/flwr.server.strategy.FedAvg.html#flwr.server.strategy.FedAvg
.. _makeregression: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html
.. _quickstart_jax_link: https://github.com/adap/flower/tree/main/examples/quickstart-jax
.. _serverappcomponents: ref-api/flwr.server.ServerAppComponents.html#serverappcomponents
.. meta::
:description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset.