Skip to content

Latest commit

Β 

History

History
327 lines (255 loc) Β· 8.36 KB

File metadata and controls

327 lines (255 loc) Β· 8.36 KB

SynthRAD2023 CT-to-MRI Translation Setup Guide

Complete guide for training Brownian Bridge Diffusion Model on SynthRAD2023 dataset.

πŸ“‹ Prerequisites

Environment Setup

# Activate conda environment
conda activate ct2mri

# Verify installation
python -c "import torch; print(f'PyTorch: {torch.__version__}')"
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')"
python -c "import neptune; print('Neptune OK')"

Required Packages

Ensure these are installed in your ct2mri conda environment:

  • PyTorch >= 2.0 (with CUDA support)
  • neptune >= 1.0
  • albumentations
  • h5py
  • nibabel
  • tqdm
  • pyyaml
  • wandb (optional, legacy)

πŸ“ Dataset Structure

Your SynthRAD2023 dataset should be organized as follows:

/pscratch/sd/s/seojw/CT_to_MRI/Task1-2/brain/
β”œβ”€β”€ 1BA001/
β”‚   β”œβ”€β”€ ct_brain_crop_pad_(0,1).pt      # CT volume [D, H, W] in [0, 1]
β”‚   └── mri_brain_crop_pad_(0,1).pt     # MRI volume [D, H, W] in [0, 1]
β”œβ”€β”€ 1BA002/
β”‚   β”œβ”€β”€ ct_brain_crop_pad_(0,1).pt
β”‚   └── mri_brain_crop_pad_(0,1).pt
└── ...

Important Notes:

  • Files must be PyTorch tensors saved with torch.save()
  • Tensors shape: [Depth, Height, Width]
  • Values normalized to [0, 1] range
  • Subject IDs must contain "1B" for auto-detection

πŸš€ Quick Start

1. Training (SLURM)

# Submit training job
cd /pscratch/sd/s/seojw/CT_to_MRI/CT2MRI
sbatch scripts/train_synthrad.sh

2. Training (Direct)

python train_synthrad.py \
    --train \
    --config configs/BBDM_synthrad.yaml \
    --base_dir /pscratch/sd/s/seojw/CT_to_MRI/Task1-2/brain \
    --exp_name synthrad_bbdm_baseline \
    --batch 8 \
    --max_epoch 5000 \
    --gpu_ids 0

3. Resume Training

Training automatically resumes from last checkpoint if available:

# Manual resume
python train_synthrad.py \
    --train \
    --config configs/BBDM_synthrad.yaml \
    --resume_model results/SynthRAD2023_176/synthrad_bbdm_baseline/checkpoint/last_model.pth \
    --resume_optim results/SynthRAD2023_176/synthrad_bbdm_baseline/checkpoint/last_optim_sche.pth \
    --exp_name synthrad_bbdm_baseline \
    --gpu_ids 0

4. Testing/Inference

# Submit test job
sbatch scripts/test_synthrad.sh

# Or directly
python train_synthrad.py \
    --sample_to_eval \
    --config configs/BBDM_synthrad.yaml \
    --resume_model results/SynthRAD2023_176/synthrad_bbdm_baseline/checkpoint/last_model.pth \
    --base_dir /pscratch/sd/s/seojw/CT_to_MRI/Task1-2/brain \
    --gpu_ids 0

πŸ”§ Configuration

Key Parameters (configs/BBDM_synthrad.yaml)

Model Architecture:

  • num_timesteps: 1000 - Diffusion timesteps
  • sample_step: 50 - Sampling steps (inference)
  • mt_type: 'linear' - Marginal transformation (linear/sin)
  • objective: 'grad' - Training objective (grad/noise/ysubx)

Training:

  • n_epochs: 5000 - Maximum epochs
  • batch_size: 8 - Batch size per GPU
  • lr: 1e-4 - Learning rate
  • use_amp: True - Enable bfloat16 AMP
  • sample_interval: 200 - Image logging frequency
  • checkpoint_interval: 5 - Save checkpoint every N epochs

Data:

  • image_size: 176 - Input image resolution
  • channels: 3 - Multi-channel slices (center Β± 1)
  • augment: True - Random flip augmentation

Command Line Overrides

python train_synthrad.py \
    --train \
    --config configs/BBDM_synthrad.yaml \
    --HW 256 \                      # Change image size
    --batch 4 \                     # Change batch size
    --lr 5e-5 \                     # Change learning rate
    --max_epoch 10000 \             # Change max epochs
    --sample_step 25 \              # Faster sampling
    --exp_name custom_experiment    # Custom name

πŸ“Š Neptune Logging

Tracked Metrics

Training:

  • train/loss - Training loss (latent L1)
  • train/lr - Current learning rate
  • train/images/source_ct - Input CT slices
  • train/images/target_mri - Ground truth MRI
  • train/images/generated_mri - Generated MRI

