Flower AI Summit 2026·April 15–16·London

@mnabih/speech_llm_fl

0
0
flwr new @mnabih/speech_llm_fl

Federated SpeechLLM with Flower

This app implements federated learning for a Speech Large Language Model (SpeechLLM), enabling privacy-preserving training of a multimodal speech understanding system across multiple clients using the Flower framework. It combines a WavLM audio encoder, a lightweight connector module, and a TinyLlama LLM to perform joint audio understanding tasks — including transcription, speaker gender, emotion, age, accent, and speech activity detection — without raw audio data ever leaving the local client.

Key Features

  • 🎙️ Multimodal Architecture — Combines WavLM audio encoder + connector + TinyLlama LLM for end-to-end speech understanding across multiple tasks simultaneously
  • 🔒 Privacy-Preserving — Only trainable model parameters (LoRA weights + connector) are shared; raw audio data never leaves the client
  • 🔁 Custom FedAvg Strategy — Server-side learning rate decay per round, configurable client sampling, and automatic model checkpointing after every aggregation
  • LoRA Fine-Tuning — Parameter-efficient federation: only LoRA adapters and the connector are trained and communicated, drastically reducing communication cost
  • 📊 Rich Metric Logging — Tracks WER (transcription), gender, emotion, age, accent accuracy, and speech activity per validation round via W&B
  • 💾 Checkpoint Resumption — Supports resuming federation from a pretrained .ckpt file via pyproject.toml config
  • 🖥️ GPU Support — Full CUDA acceleration with gradient clipping and gradient accumulation

Architecture

FileDescription
client_app.pyClientApp with @app.train() and @app.evaluate() handlers — loads weights, runs local PyTorch Lightning training, returns updated parameters and metrics
server_app.pyServerApp with SpeechLLMFedAvg strategy — manages LR decay per round, hierarchical aggregation, and checkpoint saving
trainer.pySpeechLLMLightning — PyTorch Lightning module defining the full SpeechLLM model, forward pass, training/validation/test steps, and metric logging
dataset.pyInstructionalAudioDataset, MyCollator, and build_dataloaders_from_csvs — loads partitioned audio CSV datasets per client
pyproject.tomlAll federation, model, training, data, and checkpoint config in one place

Federated Learning Process

  1. Initialization — Server loads global SpeechLLMLightning model (optionally from a pretrained checkpoint) and extracts only trainable parameters (LoRA + connector)
  2. Round Config — Server computes a decayed learning rate for the round and broadcasts it alongside model weights to sampled clients
  3. Local Training — Each client loads the received weights, trains locally for local-epochs with train-batch-per-epoch steps using PyTorch Lightning
  4. Aggregation — Server performs weighted FedAvg over client updates proportional to dataset sizes
  5. Checkpointing — Aggregated model is saved to disk after every round; final model saved as final_model.pt

Model Architecture

Audio Input (waveform)
        │
        ▼
 WavLM Encoder (frozen or finetuned)
        │
        ▼
   Connector (Linear / LinearPool / CNN)
        │
        ▼
  [Pre-prompt embeddings] + [Speech embeddings] + [Post-prompt embeddings]
        │
        ▼
   TinyLlama LLM (LoRA fine-tuned)
        │
        ▼
  Structured JSON output:
  { "Transcript": "...", "Gender": "male", "Emotion": "neutral", ... }

Fetch the App

Install Flower:

pip install flwr

Fetch the app:

flwr new @mnabih/speech-llm-fl

This will create the following structure:

speech_llm_fl/
├── speech_llm_fl/
│   ├── __init__.py
│   ├── client_app.py       # ClientApp — local train & evaluate handlers
│   ├── server_app.py       # ServerApp — SpeechLLMFedAvg strategy + main
│   ├── trainer.py          # SpeechLLMLightning model definition
│   └── dataset.py          # Dataset, collator, and dataloader utilities
├── pyproject.toml          # All project metadata and Flower config
└── README.md

Prerequisites

  1. Python 3.10+
  2. CUDA-capable GPU (strongly recommended for WavLM + LLM training)
  3. Audio data partitioned as CSV files per client, each row containing audio_path and label columns (transcript, gender, emotion, age, accent, isspeech)
  4. Pretrained model weights accessible (WavLM and TinyLlama will be downloaded from HuggingFace on first run)

Install Dependencies

cd speech_llm_fl && pip install -e .

Data Preparation

Each client needs a CSV file with the following columns:

