@sainzpardo/ai4os-fedllm-medical
flwr new @sainzpardo/ai4os-fedllm-medicalFlowerTune LLM on Medical Dataset
Evaluation in the three baseline datasets:
| PubMedQA | MedMCQA | MedQA | Avg | |
|---|---|---|---|---|
| Acc (%) | 72.60 | 58.64 | 63.39 | 64.88 |
Communication budget: used 1040.31 MB (5th round).
Evaluation of the baseline model proposed
| PubMedQA | MedMCQA | MedQA | Avg | |
|---|---|---|---|---|
| Acc (%) | 59.00 | 23.69 | 27.10 | 36.60 |
Introduction
This directory conducts federated instruction tuning with a pretrained Mistral-7B model on a Medical dataset. We use Flower Datasets to download, partition and preprocess the dataset. Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way, which allows users to perform the training on a single GPU.
Changes from baseline
- Following the advances obtained with the approach presented by the Gachon Cognitive Computing Lab, we have used as a base model the ContactDoctor/Bio-Medical-Llama-3-8B fine tuned model.
- We train the model during 5 rounds. Although we set num-server-rounds = 20, we take the checkpoint obtained in round 5 (peft_5).
- We train the model locally during 5 epochs: train.training-arguments.num-train-epochs = 5
- We use the FedAvgOpt aggregation function.
Methodology
This baseline performs federated LLM fine-tuning with LoRA using the 🤗PEFT library. The clients' models are aggregated with FedAvg strategy. This provides a baseline performance for the leaderboard of Medical challenge.
Environments setup
Project dependencies are defined in pyproject.toml. Install them in an activated Python environment with:
pip install -e .
Experimental setup
The dataset is divided into 20 partitions in an IID fashion, a partition is assigned to each ClientApp. We randomly sample a fraction (0.1) of the total nodes to participate in each round, for a total of 20 rounds (but we take the checkpoint for round 5). All settings are defined in pyproject.toml.
Running the experiment
First, login in huggingface:
huggingface-cli login
Then, run the experiment:
flwr run .
Evaluation in the three baseline datasets:
python eval.py --base-model-name-path="ContactDoctor/Bio-Medical-Llama-3-8B" --peft-path="peft_5" --batch-size=16 --quantization=4 --datasets=pubmedqa,medmcqa,medqa