Quickstart ๐Ÿค— Transformersยถ

In this federated learning tutorial we will learn how to train a large language model (LLM) on the IMDB dataset using Flower and the ๐Ÿค— Hugging Face Transformers library. It is recommended to create a virtual environment and run everything within a virtualenv.

Letโ€™s use flwr new to create a complete Flower+๐Ÿค— Hugging Face project. It will generate all the files needed to run, by default with the Flower Simulation Engine, a federation of 10 nodes using FedAvg The dataset will be partitioned using Flower Datasetsโ€™s IidPartitioner.

Now that we have a rough idea of what this example is about, letโ€™s get started. First, install Flower in your new environment:

# In a new Python environment
$ pip install flwr

Then, run the command below. You will be prompted to select one of the available templates (choose HuggingFace), give a name to your project, and type in your developer name:

$ flwr new

After running it youโ€™ll notice a new directory with your project name has been created. It should have the following structure:

<your-project-name>
โ”œโ”€โ”€ <your-project-name>
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”œโ”€โ”€ client_app.py   # Defines your ClientApp
โ”‚   โ”œโ”€โ”€ server_app.py   # Defines your ServerApp
โ”‚   โ””โ”€โ”€ task.py         # Defines your model, training and data loading
โ”œโ”€โ”€ pyproject.toml      # Project metadata like dependencies and configs
โ””โ”€โ”€ README.md

If you havenโ€™t yet installed the project and its dependencies, you can do so by:

# From the directory where your pyproject.toml is
$ pip install -e .

To run the project, do:

# Run with default arguments
$ flwr run .

With default arguments you will see an output like this one:

Loading project configuration...
Success
INFO :      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
INFO :
INFO :      [INIT]
INFO :      Using initial global parameters provided by strategy
INFO :      Starting evaluation of initial global parameters
INFO :      Evaluation returned no results (`None`)
INFO :
INFO :      [ROUND 1]
INFO :      configure_fit: strategy sampled 2 clients (out of 10)
INFO :      aggregate_fit: received 2 results and 0 failures
WARNING :   No fit_metrics_aggregation_fn provided
INFO :      configure_evaluate: strategy sampled 10 clients (out of 10)
INFO :      aggregate_evaluate: received 10 results and 0 failures
WARNING :   No evaluate_metrics_aggregation_fn provided
INFO :
INFO :      [ROUND 2]
INFO :      configure_fit: strategy sampled 5 clients (out of 10)
INFO :      aggregate_fit: received 5 results and 0 failures
INFO :      configure_evaluate: strategy sampled 10 clients (out of 10)
INFO :      aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO :      [ROUND 3]
INFO :      configure_fit: strategy sampled 5 clients (out of 10)
INFO :      aggregate_fit: received 5 results and 0 failures
INFO :      configure_evaluate: strategy sampled 10 clients (out of 10)
INFO :      aggregate_evaluate: received 10 results and 0 failures
INFO :
INFO :      [SUMMARY]
INFO :      Run finished 3 round(s) in 249.11s
INFO :          History (loss, distributed):
INFO :                  round 1: 0.02111011856794357
INFO :                  round 2: 0.019722302150726317
INFO :                  round 3: 0.018227258533239362
INFO :

You can also run the project with GPU as follows:

# Run with default arguments
$ flwr run . localhost-gpu

This will use the default arguments where each ClientApp will use 2 CPUs and at most 4 ClientApps will run in a given GPU.

You can also override the parameters defined in the [tool.flwr.app.config] section in pyproject.toml like this:

# Override some arguments
$ flwr run . --run-config "num-server-rounds=5 fraction-fit=0.2"

What follows is an explanation of each component in the project you just created: dataset partition, the model, defining the ClientApp and defining the ServerApp.

The Dataยถ

This tutorial uses Flower Datasets to easily download and partition the IMDB dataset. In this example youโ€™ll make use of the IidPartitioner to generate num_partitions partitions. You can choose other partitioners available in Flower Datasets. To tokenize the text, we will also load the tokenizer from the pre-trained Transformer model that weโ€™ll use during training - more on that in the next section. Each ClientApp will call this function to create dataloaders with the data that correspond to their data partition.

partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
    dataset="stanfordnlp/imdb",
    partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize_function(examples):
    return tokenizer(
        examples["text"], truncation=True, add_special_tokens=True, max_length=512
    )


partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
partition_train_test = partition_train_test.rename_column("label", "labels")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
    partition_train_test["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator,
)

