Federated Scikit-learn Using Flower

Photo of Kaushik Amar Das
Kaushik Amar Das
Senior Analyst

Share this post

Federated Scikit-learn Using Flower

Scikit-learn needs no introduction. It is one of the most beloved machine-learning modules out there. But it lacks direct support for federated learning (FL). We can easily fix that by combining scikit-learn with Flower! In this post, we will discuss an example of how we can leverage Flower's framework agnostic API for training a federated scikit-learn model. Let's get started!

Training Scenario

Since this is just an example, let us keep things simple. We will train a Logistic Regression model on the MNIST dataset using federated learning. We will have only two clients participating in the FL. The MNIST dataset will be artificially split into 10 parts, out of which each client will randomly pick one as their local dataset for training. This example is meant to be run locally on a single machine hosting both the clients and the server. Make sure to pip install openml scikit-learn along with your Flower installation as we will be needing these. You can find the complete code used in this blog post here.

This example comprises three scripts:, and The first and second scripts will contain the code for the server and the clients. The last script will contain the code for some of the utility functions that will be needed for our training. The following sections will discuss how each of these scripts is written. And following that, we will execute them to perform the federated learning.

Client code

The code for a Flower client training a scikit-learn model isn't too different from a Flower client using, for instance, Tensorflow. If you have worked through the other examples, things should look pretty familiar.

Begin by importing the following modules in Don't worry about the utils module for now. The functions in that module will be discussed in a later section.

import warnings
import flwr as fl
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import utils

Load the train and test split of the MNIST dataset. The train set is partitioned into 10 partitions, out of which a random partition is used for training.

# Load MNIST dataset from
(X_train, y_train), (X_test, y_test) = utils.load_mnist()

# Split train set into 10 partitions and randomly use one for training.
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]

Initialize the logistic regression model in the client. Let's have the client train for just a single iteration in each round by setting max_iter=1. Also, don't forget to set warm_start=True, otherwise, the model's parameters get refreshed when we call .fit. We don't want to reset the global parameters sent by the server.

# Create LogisticRegression Model
model = LogisticRegression(
    max_iter=1, # local epoch
    warm_start=True, # prevent refreshing weights when fitting

Next, we have to set the initial parameters of the model since the instance attributes used to save the model's parameters aren't created until .fit is called. But the server might want to set them or request them before fitting as is usually the case in federated learning. So, we create the parameter attributes and zero-initialize them using utils.set_initial_params(model).

# Setting initial parameters, akin to model.compile for keras models

Now it is time to define the Flower client. The client is derived from the class fl.client.NumPyClient. It needs to define the following three methods:

  1. get_parameters : Returns the current local model parameters. The utility function get_model_parameters does this for us.
  2. fit: Defines the steps to train the model on the locally held dataset. It also receives global model parameters and other configuration information from the server. We update the local model's parameters using the received global parameters using utils.set_model_params(model, parameters) and train it on the local dataset. This method also sends back the local model's parameters after training, the size of the training set and a dict communicating arbitrary values back to the server.
  3. evaluate: This method is meant for evaluating the provided parameters using a locally held dataset. It returns the loss along with other details such as the size of the test set, accuracy, etc., back to the server. Here, we calculate the loss value of the model explicitly using sklearn.metrics.log_loss. This is done explicitly because there is no public attribute in LogisticRegression that saves the loss value like, for instance, a TensorFlow model's history. Make sure to use the proper loss function corresponding to your model.
class MnistClient(fl.client.NumPyClient):
    def get_parameters(self): # type: ignore
        return utils.get_model_parameters(model)

    def fit(self, parameters, config): # type: ignore
        utils.set_model_params(model, parameters)
        # Ignore convergence failure due to low local epochs
        with warnings.catch_warnings():
  , y_train)
            print(f"Training finished for round {config['rnd']}")
        return utils.get_model_parameters(model), len(X_train), {}

    def evaluate(self, parameters, config): # type: ignore
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, len(X_test), {"accuracy": accuracy}

Finally, the script starts the client.

fl.client.start_numpy_client("", client=MnistClient())

Utility functions

We used a few utility functions in the client code that we will define in this section. The functions dealing with the model parameters are quite sensitive to the particular scikit-learn model that you would be using. So we have to carefully write them by following the model's documentation properly. In our case, we will follow the documentation for LogisticRegression.

The utility functions in script require the following imports and type hints.

