@mnabih/speech_llm_fl
flwr new @mnabih/speech_llm_flFederated SpeechLLM with Flower
This app implements federated learning for a Speech Large Language Model (SpeechLLM), enabling privacy-preserving training of a multimodal speech understanding system across multiple clients using the Flower framework. It combines a WavLM audio encoder, a lightweight connector module, and a TinyLlama LLM to perform joint audio understanding tasks — including transcription, speaker gender, emotion, age, accent, and speech activity detection — without raw audio data ever leaving the local client.
Key Features
- 🎙️ Multimodal Architecture — Combines WavLM audio encoder + connector + TinyLlama LLM for end-to-end speech understanding across multiple tasks simultaneously
- 🔒 Privacy-Preserving — Only trainable model parameters (LoRA weights + connector) are shared; raw audio data never leaves the client
- 🔁 Custom FedAvg Strategy — Server-side learning rate decay per round, configurable client sampling, and automatic model checkpointing after every aggregation
- ⚡ LoRA Fine-Tuning — Parameter-efficient federation: only LoRA adapters and the connector are trained and communicated, drastically reducing communication cost
- 📊 Rich Metric Logging — Tracks WER (transcription), gender, emotion, age, accent accuracy, and speech activity per validation round via W&B
- 💾 Checkpoint Resumption — Supports resuming federation from a pretrained .ckpt file via pyproject.toml config
- 🖥️ GPU Support — Full CUDA acceleration with gradient clipping and gradient accumulation
Architecture
| File | Description |
|---|---|
| client_app.py | ClientApp with @app.train() and @app.evaluate() handlers — loads weights, runs local PyTorch Lightning training, returns updated parameters and metrics |
| server_app.py | ServerApp with SpeechLLMFedAvg strategy — manages LR decay per round, hierarchical aggregation, and checkpoint saving |
| trainer.py | SpeechLLMLightning — PyTorch Lightning module defining the full SpeechLLM model, forward pass, training/validation/test steps, and metric logging |
| dataset.py | InstructionalAudioDataset, MyCollator, and build_dataloaders_from_csvs — loads partitioned audio CSV datasets per client |
| pyproject.toml | All federation, model, training, data, and checkpoint config in one place |
Federated Learning Process
- Initialization — Server loads global SpeechLLMLightning model (optionally from a pretrained checkpoint) and extracts only trainable parameters (LoRA + connector)
- Round Config — Server computes a decayed learning rate for the round and broadcasts it alongside model weights to sampled clients
- Local Training — Each client loads the received weights, trains locally for local-epochs with train-batch-per-epoch steps using PyTorch Lightning
- Aggregation — Server performs weighted FedAvg over client updates proportional to dataset sizes
- Checkpointing — Aggregated model is saved to disk after every round; final model saved as final_model.pt
Model Architecture
Audio Input (waveform)
│
▼
WavLM Encoder (frozen or finetuned)
│
▼
Connector (Linear / LinearPool / CNN)
│
▼
[Pre-prompt embeddings] + [Speech embeddings] + [Post-prompt embeddings]
│
▼
TinyLlama LLM (LoRA fine-tuned)
│
▼
Structured JSON output:
{ "Transcript": "...", "Gender": "male", "Emotion": "neutral", ... }
Fetch the App
Install Flower:
pip install flwr
Fetch the app:
flwr new @mnabih/speech-llm-fl
This will create the following structure:
speech_llm_fl/
├── speech_llm_fl/
│ ├── __init__.py
│ ├── client_app.py # ClientApp — local train & evaluate handlers
│ ├── server_app.py # ServerApp — SpeechLLMFedAvg strategy + main
│ ├── trainer.py # SpeechLLMLightning model definition
│ └── dataset.py # Dataset, collator, and dataloader utilities
├── pyproject.toml # All project metadata and Flower config
└── README.md
Prerequisites
- Python 3.10+
- CUDA-capable GPU (strongly recommended for WavLM + LLM training)
- Audio data partitioned as CSV files per client, each row containing audio_path and label columns (transcript, gender, emotion, age, accent, isspeech)
- Pretrained model weights accessible (WavLM and TinyLlama will be downloaded from HuggingFace on first run)
Install Dependencies
cd speech_llm_fl && pip install -e .
Data Preparation
Each client needs a CSV file with the following columns:
| Column | Description |
|---|---|
| audio_path | Absolute path to the .wav audio file (16kHz mono) |
| transcript | Ground-truth transcription text |
| gender | Speaker gender (male / female) |
| emotion | Emotion label (e.g. neutral, happy, sad) |
| age | Age group label |
| accent | Accent label |
| isspeech | Boolean — whether audio contains speech |
Organize client partitions into a directory, one CSV per client:
fl_multilingual/
├── client_0.csv
├── client_1.csv
├── client_2.csv
└── ...
Set the paths in pyproject.toml:
csv-train-dir = "./fl_multilingual" csv-dev-dir = "./fl_MLS_dev_speaker"
Run the App
Simulation (Single Machine)
Run with default settings:
flwr run .
Override specific settings at runtime:
flwr run . --run-config "num-server-rounds=10 local-epochs=5 max-lr=0.00005"
Resume from a pretrained checkpoint:
flwr run . --run-config "pretrained-checkpoint=/path/to/Checkpoint-round-420.ckpt checkpoint-offset=420"
Note: Simulation runs all clients on the same machine via Ray. Ensure sufficient GPU memory or reduce train-batch-per-epoch and train-batch-size for smaller GPUs.
Deployment Engine (Multi-Machine)
1. Start the SuperLink (Terminal 1):
flower-superlink --insecure
2. Start SuperNodes — one per client machine (separate terminals):
# Client 0 flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config "partition-id=0 num-partitions=4" # Client 1 flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9095 \ --node-config "partition-id=1 num-partitions=4"
3. Run the federation (Terminal 4):
flwr run . supernode-deployment
Configuration
All settings are controlled via pyproject.toml. No code changes needed to run experiments.
Federation Settings
[tool.flwr.app.config] num-server-rounds = 200 # Total FL rounds fraction-fit = 0.3 # Fraction of clients sampled per round fraction-evaluate = 0.0 # Fraction of clients used for evaluation min-fit-clients = 2 # Minimum clients required to start a round min-evaluate-clients = 2
Model Settings
audio-encoder-name = "microsoft/wavlm-large" # HuggingFace model ID llm-name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" connector-name = "linear" # "linear", "linear-pool", or "cnn" audio-enc-dim = 1024 llm-dim = 2048 use-lora = true lora-r = 8 lora-alpha = 16 finetune-encoder = false
Training Settings
local-epochs = 10 # Epochs per client per round train-batch-size = 4 train-batch-per-epoch = 200 # Steps per epoch (limits dataset length) grad-accumulate-steps = 4 max-lr = 0.0001
Learning Rate Decay
lr-decay-factor = 0.9 # Multiply LR by this value every N rounds lr-decay-every = 10 # Decay interval in rounds
The effective LR at round r is: max-lr × decay-factor ^ (r // decay-every)
Checkpoint Settings
checkpoint-dir = "FL_SLAM_checkpoints" # Directory for round checkpoints checkpoint-offset = 0 # Add offset to round number in filenames pretrained-checkpoint = "" # Path to .ckpt to resume from
Simulation Settings
[tool.flwr.federations.local-simulation] options.num-supernodes = 316 # Total number of simulated clients [tool.flwr.federations.local-simulation.options] backend.client-resources.num-cpus = 1 backend.client-resources.num-gpus = 1
Metrics Tracked
| Metric | Description |
|---|---|
| train_loss | Cross-entropy loss on local training data |
| val/loss | Validation loss |
| val/wer | Word Error Rate on transcript predictions |
| val/gender | Gender classification accuracy |
| val/emotion | Emotion classification accuracy |
| val/age | Age group classification accuracy |
| val/accent | Accent classification accuracy |
| val/speech_activity | Speech activity detection accuracy |
Results
Performance comparison of WavLM vs. Whisper encoders, measured in Word Error Rate (WER ↓) on LibriSpeech (LS) and Multilingual LibriSpeech (MLS) test sets. Central training serves as the upper bound.
| Setting | WavLM (Round=100) LS | WavLM (Round=100) MLS | Whisper (Round=40) LS | Whisper (Round=40) MLS |
|---|---|---|---|---|
| Central Training ⭐ | 6.1 | 18.4 | 6.4 | 16.4 |
| FL Sample Cluster | 9.7 | 19.6 | 7.7 | 16.4 |
⭐ Central training is the upper bound (non-federated). Lower WER is better.
Key takeaway: WavLM with federated learning (FL Sample Cluster) achieves competitive WER on both benchmarks, with only a modest gap vs. central training — demonstrating that the federation does not significantly degrade model quality while preserving data privacy.
More Information
- Flower Framework Docs — Flower framework documentation
- WavLM (Microsoft) — Audio encoder used in this app
- TinyLlama — LLM backbone
- LoRA / PEFT — Parameter-efficient fine-tuning library
- PyTorch Lightning — Training framework
License
Apache License 2.0