Quickstart TensorFlowΒΆ
In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR-10 using the Flower framework and TensorFlow. First of all, it is recommended to create a virtual environment and run everything within a virtualenv.
Letβs use flwr new to create a complete Flower+TensorFlow project. It will generate
all the files needed to run, by default with the Flower Simulation Engine, a federation
of 10 nodes using FedAvg. The dataset will be partitioned using Flower Datasetβs
IidPartitioner.
Now that we have a rough idea of what this example is about, letβs get started. First, install Flower in your new environment:
# In a new Python environment
$ pip install flwr
Then, run the command below. You will be prompted to select one of the available
templates (choose TensorFlow), give a name to your project, and type in your
developer name:
$ flwr new
After running it youβll notice a new directory with your project name has been created. It should have the following structure:
<your-project-name>
βββ <your-project-name>
β   βββ __init__.py
β   βββ client_app.py   # Defines your ClientApp
β   βββ server_app.py   # Defines your ServerApp
β   βββ task.py         # Defines your model, training and data loading
βββ pyproject.toml      # Project metadata like dependencies and configs
βββ README.md
If you havenβt yet installed the project and its dependencies, you can do so by:
# From the directory where your pyproject.toml is
$ pip install -e .
To run the project, do:
# Run with default arguments
$ flwr run .
With default arguments you will see an output like this one:
Loading project configuration...
Success
INFO :      Starting FedAvg strategy:
INFO :          βββ Number of rounds: 3
INFO :          βββ ArrayRecord (0.16 MB)
INFO :          βββ ConfigRecord (train): (empty!)
INFO :          βββ ConfigRecord (evaluate): (empty!)
INFO :          βββ> Sampling:
INFO :          β       βββFraction: train (0.50) | evaluate ( 1.00)
INFO :          β       βββMinimum nodes: train (2) | evaluate (2)
INFO :          β       βββMinimum available nodes: 2
INFO :          βββ> Keys in records:
INFO :                  βββ Weighted by: 'num-examples'
INFO :                  βββ ArrayRecord key: 'arrays'
INFO :                  βββ ConfigRecord key: 'config'
INFO :
INFO :
INFO :      [ROUND 1/3]
INFO :      configure_train: Sampled 5 nodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'train_loss': 2.0013, 'train_acc': 0.2624}
INFO :      configure_evaluate: Sampled 10 nodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'eval_acc': 0.1216, 'eval_loss': 2.2686}
INFO :
INFO :      [ROUND 2/3]
INFO :      configure_train: Sampled 5 nodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'train_loss': 1.8099, 'train_acc': 0.3373}
INFO :      configure_evaluate: Sampled 10 nodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'eval_acc': 0.4273, 'eval_loss': 1.6684}
INFO :
INFO :      [ROUND 3/3]
INFO :      configure_train: Sampled 5 nodes (out of 10)
INFO :      aggregate_train: Received 5 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'train_loss': 1.6749, 'train_acc': 0.3965}
INFO :      configure_evaluate: Sampled 10 nodes (out of 10)
INFO :      aggregate_evaluate: Received 10 results and 0 failures
INFO :          βββ> Aggregated MetricRecord: {'eval_acc': 0.4281, 'eval_loss': 1.5807}
INFO :
INFO :      Strategy execution finished in 16.60s
INFO :
INFO :      Final results:
INFO :
INFO :          Global Arrays:
INFO :                  ArrayRecord (0.163 MB)
INFO :
INFO :          Aggregated ClientApp-side Train Metrics:
INFO :          { 1: {'train_acc': '2.6240e-01', 'train_loss': '2.0014e+00'},
INFO :            2: {'train_acc': '3.3725e-01', 'train_loss': '1.8099e+00'},
INFO :            3: {'train_acc': '3.9655e-01', 'train_loss': '1.6750e+00'}}
INFO :
INFO :          Aggregated ClientApp-side Evaluate Metrics:
INFO :          { 1: {'eval_acc': '1.2160e-01', 'eval_loss': '2.2686e+00'},
INFO :            2: {'eval_acc': '4.2730e-01', 'eval_loss': '1.6684e+00'},
INFO :            3: {'eval_acc': '4.2810e-01', 'eval_loss': '1.5807e+00'}}
INFO :
INFO :          ServerApp-side Evaluate Metrics:
INFO :          {}
INFO :
Saving final model to disk as final_model.keras...
You can also override the parameters defined in the [tool.flwr.app.config] section
in pyproject.toml like this:
# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 batch-size=16"
The DataΒΆ
This tutorial uses Flower Datasets to easily
download and partition the CIFAR-10 dataset. In this example youβll make use of the
IidPartitioner
to generate num_partitions partitions. You can choose other partitioners available in
Flower Datasets. Each ClientApp will call this function to create the NumPy
arrays that correspond to their data partition.
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
    dataset="uoft-cs/cifar10",
    partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")
