Skip to content

creatis-myriad/cLDM_project

Repository files navigation

Controllable latent diffusion model to evaluate the performance of cardiac segmentation methods

Pipeline overview

Prerequisites

  • Python 3.10+

Installation

Create the Conda environment

  • Create a conda environment from the provided file:
    • conda env create -f environment.yaml
  • Activate it:
    • conda activate cLDM_env

Data preparation

  • Data come from the MYOSAIQ challenge and were put in one folder and divided by patients.
  • The D8 subset was not used during training.
  • In the code, D_metrics is a dictionnary with specific informations (like the metrics) obtained from the MYOSAIQ database.

Configuration files

  • Configurations files that were used to run the experiments can be found in the folder nn_lib/config/. It is configurated as follow:
    • Config_VanillaVAE.yaml is the main config file that will use other configuration files to run the VAE model.
    • The folder config/model contains the configuration files with the parameters for the each model.
    • The folder config/architecture contains the configuration files to select specific architecture depending on the model you want to run.
    • The folder config/processing has the path to load a dictionnary containing the metrics of all patients for all slices.
    • The folders config/dataset and config/datamodule contain the configuration files to create the dataloaders and use them to train the model.
  • When running a main configuration, the sub-configuration files used should have the same name.

Models used for the experiments

  • For strategy 1, the cLDM model is used. The conditioning (using cross-attention) was done with:
    • A vector of scalars (clinical attributes) derived from the segmentations (strategy 1.1).
    • The latent representation from the VanillaVAE model trained on images (strategy 1.2).
    • The latent representation from the ARVAE model trained on images and regularized with clinical attributes (strategy 1.3)
  • For strategy 2, the ControlNet architecture is employed with an LDM backbone.
  • For strategy 3, the cLDM_concat architecture is conditioned on 2D representation of segmentation masks obtained with the AE_KL model.

Figures from the paper

  • Figure 2 and 3 where obtained using the file fig_originalSeg_vs_generatedSeg.py. It needs the original segmentation as well as the segmentation derived from the nnU-Net model with synthetic images used as inputs.
  • To get Figure 2, we have chosen an arbitrary mask to illustrate our pipeline.

Pipeline overview

  • To get Figure 3, we have selected specific masks with relevant characteristics. Therefore, synthetic images were generated and conditioned with those masks, as illustrated in the figure. For the final row, a manual rotation of 90°, 180° and 270° were applied to the mask.

Pipeline overview

How to run

Below are the command lines to run the models:

  • VanillaVAE

    python train_VanillaVAE.py \
        +config_name=Config_VanillaVAE.yaml \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=500 \
        model.net.shape_data=[1,128,128] \
        model.net.lat_dims=8 \
        model.net.alpha=5 \
        model.net.beta=8e-3
  • ARVAE

    python train_ARVAE.py \
        +config_name=Config_ARVAE.yaml \
        \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=500 \
        model.net.shape_data=[1,128,128] \
        model.net.lat_dims=8 \
        model.net.alpha=5 \
        model.net.beta=8e-3 \
        model.net.gamma=3 \
        \
        +model.net.keys_cond_data=["z_vals","transmurality","endo_surface_length","infarct_size_2D"] \
  • AE_KL

    python train_AE_KL.py \
        +config_name=ConfigAE_KL.yaml \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=1000 \
        model.net.lat_dims=1 \
  • cLDM

    # Conditioning with Scalars
    python train_cLDM.py \
        +config_name=Config_cLDM.yaml \
        \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
        model.path_model_cond=null \
        \
        processing=processing_CompressLgeSegCond_Scalars \
        dataset=CompressLgeSegCond_Scalars_Dataset \
        datamodule=CompressLgeSegCond_Scalars_Datamodule \
        datamodule.keys_cond_data=["z_vals","transmurality","endo_surface_length","infarct_size_2D"] \
        \
        architecture/unets=unet_cLDM_light \
    
    # Conditioning with latent representation from VAE 
    python train_cLDM.py \
        +config_name=Config_cLDM.yaml \
        \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
        model.path_model_cond="/home/deleat/Documents/RomainD/Working_space/NN_models/training_Pytorch/training_VAE/training_LgeMyosaiq_v2/2025-01-06 10:13:45_106e_img_base" \
        \
        architecture/unets=unet_cLDM_light \
    
    # Conditioning with latent representation from ARVAE
    python nn_models/bin/train_cLDM.py \
        +config_name=Config_cLDM.yaml \
        \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
        model.path_model_cond="/home/deleat/Documents/RomainD/Working_space/NN_models/training_Pytorch/training_ARVAE/training_LgeMyosaiq_v2/2025-01-06 14:23:29_72e_img_base" \
        \
        architecture/unets=unet_cLDM_light \
  • LDM

    python train_LDM.py \
        +config_name=Config_LDM.yaml \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
        architecture/unets=unet_LDM_light \
  • ControlNet

    python train_ControlNet.py \
        +config_name=Config_ControlNet.yaml \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
  • cLDM_concat

    python train_cLDM_concat.py \
        +config_name=Config_cLDM_concat.yaml \
        model.train_params.num_workers=24 \
        model.train_params.batch_size=32 \
        model.train_params.max_epoch=5100 \
        \
        architecture/unets=unet_cLDM_concat_light \

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages