Published

Federated Learning with MLX and Flower

Photo of Charles Beauville
Charles Beauville
Data Scientist at Flower Labs

Share this post

MLX is a NumPy-like array framework designed for efficient and flexible machine learning on Apple Silicon.

Quoting from their website:

The Python API closely follows NumPy with a few exceptions. MLX also has a fully featured C++ API which closely follows the Python API.

The main differences between MLX and NumPy are:

Composable function transformations: MLX has composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization.

Lazy computation: Computations in MLX are lazy. Arrays are only materialized when needed.

Multi-device: Operations can run on any of the supported devices (CPU, GPU, …)

The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A noteable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU.

The Federated MLX example

In the new quickstart-mlx example, we implemented a simple multi-class classifier model trained for handwritten digit recognition on the MNIST dataset. The centralized case can be found here.

The data

We will use flwr_datasets to easily download and partition the MNIST dataset:

fds = FederatedDataset(dataset="mnist", partitioners={"train": 3})
partition = fds.load_partition(node_id = args.node_id)
partition_splits = partition.train_test_split(test_size=0.2)

partition_splits['train'].set_format("numpy")
partition_splits['test'].set_format("numpy")

train_partition = partition_splits["train"].map(
    lambda img: {
        "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
    },
    input_columns="image",
)
test_partition = partition_splits["test"].map(
    lambda img: {
        "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
    },
    input_columns="image",
)

data = (
    train_partition["img"],
    train_partition["label"].astype(np.uint32),
    test_partition["img"],
    test_partition["label"].astype(np.uint32),
)

train_images, train_labels, test_images, test_labels = map(mlx.core.array, data)

The model

We define the model as in the centralized MLX example, it's a simple MLP:

class MLP(mlx.nn.Module):
    """A simple MLP."""

    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            mlx.nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mlx.core.maximum(l(x), 0.0)
        return self.layers[-1](x)

We also define some utility functions to test our model and to iterate over batches.

def loss_fn(model, X, y):
    return mlx.core.mean(mlx.nn.losses.cross_entropy(model(X), y))


def eval_fn(model, X, y):
    return mlx.core.mean(mlx.core.argmax(model(X), axis=1) == y)


def batch_iterate(batch_size, X, y):
    perm = mlx.core.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

The client

The main changes we have to make to use MLX with Flower will be found in the get_parameters and set_parameters functions. Indeed, MLX doesn't provide an easy way to convert the model parameters into a list of np.ndarrays (the format we need for the serialization of the messages to work).

The way MLX stores its parameters is as follows:

{ "layers": [ {"weight": mlx.core.array, "bias": mlx.core.array}, {"weight": mlx.core.array, "bias": mlx.core.array}, ..., {"weight": mlx.core.array, "bias": mlx.core.array} ] }

Therefore, to get our list of np.arrays, we need to extract each array and convert them into a numpy array:

def get_parameters(self, config):
    layers = self.model.parameters()["layers"]
    return [np.array(val) for layer in layers for _, val in layer.items()]

For the set_parameters function, we perform the reverse operation. We receive a list of arrays and want to convert them into MLX parameters. Therefore, we iterate through pairs of parameters and assign them to the weight and bias keys of each layer dict:

def set_parameters(self, parameters):
    new_params = {}
    new_params["layers"] = [
        {"weight": mlx.core.array(parameters[i]), "bias": mlx.core.array(parameters[i + 1])}
        for i in range(0, len(parameters), 2)
    ]
    self.model.update(new_params)

The rest of the functions are directly inspired by the centralized case:

def fit(self, parameters, config):
    self.set_parameters(parameters)
    for _ in range(self.num_epochs):
        for X, y in batch_iterate(
            self.batch_size, self.train_images, self.train_labels
        ):
            loss, grads = self.loss_and_grad_fn(self.model, X, y)
            self.optimizer.update(self.model, grads)
            mlx.core.eval(self.model.parameters(), self.optimizer.state)
    return self.get_parameters(config={}), len(self.train_images), {}

Here, after updating the parameters, we perform the training as in the centralized case, and return the new parameters.

And for the evaluate function:

def evaluate(self, parameters, config):
    self.set_parameters(parameters)
    accuracy = eval_fn(self.model, self.test_images, self.test_labels)
    loss = loss_fn(self.model, self.test_images, self.test_labels)
    return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}

We also begin by updating the parameters with the ones sent by the server, and then we compute the loss and accuracy using the functions defined above.

Putting everything together we have:

class FlowerClient(fl.client.NumPyClient):
    def __init__(
        self, model, optim, loss_and_grad_fn, data, num_epochs, batch_size
    ) -> None:
        self.model = model
        self.optimizer = optim
        self.loss_and_grad_fn = loss_and_grad_fn
        self.train_images, self.train_labels, self.test_images, self.test_labels = data
        self.num_epochs = num_epochs
        self.batch_size = batch_size

    def get_parameters(self, config):
        layers = self.model.parameters()["layers"]
        return [np.array(val) for layer in layers for _, val in layer.items()]

    def set_parameters(self, parameters):
        new_params = {}
        new_params["layers"] = [
            {"weight": mlx.core.array(parameters[i]), "bias": mlx.core.array(parameters[i + 1])}
            for i in range(0, len(parameters), 2)
        ]
        self.model.update(new_params)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        for _ in range(self.num_epochs):
            for X, y in batch_iterate(
                self.batch_size, self.train_images, self.train_labels
            ):
                loss, grads = self.loss_and_grad_fn(self.model, X, y)
                self.optimizer.update(self.model, grads)
                mlx.core.eval(self.model.parameters(), self.optimizer.state)
        return self.get_parameters(config={}), len(self.train_images), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        accuracy = eval_fn(self.model, self.test_images, self.test_labels)
        loss = loss_fn(self.model, self.test_images, self.test_labels)
        return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}

And as you can see, with only a few lines of code, our client is ready! Before we can instantiate it, we need to define a few variables:

num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 1
learning_rate = 1e-1

model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)

loss_and_grad_fn = mlx.nn.value_and_grad(model, loss_fn)
optimizer = mlx.optimizers.SGD(learning_rate=learning_rate)

Finally, we can instantiate it by using the start_client function:

# Start Flower client
fl.client.start_client(
    server_address="127.0.0.1:8080",
    client=FlowerClient(
        model,
        optimizer,
        loss_and_grad_fn,
        (train_images, train_labels, test_images, test_labels),
        num_epochs,
        batch_size,
    ).to_client(),
)

The server

On the server side, we don't need to add anything in particular. The weighted_average function is just there to be able to aggregate the results and have an accuracy at the end.

# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
)

Running the example

Once this is done, you can just start the server in an open terminal using:

$ python server.py

Note that our server uses the default FedAvg strategy with all default parameters, which is fine for this introductory example with just two clients.

Next, we open a new terminal and start the first client:

$ python client.py --node-id 0

Finally, we open another new terminal and start the second client:

$ python client.py --node-id 1

The node-id argument will define which partition our client will use.

You can now see that the MLX example is running federated through Flower. There is of course much more to learn, this was just a first glimpse of how Flower can allow you to easily federate existing MLX projects.

The next thing you could try is to use another dataset or model, start more clients, or even define your own strategy!


Share this post