Deep learning models for vessel segmentation in medical images. Implements U-Net and variants with a modular, configurable training pipeline.
- Multiple architectures: U-Net, Attention U-Net, ResNet U-Net
- Configurable: JSON config files for reproducible experiments
- Checkpointing: Save/resume training with full state preservation
- Validation: Automatic train/val split with early stopping on validation loss
- TensorBoard: Real-time training monitoring with loss curves, metrics, and sample predictions
- Evaluation: Dice, IoU, precision, recall, specificity metrics
- Visualization: Overlay predictions on images with TP/FP/FN highlighting
git clone <repository-url>
cd VesselSegmentation
pip install -r requirements.txtVesselSegmentation/
├── models/ # Model architectures
│ ├── blocks.py # Shared building blocks
│ ├── unet.py # Standard U-Net
│ ├── attention_unet.py # U-Net with attention gates
│ └── resnet_unet.py # U-Net with residual blocks
├── config/ # Configuration files
│ ├── config.py # Config dataclasses
│ ├── unet.json
│ ├── attention.json
│ └── resnet.json
├── scripts/ # Training and evaluation
│ ├── train.py
│ ├── evaluate.py
│ ├── visualize.py
│ ├── metrics.py
│ └── datasetloader.py
├── training_data/ # Dataset (not included)
│ ├── images/
│ └── 1st_manual/
├── results/ # Output weights and metrics
├── runs/ # TensorBoard logs
├── Makefile
└── requirements.txt
# Train U-Net (default)
make train-unet
# Train other architectures
make train-attention
make train-resnet
# Train all models
make train-all
# Custom training with overrides
python scripts/train.py --config config/unet.json --epochs 50 --lr 0.0001
# Adjust validation split (default: 20%)
python scripts/train.py --config config/unet.json --val-split 0.15
# Name your experiment for TensorBoard
python scripts/train.py --config config/unet.json --experiment my_experiment_v1Training produces:
results/<model>_best.pth- Best model weights (lowest validation loss)results/<model>_checkpoint.pth- Latest checkpoint for resumingruns/<experiment>/- TensorBoard logs
Monitor training in real-time with TensorBoard:
# Start TensorBoard server
make tensorboard
# Or with custom port
make tensorboard PORT=8080Then open http://localhost:6006 in your browser.
What's logged:
- Training and validation loss curves
- Validation metrics (Dice, IoU, accuracy, precision, recall, specificity)
- Sample predictions (input, ground truth, prediction) every 5 epochs
- Hyperparameters and final metrics
If training is interrupted, resume from the last checkpoint:
# Resume U-Net training
make resume-unet
# Resume other models
make resume-attention
make resume-resnet
# Resume with custom checkpoint
python scripts/train.py --config config/unet.json --resume results/unet_checkpoint.pth
# Resume and extend training to more epochs
python scripts/train.py --config config/unet.json --resume results/unet_checkpoint.pth --epochs 100Checkpoints preserve:
- Model weights
- Optimizer state (momentum, adaptive learning rates)
- Current epoch
- Best validation loss achieved
- Early stopping patience counter
# Evaluate on test set
make eval-unet
# Custom evaluation
python scripts/evaluate.py \
--weights results/unet_best.pth \
--model unet \
--images test_data/images \
--labels test_data/masks \
--output results/metrics.jsonpython scripts/evaluate.py \
--weights results/unet_best.pth \
--model unet \
--image path/to/image.png \
--output prediction.pngpython scripts/visualize.py \
--weights results/unet_best.pth \
--model unet \
--image path/to/image.png \
--label path/to/ground_truth.png \
--output visualization.pngVisualization shows: original | prediction overlay | ground truth overlay | TP(yellow)/FP(green)/FN(red)
| Model | Description |
|---|---|
| U-Net | Standard encoder-decoder with skip connections |
| Attention U-Net | Adds attention gates to focus on relevant features |
| ResNet U-Net | Uses residual blocks for better gradient flow |
All models use the same interface:
from models import UNet, AttentionUNet, ResNetUNet
model = UNet(input_channels=3, num_classes=1, feature_dims=[64, 128, 256, 512, 1024])Configs are JSON files with three sections:
{
"model": {
"name": "unet",
"input_channels": 3,
"num_classes": 1,
"feature_dims": [64, 128, 256, 512, 1024]
},
"training": {
"epochs": 10,
"batch_size": 8,
"learning_rate": 0.001,
"patience": 5
},
"data": {
"images_path": "training_data/images",
"labels_path": "training_data/1st_manual",
"output_dir": "results"
}
}Predefined configs: unet, attention, resnet, lightweight
| Metric | Description |
|---|---|
| Dice | F1 score for segmentation (2 * intersection / union) |
| IoU | Intersection over Union (Jaccard index) |
| Accuracy | Pixel-wise accuracy |
| Precision | True positives / predicted positives |
| Recall | True positives / actual positives (sensitivity) |
| Specificity | True negatives / actual negatives |
# Training
make train-unet # Train U-Net
make train-attention # Train Attention U-Net
make train-resnet # Train ResNet U-Net
make train-all # Train all models
# Resume training
make resume-unet # Resume U-Net from checkpoint
make resume-attention # Resume Attention U-Net
make resume-resnet # Resume ResNet U-Net
# Monitoring
make tensorboard # Launch TensorBoard (http://localhost:6006)
# Evaluation
make eval-unet # Evaluate U-Net
make eval-all # Evaluate all models
# Testing
make test # Run all tests
# Cleanup
make clean # Remove all generated files
make clean-checkpoints # Remove only checkpoint files
make clean-logs # Remove TensorBoard logs- Images: RGB images in PNG/JPG format
- Labels: Binary masks (0 = background, 255 = vessel)
- Images and labels must have matching filenames (sorted alphabetically)
MIT