@sainzpardo/ai4os-fedllm-medical-v2
flwr new @sainzpardo/ai4os-fedllm-medical-v2FlowerTune LLM on Medical Dataset
Introduction
This directory conducts federated instruction tuning with a pretrained ContactDoctor/Bio-Medical-Llama-3-8B model on a Medical dataset. We use Flower Datasets to download, partition and preprocess the dataset. Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way, which allows users to perform the training on a single GPU.
Evaluation in the three baseline datasets with the proposed approach:
| PubMedQA | MedMCQA | MedQA | CareQA | Avg | |
|---|---|---|---|---|---|
| Acc (%) | 66.20 | 60.29 | 68.42 | 53.64 | 62.14 |
Communication budget: 1040.31 MB*
*Note that this value has been obtained when running the experiment using a NVIDIA GPU Tesla V100-PCIE-32GB.
Changes from baseline
- Following the advances obtained with the approach presented by the Gachon Cognitive Computing Lab, we have used as a base model the ContactDoctor/Bio-Medical-Llama-3-8B fine tuned model.
- We train the model during 5 rounds, num-server-rounds = 5, see peft_5.
- We train the model locally during 3 epochs: train.training-arguments.num-train-epochs = 3.
- We take train.learning-rate-max = 5e-6 and train.learning-rate-min = 1e-7.
- We use FedAvgOpt as aggregation function.
Methodology
This baseline performs federated LLM fine-tuning with LoRA using the 🤗PEFT library.
Environments setup
Project dependencies are defined in pyproject.toml. Install them in an activated Python environment with:
pip install -e .
Experimental setup
The dataset is divided into 20 partitions in an IID fashion, a partition is assigned to each ClientApp. We randomly sample a fraction (0.1) of the total nodes to participate in each round, for a total of 5 rounds. All settings are defined in pyproject.toml.
Running the experiment
First, login in huggingface:
huggingface-cli login
Then, run the experiment:
flwr run .
Evaluation in the three baseline datasets:
python eval.py --base-model-name-path="ContactDoctor/Bio-Medical-Llama-3-8B" --peft-path="peft_5" --batch-size=16 --quantization=4 --datasets=pubmedqa,medmcqa,medqa,careqa