A PyTorch-based GAN framework for bidirectional generation between neuroimaging data (fMRI, DTI) and behavioral measures, with a focus on robust training and reproducibility.
- Image → Behavior: Generate behavioral predictions from neuroimaging data (single or multiple modalities, voxelwise or ROI-wise)
- Behavior → Image: Generate synthetic neuroimaging data from behavioral scores or batteries of scores
- Optional self-attention mechanisms for capturing long-range dependencies
- GAN, lsGAN, and/or WGAN loss functions
- Flexible network architecture with optimizable hidden layers and other components
- Some support for conditional generation (cGAN is a work in progress)
- Nested cross-validation with Optuna (Bayesian)hyperparameter optimization
- Smart learning rate scheduling with warmup periods and plateau reduction
- Early stopping with customizable patience independently for generator and discriminator
- Gradient penalty and/or clipping for stability
- Multi-GPU support with parallel fold processing (but can default to CPU if no GPU is available)
pip install -r requirements.txt
Edit gan_settings.py
before running!
python gan_train.py
├── 🎯 Core Files
│ ├── gan_settings.py # Global settings and hyperparameters (start here!)
│ ├── gan_train.py # Main training script with nested CV
│ ├── gan_arch.py # GAN architecture definitions used by gan_train.py
│ ├── gan_eval.py # Standalone evaluation script (not used during training)
│ └── gan_trainer.py # Training orchestration and management
│
├── 🧮 Data Processing
│ ├── data_utils.py # Data loading and preprocessing
│ ├── process_utils.py # GPU/device management and process orchestration
│ └── concat_mat_beh.py # MAT and behavioral data handling
│
├── 🔧 Training Support
│ ├── training_utils.py # Early stopping, LR scheduling
│ ├── model_utils.py # Model initialization, checkpointing, and validation
│ ├── metrics.py # Loss functions used by gan_trainer.py
│ ├── visualization.py # Plotting and visualization tools
│ └── seed_manager.py # Random seed management
│
├── 🎛️ Hyperparameter Optimization
│ ├── hyperparameter_optimization.py # Optuna optimization used by gan_train.py
│ └── test_hyperparameter_optimization.py # Optimization tests
│
├── 🧪 Testing
│ ├── test_gan.py # Core GAN functionality tests
│ ├── test_data_processing.py # Data pipeline tests
│ └── test_gan_dimensions.py # Input/output dimension tests
│
├── 🧹 Utility
│ └── requirements.txt # Package dependencies
│
└── 📂 outputs/ # Generated during training
├── cache/ # Intermediate results
│ ├── data_cache.pkl # Preprocessed data
│ └── transform_info.pkl # Transform metadata
│
├── checkpoints/ # Model states
│ ├── outer_fold_{n}/
│ │ ├── inner_fold_{m}/
│ │ │ ├── generator_checkpoint.pt
│ │ │ └── discriminator_checkpoint.pt
│ │ └── best_model.pt
│ └── best_model.pt
│
├── logs/ # Training records
│ ├── random_seed.json
│ ├── settings.json
│ └── training_log.txt
│
├── plots/ # Visualizations
│ ├── outer_fold_{n}/
│ │ ├── loss_curves/
│ │ └── predictions/
│ └── final_results/
│
└── results/ # Evaluation outputs
├── outer_fold_{n}/
│ ├── predictions/
│ │ ├── behavior_to_image/ # When MODE="behavior_to_image"
│ │ │ ├── batch_{k}/
│ │ │ │ ├── comparison_subject{i}.png
│ │ │ │ └── generated_subject{i}.npy
│ │ │ └── average_images/
│ │ │ ├── average_comparison_sagittal.png
│ │ │ ├── average_comparison_coronal.png
│ │ │ └── average_comparison_axial.png
│ │ └── image_to_behavior/ # When MODE="image_to_behavior"
│ │ ├── predicted_scores.csv
│ │ └── score_comparisons/
│ │ ├── scatter_plots.png
│ │ └── error_dist.png
│ ├── true_data_fold{n}.npy
│ └── generated_data_fold{n}.npy
└── aggregate_results/
├── combined_true_data.npy
└── combined_generated_data.npy
- 🏗️ Generator & Discriminator Foundation
- Shared
BaseGANModule
with configurable dimensions - Smart normalization selection (batch/layer) and dynamicweight initialization strategiesdepending on loss function
- Dropout regularization (configurable rate)
- Other configurable components (e.g., depth, activation functions, GAN type)
- Shared
- 🔍 Multi-Head Self-Attention
- 4-head scaled dot-product attention
- Learnable attention strength (γ parameter)
- Strategic placement at network intervals
- Dimension-scaled transformations (d⁻⁰·⁵)
- 🛑 Early Stopping
Generator: Discriminator Validity ↑ (behavior_to_image) MSE ↓ (image_to_behavior) Discriminator: W-distance/accuracy ↑
- Auto-checkpoint at peak performance
- Configurable minimum epochs and change thresholds
- Different validation metrics used based on mode:
- behavior_to_image: Uses discriminator's real/fake validity scores during training, MS-SSIM for final evaluation
- image_to_behavior: Uses MSE between predicted and actual behavioral scores
- 📈 Advanced Scheduling
- Warmup → Plateau reduction
- Independent attention-specific optimization
- 🔒 Training Safeguards
- WGAN gradient penalty
- Norm-based gradient clipping
- Hybrid metric scoring (MSE + MS-SSIM)
- Multiple generations per input for point-estimates of performance
- Nested cross-validation for hyperparameter optimization
- 📊 Performance Tracking
MSE, RMSE, MAE → Lower is better MS-SSIM → Higher is better
- Multi-prediction correlation handling
- 📈 Analysis Pipeline
- Cross-validation aggregation
- Statistical analysis (correlation, MSE, MAE)
- Modality-specific evaluations
- automatic downsampling of brain images
- masking out empty voxels across the group of subjects analyzed
- reconstruction of outputs in original space
- 📥 Input Processing
- Structured .mat parsing
- Automated normalization with robust scaling for behavioral scores
- Modality-based feature extraction
- 🎲 Randomization Control
- Global seed management
- Framework-wide determinism
- Consistent cross-validation