testloader = DataLoader(
    partition_train_test["test"], batch_size=32, collate_fn=data_collator
)

The Modelยถ

We will leverage ๐Ÿค— Hugging Face to federate the training of language models over multiple clients using Flower. More specifically, we will fine-tune a pre-trained Transformer model (bert-tiny) for sequence classification over the dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative. If you have access to larger GPUs, feel free to use larger models!

net = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=num_labels
)

Note that here, model_name is a string that will be loaded from the Context in the ClientApp and ServerApp.

In addition to loading the pretrained model weights and architecture, we also include two utility functions to perform both training (i.e. train()) and evaluation (i.e. test()) using the above model. These functions should look fairly familiar if you have some prior experience with PyTorch. Note these functions do not have anything specific to Flower. That being said, the training function will normally be called, as weโ€™ll see later, from a Flower client passing its own data. In summary, your clients can use standard training/testing functions to perform local training or evaluation:

def train(net, trainloader, epochs, device):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test(net, testloader, device):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

The ClientAppยถ

The main changes we have to make to use ๐Ÿค— Hugging Face with Flower will be found in the get_weights() and set_weights() functions. Under the hood, the transformers library uses PyTorch, which means we can reuse the get_weights() and set_weights() code that we defined in the Quickstart PyTorch tutorial. As a reminder, in get_weights(), PyTorch model parameters are extracted and represented as a list of NumPy arrays. The set_weights() function thatโ€™s the opposite: given a list of NumPy arrays it applies them to an existing PyTorch model. Doing this in fairly easy in PyTorch.

์ฐธ๊ณ 

The specific implementation of get_weights() and set_weights() depends on the type of models you use. The ones shown below work for a wide range of PyTorch models but you might need to adjust them if you have more exotic model architectures.

def get_weights(net):
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

The rest of the functionality is directly inspired by the centralized case. The fit() method in the client trains the model using the local dataset. Similarly, the evaluate() method is used to evaluate the model received on a held-out validation set that the client might have:

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader, local_epochs):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader
        self.local_epochs = local_epochs
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net.to(self.device)

    def fit(self, parameters, config):
        set_weights(self.net, parameters)
        train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
        return get_weights(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        loss, accuracy = test(self.net, self.testloader, self.device)
        return float(loss), len(self.testloader), {"accuracy": accuracy}

Finally, we can construct a ClientApp using the FlowerClient defined above by means of a client_fn() callback. Note that the context enables you to get access to hyperparemeters defined in your pyproject.toml to configure the run. In this tutorial we access the local-epochs setting to control the number of epochs a ClientApp will perform when running the fit() method. You could define additional hyperparameters in pyproject.toml and access them here.

def client_fn(context: Context):

    # Get this client's dataset partition
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    model_name = context.run_config["model-name"]
    trainloader, valloader = load_data(partition_id, num_partitions, model_name)

    # Load model
    num_labels = context.run_config["num-labels"]
    net = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=num_labels
    )

    local_epochs = context.run_config["local-epochs"]

    # Return Client instance
    return FlowerClient(net, trainloader, valloader, local_epochs).to_client()


# Flower ClientApp
app = ClientApp(client_fn)

The ServerAppยถ

To construct a ServerApp we define a server_fn() callback with an identical signature to that of client_fn() but the return type is ServerAppComponents as opposed to a Client In this example we use the FedAvg strategy. To it we pass a randomly initialized model that will server as the global model to federated. Note that the value of fraction_fit is read from the run config. You can find the default value defined in the pyproject.toml.

def server_fn(context: Context):
    # Read from config
    num_rounds = context.run_config["num-server-rounds"]
    fraction_fit = context.run_config["fraction-fit"]

    # Initialize global model
    model_name = context.run_config["model-name"]
    num_labels = context.run_config["num-labels"]
    net = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=num_labels
    )

    weights = get_weights(net)
    initial_parameters = ndarrays_to_parameters(weights)

    # Define strategy
    strategy = FedAvg(
        fraction_fit=fraction_fit,
        fraction_evaluate=1.0,
        initial_parameters=initial_parameters,
    )
    config = ServerConfig(num_rounds=num_rounds)

    return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)

Congratulations! Youโ€™ve successfully built and run your first federated learning system for an LLM.

์ฐธ๊ณ 

Check the source code of the extended version of this tutorial in examples/quickstart-huggingface in the Flower GitHub repository. For a comprehensive example of a federated fine-tuning of an LLM with Flower, refer to the FlowerTune LLM example in the Flower GitHub repository.