Validation:

  • val/loss - Validation loss

Model:

  • model/total_params - Total parameters
  • model/trainable_params - Trainable parameters

Checkpoints:

  • checkpoints/epoch - Epoch checkpoints
  • checkpoints/last - Latest checkpoint

Accessing Results

  1. Go to https://app.neptune.ai
  2. Navigate to project: ejswjawnj/CT-to-MRI
  3. Find your experiment by name
  4. View metrics, images, and download checkpoints

πŸ“‚ Output Structure

results/SynthRAD2023_176/
└── synthrad_bbdm_baseline/
    β”œβ”€β”€ checkpoint/
    β”‚   β”œβ”€β”€ epoch_0005.pth
    β”‚   β”œβ”€β”€ epoch_0010.pth
    β”‚   β”œβ”€β”€ last_model.pth          # Latest model
    β”‚   └── config.yaml             # Saved config
    β”œβ”€β”€ log/                         # TensorBoard logs
    β”œβ”€β”€ image/                       # Training visualizations
    β”œβ”€β”€ sample/                      # Generated samples (training)
    └── sample_to_eval/              # Test-time generation
        β”œβ”€β”€ 1BA001_generated_mri.npy
        β”œβ”€β”€ 1BA002_generated_mri.npy
        └── ...

🎯 Usage Scenarios

Scenario 1: Initial Training

# 1. Check dataset
ls /pscratch/sd/s/seojw/CT_to_MRI/Task1-2/brain/

# 2. Submit job
sbatch scripts/train_synthrad.sh

# 3. Monitor
squeue -u $USER
tail -f logs/train_*.out

# 4. Check Neptune dashboard

Scenario 2: Hyperparameter Tuning

# Test different learning rates
for lr in 5e-5 1e-4 2e-4; do
    python train_synthrad.py \
        --train \
        --config configs/BBDM_synthrad.yaml \
        --lr $lr \
        --exp_name "synthrad_lr${lr}" \
        --max_epoch 100 \
        --gpu_ids 0
done

Scenario 3: Resume After Interruption

# Training automatically resumes from last checkpoint
sbatch scripts/train_synthrad.sh

Scenario 4: Generate Test Results

# 1. Ensure model checkpoint exists
ls results/SynthRAD2023_176/synthrad_bbdm_baseline/checkpoint/last_model.pth

# 2. Run inference
sbatch scripts/test_synthrad.sh

# 3. Check results
ls results/SynthRAD2023_176/synthrad_bbdm_baseline/sample_to_eval/

πŸ› Troubleshooting

Issue: Dataset not found

# Verify base_dir in config
grep base_dir configs/BBDM_synthrad.yaml

# Check subjects
ls /pscratch/sd/s/seojw/CT_to_MRI/Task1-2/brain/ | grep 1B | head -5

Issue: Neptune connection failed

# Test Neptune connection
python -c "import neptune; run = neptune.init_run(project='ejswjawnj/CT-to-MRI', api_token='YOUR_TOKEN'); run.stop()"

Issue: CUDA out of memory

# Reduce batch size
python train_synthrad.py --train --batch 4 ...

# Or reduce image size
python train_synthrad.py --train --HW 128 ...

Issue: Checkpoint not loading

# Check checkpoint content
python -c "import torch; ckpt = torch.load('path/to/checkpoint.pth'); print(ckpt.keys())"

πŸ”„ Data/Train Split

Default split: 90% train, 10% validation

  • Controlled in datasets/synthrad_dataset.py
  • Subjects sorted alphabetically, then split
  • Consistent across runs (deterministic)

To customize:

# In synthrad_dataset.py, line ~40
n_train = int(n_total * 0.9)  # Change 0.9 to desired ratio

⚑ Performance Tips

  1. Use bfloat16 AMP: Already enabled by default
  2. Optimize batch size: Balance between GPU memory and speed
  3. Reduce sample_interval: Less frequent image logging speeds up training
  4. Use pin_memory: Already enabled in DataLoader
  5. Multi-GPU: Set --gpu_ids 0,1,2,3 for DDP training

πŸ“ Checkpoint Information

Each checkpoint contains:

  • epoch: Current epoch number
  • model_state_dict: Model weights
  • optimizer_state_dict: Optimizer state
  • scheduler_state_dict: LR scheduler state
  • ema_state_dict: EMA weights (if enabled)
  • val_loss: Validation loss at save time
  • config: Full configuration

πŸŽ“ Citation

If you use this code, please cite:

@inproceedings{choo2024ct2mri,
  title={Slice-Consistent 3D Volumetric Brain CT-to-MRI Translation with 2D Brownian Bridge Diffusion Model},
  author={Choo et al.},
  booktitle={MICCAI},
  year={2024}
}

πŸ“§ Support


Last Updated: 2025-10-24 Version: 1.0