Use with TensorFlow

Let’s integrate flwr-datasets with TensorFlow. We show you three ways how to convert the data into the formats that TensorFlow’s models expect. Please note that, especially for the smaller datasets, the performance of the following methods is very close. We recommend you choose the method you are the most comfortable with.

Create a FederatedDataset:

from flwr_datasets import FederatedDataset

fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
partition = fds.load_partition(0, "train")
centralized_dataset = fds.load_split("test")

Inspect the names of the features:

partition.features

In case of CIFAR10, you should see the following output.

{'img': Image(decode=True, id=None),
'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
'frog', 'horse', 'ship', 'truck'], id=None)}

We will use the keys in the partition features in order to construct a tf.data.Dataset. Let’s move to the transformations.

NumPy

The first way is to transform the data into the NumPy arrays. It’s an easier option that is commonly used. Feel free to follow the Use with NumPy tutorial, especially if you are a beginner.

TensorFlow Dataset

Transform the data to TensorFlow Dataset:

tf_dataset = partition.to_tf_dataset(columns="img", label_cols="label", batch_size=64,
                                   shuffle=True)
# Assuming you have defined your model and compiled it
model.fit(tf_dataset, epochs=20)

TensorFlow Tensors

Transform the data to the TensorFlow tf.Tensor (it’s not the TensorFlow dataset):

data_tf = partition.with_format("tf")
# Assuming you have defined your model and compiled it
model.fit(data_tf["img"], data_tf["label"], epochs=20, batch_size=64)

CNN Keras Model

Here’s a quick example of how you can use that data with a simple CNN model (it assumes you created the TensorFlow dataset as in the section above, see TensorFlow Dataset):

import tensorflow as tf
from tensorflow.keras import datasets, layers, models

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
model.fit(tf_dataset, epochs=20)