# Divide data on each node: 80% train, 20% test
partition = partition.train_test_split(test_size=0.2)
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
The ModelΒΆ
Next, we need a model. We defined a simple Convolutional Neural Network (CNN), but feel free to replace it with a more sophisticated model if youβd like:
def load_model(learning_rate: float = 0.001):
    # Define a simple CNN for CIFAR-10 and set Adam optimizer
    model = keras.Sequential(
        [
            keras.Input(shape=(32, 32, 3)),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(10, activation="softmax"),
        ]
    )
    model.compile(
        "adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model
The ClientAppΒΆ
The main changes we have to make to use Tensorflow with Flower have to do with
converting the ArrayRecord received in the Message into numpy ndarrays
for use with the built-in set_weights() function. After training, the
get_weights() function can be used to extract then pack the updated numpy ndarrays
into a Message from the ClientApp. We can make use of built-in methods in the
ArrayRecord to make these conversions:
@app.train()
def train(msg: Message, context: Context):
    # Load the model
    model = load_model(context.run_config["learning-rate"])
    # Extract the ArrayRecord from Message and convert to numpy ndarrays
    model.set_weights(msg.content["arrays"].to_numpy_ndarrays())
    # Train the model
    ...
    # Pack the model weights into an ArrayRecord
    model_record = ArrayRecord(model.get_weights())
The rest of the functionality is directly inspired by the centralized case. The
ClientApp comes with three core methods (train, evaluate, and query)
that we can implement for different purposes. For example: train to train the
received model using the local data; evaluate to assess its performance of the
received model on a validation set; and query to retrieve information about the node
executing the ClientApp. In this tutorial we will only make use of train and
evaluate.
Letβs see how the train method can be implemented. It receives as input arguments a
Message from the ServerApp. By default it carries:
- an - ArrayRecordwith the arrays of the model to federate. By default they can be retrieved with key- "arrays"when accessing the message content.
- a - ConfigRecordwith the configuration sent from the- ServerApp. By default it can be retrieved with key- "config"when accessing the message content.
The train method also receives the Context, giving access to configs for your
run and node. The run config hyperparameters are defined in the pyproject.toml of
your Flower App. The node config can only be set when running Flower with the Deployment
Runtime and is not directly configurable during simulations.
# Flower ClientApp
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
    """Train the model on local data."""
    # Reset local Tensorflow state
    keras.backend.clear_session()
    # Load the data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    x_train, y_train, _, _ = load_data(partition_id, num_partitions)
    # Load the model
    model = load_model(context.run_config["learning-rate"])
    model.set_weights(msg.content["arrays"].to_numpy_ndarrays())
    epochs = context.run_config["local-epochs"]
    batch_size = context.run_config["batch-size"]
    verbose = context.run_config.get("verbose")
    # Train the model
    history = model.fit(
        x_train,
        y_train,
        epochs=epochs,
        batch_size=batch_size,
        verbose=verbose,
    )
    # Get training metrics
    train_loss = history.history["loss"][-1] if "loss" in history.history else None
    train_acc = (
        history.history["accuracy"][-1] if "accuracy" in history.history else None
    )
    # Pack and send the model weights and metrics as a message
    model_record = ArrayRecord(model.get_weights())
    metrics = {"num-examples": len(x_train)}
    if train_loss is not None:
        metrics["train_loss"] = train_loss
    if train_acc is not None:
        metrics["train_acc"] = train_acc
    content = RecordDict({"arrays": model_record, "metrics": MetricRecord(metrics)})
    return Message(content=content, reply_to=msg)
The @app.evaluate() method would be near identical with two exceptions: (1) the
model is not locally trained, instead it is used to evaluate its performance on the
locally held-out validation set; (2) including the model in the reply Message is no
longer needed because it is not locally modified.
The ServerAppΒΆ
To construct a ServerApp we define its @app.main() method. This method
receive as input arguments:
- a - Gridobject that will be used to interface with the nodes running the- ClientAppto involve them in a round of train/evaluate/query or other.
- a - Contextobject that provides access to the run configuration.
In this example we use the FedAvg and configure it with a specific value of
fraction_train which is read from the run config. You can find the default value
defined in the pyproject.toml. Then, the execution of the strategy is launched when
invoking its start method. To it we pass:
- the - Gridobject.
- an - ArrayRecordcarrying a randomly initialized model that will serve as the global model to federated.
- the - num_roundsparameter specifying how many rounds of- FedAvgto perform.
# Create the ServerApp
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""
    # Load config
    num_rounds = context.run_config["num-server-rounds"]
    fraction_train = context.run_config["fraction-train"]
    # Load initial model
    model = load_model()
    arrays = ArrayRecord(model.get_weights())
    # Define and start FedAvg strategy
    strategy = FedAvg(
        fraction_train=fraction_train,
    )
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        num_rounds=num_rounds,
    )
    # Save the final model
    ndarrays = result.arrays.to_numpy_ndarrays()
    final_model_name = "final_model.keras"
    print(f"Saving final model to disk as {final_model_name}...")
    model.set_weights(ndarrays)
    model.save(final_model_name)
Note the start method of the strategy returns a result object. This object contains
all the relevant information about the FL process, including the final model weights as
an ArrayRecord, and federated training and evaluation metrics as MetricRecords.
You can easily log the metrics using Pythonβs pprint and save the final model weights using
Tensorflowβs save() function.
Congratulations! Youβve successfully built and run your first federated learning system.
Note
Check the source code of the extended version of this tutorial in
examples/quickstart-tensorflow in the Flower GitHub repository.