ColumnDescription
audio_pathAbsolute path to the .wav audio file (16kHz mono)
transcriptGround-truth transcription text
genderSpeaker gender (male / female)
emotionEmotion label (e.g. neutral, happy, sad)
ageAge group label
accentAccent label
isspeechBoolean — whether audio contains speech

Organize client partitions into a directory, one CSV per client:

fl_multilingual/
├── client_0.csv
├── client_1.csv
├── client_2.csv
└── ...

Set the paths in pyproject.toml:

csv-train-dir = "./fl_multilingual"
csv-dev-dir   = "./fl_MLS_dev_speaker"

Run the App

Simulation (Single Machine)

Run with default settings:

flwr run .

Override specific settings at runtime:

flwr run . --run-config "num-server-rounds=10 local-epochs=5 max-lr=0.00005"

Resume from a pretrained checkpoint:

flwr run . --run-config "pretrained-checkpoint=/path/to/Checkpoint-round-420.ckpt checkpoint-offset=420"

Note: Simulation runs all clients on the same machine via Ray. Ensure sufficient GPU memory or reduce train-batch-per-epoch and train-batch-size for smaller GPUs.


Deployment Engine (Multi-Machine)

1. Start the SuperLink (Terminal 1):

flower-superlink --insecure

2. Start SuperNodes — one per client machine (separate terminals):

# Client 0
flower-supernode --insecure --superlink 127.0.0.1:9092 \
    --clientappio-api-address 127.0.0.1:9094 \
    --node-config "partition-id=0 num-partitions=4"

# Client 1
flower-supernode --insecure --superlink 127.0.0.1:9092 \
    --clientappio-api-address 127.0.0.1:9095 \
    --node-config "partition-id=1 num-partitions=4"

3. Run the federation (Terminal 4):

flwr run . supernode-deployment

Configuration

All settings are controlled via pyproject.toml. No code changes needed to run experiments.

Federation Settings

[tool.flwr.app.config]
num-server-rounds    = 200     # Total FL rounds
fraction-fit         = 0.3    # Fraction of clients sampled per round
fraction-evaluate    = 0.0    # Fraction of clients used for evaluation
min-fit-clients      = 2      # Minimum clients required to start a round
min-evaluate-clients = 2

Model Settings

audio-encoder-name = "microsoft/wavlm-large"   # HuggingFace model ID
llm-name           = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
connector-name     = "linear"                  # "linear", "linear-pool", or "cnn"
audio-enc-dim      = 1024
llm-dim            = 2048
use-lora           = true
lora-r             = 8
lora-alpha         = 16
finetune-encoder   = false

Training Settings

local-epochs          = 10     # Epochs per client per round
train-batch-size      = 4
train-batch-per-epoch = 200    # Steps per epoch (limits dataset length)
grad-accumulate-steps = 4
max-lr                = 0.0001

Learning Rate Decay

lr-decay-factor = 0.9    # Multiply LR by this value every N rounds
lr-decay-every  = 10     # Decay interval in rounds

The effective LR at round r is: max-lr × decay-factor ^ (r // decay-every)

Checkpoint Settings

checkpoint-dir        = "FL_SLAM_checkpoints"   # Directory for round checkpoints
checkpoint-offset     = 0                        # Add offset to round number in filenames
pretrained-checkpoint = ""                       # Path to .ckpt to resume from

Simulation Settings

[tool.flwr.federations.local-simulation]
options.num-supernodes = 316   # Total number of simulated clients

[tool.flwr.federations.local-simulation.options]
backend.client-resources.num-cpus = 1
backend.client-resources.num-gpus = 1

Metrics Tracked

MetricDescription
train_lossCross-entropy loss on local training data
val/lossValidation loss
val/werWord Error Rate on transcript predictions
val/genderGender classification accuracy
val/emotionEmotion classification accuracy
val/ageAge group classification accuracy
val/accentAccent classification accuracy
val/speech_activitySpeech activity detection accuracy

Results

Performance comparison of WavLM vs. Whisper encoders, measured in Word Error Rate (WER ↓) on LibriSpeech (LS) and Multilingual LibriSpeech (MLS) test sets. Central training serves as the upper bound.

SettingWavLM (Round=100) LSWavLM (Round=100) MLSWhisper (Round=40) LSWhisper (Round=40) MLS
Central Training6.118.46.416.4
FL Sample Cluster9.719.67.716.4

⭐ Central training is the upper bound (non-federated). Lower WER is better.

Key takeaway: WavLM with federated learning (FL Sample Cluster) achieves competitive WER on both benchmarks, with only a modest gap vs. central training — demonstrating that the federation does not significantly degrade model quality while preserving data privacy.


More Information


License

Apache License 2.0