from typing import Tuple, Union, List
import numpy as np
from sklearn.linear_model import LogisticRegression
import openml

XY = Tuple[np.ndarray, np.ndarray]
Dataset = Tuple[XY, XY]
LogRegParams = Union[XY, Tuple[np.ndarray]]
XYList = List[XY]

The get_model_parameters function returns the model parameters. These are found in the coef_ and intercept_ attributes for LogisticRegression .

def get_model_parameters(model):
    """Returns the parameters of a sklearn LogisticRegression model"""
    if model.fit_intercept:
        params = (model.coef_, model.intercept_)
        params = (model.coef_,)
    return params

The set_model_params function sets/updates the model's parameters. Here care needs to be taken to set the parameters using the same order/index in which they were returned by get_model_parameters.

def set_model_params(
    model: LogisticRegression, params: LogRegParams
) -> LogisticRegression:
    """Sets the parameters of a sklean LogisticRegression model"""
    model.coef_ = params[0]
    if model.fit_intercept:
        model.intercept_ = params[1]
    return model

The function set_initial_params zero-initializes the parameters of the model. This requires prior information about the attribute names, the number of classes and features of your dataset to calculate the size of the parameter matrices of the model. An alternative method for initializing the parameters could be to fit the model using a few dummy samples that mimic the dimensions of the actual dataset.

def set_initial_params(model: LogisticRegression):
    Sets initial parameters as zeros
    n_classes = 10 # MNIST has 10 classes
    n_features = 784 # Number of features in dataset
    model.classes_ = np.array([i for i in range(10)])

    model.coef_ = np.zeros((n_classes, n_features))
    if model.fit_intercept:
        model.intercept_ = np.zeros((n_classes,))

The rest of the utility functions are for loading the dataset and partitioning it and doesn't require much explanation to understand.

def load_mnist() -> Dataset:
    Loads the MNIST dataset using OpenML
    Dataset link:
    mnist_openml = openml.datasets.get_dataset(554)
    Xy, _, _, _ = mnist_openml.get_data(dataset_format="array")
    X = Xy[:, :-1] # the last column contains labels
    y = Xy[:, -1]
    # First 60000 samples consist of the train set
    x_train, y_train = X[:60000], y[:60000]
    x_test, y_test = X[60000:], y[60000:]
    return (x_train, y_train), (x_test, y_test)

def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
    """Split X and y into a number of partitions."""
    return list(
        zip(np.array_split(X, num_partitions),
        np.array_split(y, num_partitions))

Server code

Lastly, we will write the code used by the script. This includes defining the strategy for federation and its initialization parameters. Flower allows you to define your own callback functions to customize an existing strategy. We will use the FedAvg strategy with custom callbacks for evaluation and fit configuration. You can read more about how they work here.

The following imports are needed by the server script.

import flwr as fl
import utils
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict

The fit_round callback will be used to send the round number to the client. We will pass this callback as the on_fit_config_fn parameter of the strategy. We do this simply to demonstrate the use of the on_fit_config_fn parameter.

def fit_round(rnd: int) -> Dict:
    """Send round number to client"""
    return {"rnd": rnd}

The get_eval_fn callback will be used for validation.

def get_eval_fn(model: LogisticRegression):
    """Return an evaluation function for server-side evaluation."""

    # Load test data here to avoid the overhead of doing it in
    # `evaluate` itself
    _, (X_test, y_test) = utils.load_mnist()

    # The `evaluate` function will be called after every round
    def evaluate(parameters: fl.common.Weights):
        # Update model with the latest parameters
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, {"accuracy": accuracy}

    return evaluate

Next, we initialize the model and strategy and start the server. We will configure it to run for five rounds.

model = LogisticRegression()
strategy = fl.server.strategy.FedAvg(
    config={"num_rounds": 5}

And we have finished writing the scripts. All that's left is to run them. So let's do that next.


For this part, open three terminals, one for the Flower server and the other two for each one of the clients. In the first terminal, we will run the server:

$ python3

In the second terminal, use the following command to start the first client:

$ python3

And lastly, in the third terminal, start the second client in the same way:

$ python3

And voilà! Flower will initiate the federated learning and train you a federated scikit-learn model. I hope you enjoyed reading through this example and are excited to build your own.

Stay tuned for more examples in the future. If you have an exciting scikit-learn Flower recipe that you think people can learn from, feel free to make a pull request at the Flower repo.

Share this post