@addyk/flowernnunet
flwr new @addyk/flowernnunetFederated nnU-Net with Flower
This app implements federated learning for nnU-Net, enabling privacy-preserving medical image segmentation across multiple institutions using the Flower framework. It supports modality-aware aggregation across imaging types (CT, MR, PET, US), multi-dataset federation with heterogeneous anatomies, and nnU-Net v2's native training pipeline including 3D full-resolution segmentation with deep supervision.
Key Features
- ๐ฅ Modality-Aware Aggregation โ Automatically detects client imaging modalities (CT, MR, PET, US) and performs hierarchical aggregation: intra-modality first, then inter-modality weighted combination
- ๐ Multi-Dataset Federation โ Train across heterogeneous datasets with different anatomies, modalities, and label sets in a single federation
- ๐ง Native nnU-Net v2 Integration โ Uses nnU-Net's full training pipeline: nnUNetDataLoader3D, native augmentation transforms, deep supervision (6 output scales), and automatic architecture configuration
- ๐ Privacy-Preserving โ Only model parameters and dataset fingerprints are shared; raw medical imaging data never leaves the local site
- ๐ W&B Integration โ Optional Weights & Biases logging for experiment tracking across federated rounds
- ๐พ Model Checkpointing โ Automatic saving of best local and global models based on validation Dice scores
- ๐ฅ๏ธ GPU Support โ Full CUDA acceleration with automatic mixed precision training
Architecture
| Component | Description |
|---|---|
| server_app_modality.py | ModalityAwareFederatedStrategy โ multi-phase server with modality grouping and hierarchical aggregation |
| server_app.py | NnUNetFederatedStrategy โ standard FedAvg baseline strategy |
| client_app.py | NnUNet3DFullresClient โ handles fingerprint collection, local training, backbone parameter filtering, and model saving |
| task.py | FedNnUNetTrainer โ extends nnU-Net's nnUNetTrainer for federated scenarios with validation and PyTorch model export |
| wandb_integration.py | W&B logging utilities for federated experiment tracking |
| dataset_compatibility.py | Dataset validation and compatibility checking for multi-dataset federation |
| federation_config.py | Federation configuration management |
Federated Learning Process
- Fingerprint Phase (Round -2) โ Clients share dataset statistics (shapes, spacings, intensity properties, modality info)
- Initialization Phase (Round -1) โ Server merges fingerprints and distributes the global fingerprint + initial model parameters
- Training Phases (Round 0+) โ Iterative local training with native nnU-Net methods, backbone parameter extraction, and modality-aware global aggregation
Aggregation Strategies
- Modality-Aware (default): CT clients aggregate โ CT model; MR clients aggregate โ MR model; weighted combination โ global model
- Standard FedAvg: Traditional weighted average by number of training examples across all clients
Fetch the App
Install Flower:
pip install flwr
Fetch the app:
flwr new @addyk/flowernnunet
This will create a new directory called flowernnunet with the following structure:
flowernnunet
โโโ flowernnunet
โ โโโ __init__.py
โ โโโ client_app.py # Defines your ClientApp
โ โโโ server_app.py # Basic federated strategy (FedAvg)
โ โโโ server_app_modality.py # Modality-aware federated strategy
โ โโโ task.py # FedNnUNetTrainer (extends nnUNet)
โ โโโ wandb_integration.py # W&B logging support
โ โโโ dataset_compatibility.py
โ โโโ federation_config.py
โโโ pyproject.toml # Project metadata and Flower configs
โโโ README.md
โโโ DEPLOYMENT_GUIDE.md # Detailed deployment instructions
โโโ MULTI_DATASET_GUIDE.md # Multi-dataset federation guide
โโโ setup_flwr_nnunet.sh # Automated setup & preprocessing script
โโโ run_federated_deployment.sh # Automated deployment script
Prerequisites
- Python 3.10+ with conda or pip
- nnU-Net v2 installed and configured (pip install nnunetv2)
- Preprocessed data in nnU-Net's standard format (.npz/.b2nd + .pkl)
- GPU with CUDA support (required for nnU-Net training)
Environment Setup
# Set required nnU-Net paths export nnUNet_raw="/path/to/nnUNet_raw" export nnUNet_preprocessed="/path/to/nnUNet_preprocessed" export nnUNet_results="/path/to/nnUNet_results" # Optional: Set model saving directory export OUTPUT_ROOT="./federated_models"
Medical Segmentation Decathlon (MSD) Datasets
This app works with any nnU-Net-preprocessed dataset. The Medical Segmentation Decathlon provides 10 benchmark datasets ideal for testing federated scenarios:
| Task | Dataset | Modality | Targets | Train/Test | Download |
|---|---|---|---|---|---|
| Task01 | BrainTumour | MR (FLAIR, T1w, T1gd, T2w) | Glioma subregions | 484 / 266 | Download |
| Task02 | Heart | MR (Mono) | Left atrium | 20 / 10 | Download |
| Task03 | Liver | CT (Portal venous) | Liver + tumors | 131 / 70 | Download |
| Task04 | Hippocampus | MR (Mono) | Hippocampus head/body | 260 / 130 | Download |
| Task05 | Prostate | MR (T2, ADC) | Central gland + peripheral zone | 32 / 16 | Download |
| Task06 | Lung | CT | Lung tumors | 63 / 32 | Download |
| Task07 | Pancreas | CT (Portal venous) | Pancreas + tumors | 281 / 139 | Download |
| Task08 | HepaticVessel | CT | Hepatic vessels + tumors | 303 / 140 | Download |
| Task09 | Spleen | CT (Portal venous) | Spleen | 41 / 20 | Download |
| Task10 | Colon | CT (Portal venous) | Colon cancer | 126 / 64 | Download |
Quick Dataset Setup
Download and preprocess a dataset (e.g., Spleen):
# Download wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar tar -xf Task09_Spleen.tar # Move to nnUNet raw directory and preprocess # (copy Task09_Spleen into $nnUNet_raw/Dataset009_Spleen following nnU-Net conventions) nnUNetv2_plan_and_preprocess -d 9 --verify_dataset_integrity
Or use the automated setup script:
bash setup_flwr_nnunet.sh --dataset Task09_Spleen --name flwr-nnunet-demo
Multi-Dataset Federated Scenarios
For testing cross-modality federation, try combining:
- CT datasets: Spleen (Task09) + Liver (Task03) โ same modality, different anatomy
- Cross-modality: Prostate MR (Task05) + Spleen CT (Task09) โ different modalities with modality-aware aggregation
Run the App
Run with the Simulation Engine
Install the dependencies defined in pyproject.toml as well as the flowernnunet package:
cd flowernnunet && pip install -e .
Set environment variables for nnU-Net:
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed" export TASK_NAME="Dataset009_Spleen"
Run with default settings:
flwr run .
Override settings:
flwr run . --run-config "num-server-rounds=5"
Note: Simulation runs all clients on the same machine. For large 3D medical datasets, ensure sufficient GPU memory or reduce options.num-supernodes in pyproject.toml.
Run with the Deployment Engine
๐ For detailed deployment instructions, see DEPLOYMENT_GUIDE.md.
1. Start the SuperLink (Terminal 1):
flower-superlink --insecure
2. Start SuperNodes (separate terminals):
# Terminal 2: First SuperNode flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config "partition-id=0" # Terminal 3: Second SuperNode flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9095 \ --node-config "partition-id=1"
You can also specify datasets and folds per SuperNode:
flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config 'partition-id=0 dataset-name="Dataset005_Prostate" fold=0'
3. Run the federation (Terminal 4):
flwr run . supernode-deployment
Automated Deployment Script
For convenience, use the all-in-one deployment script:
# Single dataset bash run_federated_deployment.sh \ --dataset Dataset009_Spleen --clients 2 --rounds 3 --local-epochs 2 --validate # Multi-dataset with modality-aware aggregation bash run_federated_deployment.sh \ --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \ --clients 2 --rounds 5 --enable-modality-aggregation # Custom modality weights bash run_federated_deployment.sh \ --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \ --clients 2 --rounds 5 --enable-modality-aggregation \ --modality-weights '{"CT": 0.6, "MR": 0.4}'
All deployment script arguments
| Category | Argument | Default | Description |
|---|---|---|---|
| Dataset | --dataset | โ | Single dataset for all clients |
| --client-datasets | โ | JSON mapping of client IDs to datasets | |
| --list-datasets | โ | List available preprocessed datasets | |
| --validate-datasets | โ | Validate multi-dataset compatibility | |
| Training | --clients | 2 | Number of federated clients |
| --rounds | 3 | Number of federated rounds | |
| --local-epochs | 2 | Local epochs per client per round | |
| Modality | --enable-modality-aggregation | false | Enable modality-aware aggregation |
| --modality-weights | โ | JSON modality weight overrides | |
| Deployment | --mode | run | superlink, supernode, or run |
| --superlink-host | 127.0.0.1 | SuperLink host address | |
| Validation | --validate | true | Enable validation during training |
| --no-validate | โ | Skip validation | |
| Output | --output-dir | federated_models | Model output directory |
| --save-frequency | 1 | Save models every N rounds | |
| System | --gpu | 0 | GPU device ID |
Configuration
Federation Settings (pyproject.toml)
[tool.flwr.app.config] num-server-rounds = 100 # Number of training rounds fraction-fit = 1.0 # Fraction of clients for training fraction-evaluate = 0.0 # Fraction of clients for evaluation [tool.flwr.federations.local-simulation] options.num-supernodes = 1 # Number of simulated clients [tool.flwr.federations.supernode-deployment] address = "127.0.0.1:9093" insecure = true options.num-supernodes = 2 options.enable-modality-aggregation = true
GPU Configuration
# Enable specific GPU export CUDA_VISIBLE_DEVICES=0 # Set model saving directory export OUTPUT_ROOT="./federated_models"
In task.py, the default device is CPU for Ray compatibility. For GPU training, modify:
device = torch.device("cuda") # Instead of "cpu"
Data Format Support
| Format | Extension | Description |
|---|---|---|
| B2ND | .b2nd | Compressed Blosc2 format (preferred, requires blosc2) |
| NPZ | .npz | Standard NumPy compressed format (legacy) |
| Properties | .pkl | Medical imaging metadata per case |
Expected preprocessed data structure:
nnUNet_preprocessed/DatasetXXX_Name/
โโโ dataset.json
โโโ dataset_fingerprint.json
โโโ nnUNetPlans.json
โโโ splits_final.json
โโโ nnUNetPlans_3d_fullres/
โโโ case_001.b2nd (or .npz)
โโโ case_001_seg.b2nd
โโโ case_001.pkl
โโโ ...
More Information
- DEPLOYMENT_GUIDE.md โ Step-by-step SuperLink/SuperNode deployment
- MULTI_DATASET_GUIDE.md โ Multi-dataset federation with heterogeneous data
- nnU-Net v2 โ Medical image segmentation framework
- Flower Framework โ Federated learning infrastructure
- Medical Segmentation Decathlon โ Benchmark datasets
Acknowledgments
- nnU-Net v2 for medical image segmentation
- Flower Framework for federated learning infrastructure
- Kaapana โ federated learning concepts used in federating nnU-Net
License
Apache License 2.0