Deep learning pipeline for brain MRI analysis.
Train and evaluate deep learning models on 3D brain MRI data with support for self-supervised learning (SSL) pretraining, LoRA fine-tuning, and explainability visualizations using GradCAM and saliency maps.
- Multiple Training Modes - SFCN, linear probing, SSL fine-tuning, and LoRA
- Self-Supervised Learning - Pretrain on unlabeled brain MRI data
- Evaluation - ROC/PRC curves, confusion matrices, bootstrap confidence intervals
- Explainability - GradCAM and saliency maps with AAL atlas region analysis
- Preprocessing - DICOM to NIfTI conversion, bias correction, registration, skull stripping, resampling npy transformations
- Python 3.8+
- CUDA-compatible GPU
# Clone the repository
git clone <repo-url>
cd BrainTrain
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtYour CSV files should contain:
eid- Subject identifier (first column)label- Target variable (integer for classification, float for regression)
Example:
eid,label
sub001,1
sub002,0(If you have DCMs, run preprocessing.py to get npy tensors first.) After preprocessing:
- Format: NumPy arrays saved as
.npyfiles - Shape:
(96, 96, 96)for 3D MRI volumes - Naming: Filename must match
eidin CSV (e.g.,sub001.npy)
Edit config.py to set all your desired input and output paths:
TRAIN_COHORT = 'your-cohort-name'
TENSOR_DIR = f'../images/{TRAIN_COHORT}/npy96'
CSV_TRAIN = f'../data/{TRAIN_COHORT}/train/data.csv'Make sure to create differnt data and image folders for train and test cohorts.
python preprocess.pyPreprocessing pipeline:
- DICOM to NIfTI conversion
- N4 bias field correction
- Registration to MNI template
- Skull stripping
- Resampling to 96×96×96
# Train with default settings (LoRA fine-tuning)
python train.py
# Train with specific options
python train.py --mode lora --column label --gpu cuda:0Training modes:
sfcn- Train SFCN from scratchlinear- Linear probing (frozen SSL backbone)ssl-finetuned- Fine-tune SSL pretrained modellora- LoRA adaptation (parameter-efficient)
python test.pyGenerates:
- ROC and Precision-Recall curves with bootstrap CI
- Confusion matrix
- Classification metrics (accuracy, precision, recall, F1, AUC)
python heatmap.pyCreates:
- GradCAM or saliency maps overlaid on MRI scans
- Regional attribution analysis using AAL atlas
- Top-N most important brain regions per prediction
python ssl_train.pyTrain a self-supervised backbone on unlabeled brain MRI data. The pretrained model can be used for transfer learning in subsequent steps.
python features.pyExtract features from the self-supervised pretrained model and visualize them as uMaps or tSNEs.
BrainTrain/
├── utils/
│ ├── architectures/ # Neural network models
│ ├── dataloaders/ # Dataset loaders
│ ├── augmentations/ # SSL augmentations
│ └── ...
├── preprocess.py # MRI preprocessing
├── ssl_train.py # Self-supervised pretraining
├── train.py # Model training
├── test.py # Model evaluation
├── heatmap.py # Explainability visualization
└── config.py # Configuration
Results can be saved to parent directory (or whatever locaton is suitable):
../models/- Model checkpoints../scores/- Prediction scores../logs/- Training logs../evaluations/- Evaluation plots../explainability/- Heatmaps