Skip to content

Commit 2d41655

Browse files
virginiafdezVirginia FernandezericspodKumoLiu
authored
Classifier free guidance (#8460)
Fixes #8448 ### Description Classifier-free guidance scale can be used in the sampling methods of diffusion models to strengthen the conditioning. It's been used in one of the generative MONAI tutorials, but the actual Inferers (DiffusionInferer, and ControlNetDiffusionInferer) do not support this feature. This means that, whenever users want to use CFG, they have to either copy the Inferer object or write their own sampling method. This PR incorporates classifier-free guidance into the inferer objects by modifying their sampling method and adding an argument cfg to the sampling method to control this. This should not change the default behaviour (cfg=None), although some rewriting has been necessary. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes (to test_controlnet_inferers.py, test_diffusion_inferer.py and test_latent_diffusion_inferer.py). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Virginia Fernandez <virginia.fernandez.kcl.ac.uk>
1 parent f85135b commit 2d41655

File tree

4 files changed

+176
-19
lines changed

4 files changed

+176
-19
lines changed

monai/inferers/inferer.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def sample(
839839
mode: str = "crossattn",
840840
verbose: bool = True,
841841
seg: torch.Tensor | None = None,
842+
cfg: float | None = None,
842843
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
843844
"""
844845
Args:
@@ -851,6 +852,7 @@ def sample(
851852
mode: Conditioning mode for the network.
852853
verbose: if true, prints the progression bar of the sampling process.
853854
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
855+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
854856
"""
855857
if mode not in ["crossattn", "concat"]:
856858
raise NotImplementedError(f"{mode} condition is not supported")
@@ -877,15 +879,31 @@ def sample(
877879
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
878880
else diffusion_model
879881
)
880-
if mode == "concat" and conditioning is not None:
881-
model_input = torch.cat([image, conditioning], dim=1)
882+
if (
883+
cfg is not None
884+
): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
885+
model_input = torch.cat([image] * 2, dim=0)
886+
if conditioning is not None:
887+
uncondition = torch.ones_like(conditioning)
888+
uncondition.fill_(-1)
889+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
890+
else:
891+
conditioning_input = None
892+
else:
893+
model_input = image
894+
conditioning_input = conditioning
895+
if mode == "concat" and conditioning_input is not None:
896+
model_input = torch.cat([model_input, conditioning_input], dim=1)
882897
model_output = diffusion_model(
883898
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
884899
)
885900
else:
886901
model_output = diffusion_model(
887-
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
902+
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input
888903
)
904+
if cfg is not None:
905+
model_output_uncond, model_output_cond = model_output.chunk(2)
906+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
889907

890908
# 2. compute previous image: x_t -> x_t-1
891909
if not isinstance(scheduler, RFlowScheduler):
@@ -1166,6 +1184,7 @@ def sample( # type: ignore[override]
11661184
mode: str = "crossattn",
11671185
verbose: bool = True,
11681186
seg: torch.Tensor | None = None,
1187+
cfg: float | None = None,
11691188
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
11701189
"""
11711190
Args:
@@ -1180,6 +1199,7 @@ def sample( # type: ignore[override]
11801199
verbose: if true, prints the progression bar of the sampling process.
11811200
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
11821201
is instance of SPADEAutoencoderKL, segmentation must be provided.
1202+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
11831203
"""
11841204

11851205
if (
@@ -1203,6 +1223,7 @@ def sample( # type: ignore[override]
12031223
mode=mode,
12041224
verbose=verbose,
12051225
seg=seg,
1226+
cfg=cfg,
12061227
)
12071228

12081229
if save_intermediates:
@@ -1381,6 +1402,7 @@ def sample( # type: ignore[override]
13811402
mode: str = "crossattn",
13821403
verbose: bool = True,
13831404
seg: torch.Tensor | None = None,
1405+
cfg: float | None = None,
13841406
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
13851407
"""
13861408
Args:
@@ -1395,6 +1417,7 @@ def sample( # type: ignore[override]
13951417
mode: Conditioning mode for the network.
13961418
verbose: if true, prints the progression bar of the sampling process.
13971419
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1420+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
13981421
"""
13991422
if mode not in ["crossattn", "concat"]:
14001423
raise NotImplementedError(f"{mode} condition is not supported")
@@ -1413,14 +1436,31 @@ def sample( # type: ignore[override]
14131436
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
14141437
intermediates = []
14151438

1439+
if cfg is not None:
1440+
cn_cond = torch.cat([cn_cond] * 2, dim=0)
1441+
14161442
for t, next_t in progress_bar:
1443+
# Controlnet prediction
1444+
if cfg is not None:
1445+
model_input = torch.cat([image] * 2, dim=0)
1446+
if conditioning is not None:
1447+
uncondition = torch.ones_like(conditioning)
1448+
uncondition.fill_(-1)
1449+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
1450+
else:
1451+
conditioning_input = None
1452+
else:
1453+
model_input = image
1454+
conditioning_input = conditioning
1455+
1456+
# Diffusion model prediction
14171457
diffuse = diffusion_model
14181458
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14191459
diffuse = partial(diffusion_model, seg=seg)
14201460

1421-
if mode == "concat" and conditioning is not None:
1461+
if mode == "concat" and conditioning_input is not None:
14221462
# 1. Conditioning
1423-
model_input = torch.cat([image, conditioning], dim=1)
1463+
model_input = torch.cat([model_input, conditioning_input], dim=1)
14241464
# 2. ControlNet forward
14251465
down_block_res_samples, mid_block_res_sample = controlnet(
14261466
x=model_input,
@@ -1437,20 +1477,28 @@ def sample( # type: ignore[override]
14371477
mid_block_additional_residual=mid_block_res_sample,
14381478
)
14391479
else:
1480+
# 1. Controlnet forward
14401481
down_block_res_samples, mid_block_res_sample = controlnet(
1441-
x=image,
1482+
x=model_input,
14421483
timesteps=torch.Tensor((t,)).to(input_noise.device),
14431484
controlnet_cond=cn_cond,
1444-
context=conditioning,
1485+
context=conditioning_input,
14451486
)
1487+
# 2. predict noise model_output
14461488
model_output = diffuse(
1447-
image,
1489+
model_input,
14481490
timesteps=torch.Tensor((t,)).to(input_noise.device),
1449-
context=conditioning,
1491+
context=conditioning_input,
14501492
down_block_additional_residuals=down_block_res_samples,
14511493
mid_block_additional_residual=mid_block_res_sample,
14521494
)
14531495

1496+
# If classifier-free guidance isn't None, we split and compute the weighting between
1497+
# conditioned and unconditioned output.
1498+
if cfg is not None:
1499+
model_output_uncond, model_output_cond = model_output.chunk(2)
1500+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
1501+
14541502
# 3. compute previous image: x_t -> x_t-1
14551503
if not isinstance(scheduler, RFlowScheduler):
14561504
image, _ = scheduler.step(model_output, t, image) # type: ignore
@@ -1714,6 +1762,7 @@ def sample( # type: ignore[override]
17141762
mode: str = "crossattn",
17151763
verbose: bool = True,
17161764
seg: torch.Tensor | None = None,
1765+
cfg: float | None = None,
17171766
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
17181767
"""
17191768
Args:
@@ -1730,6 +1779,7 @@ def sample( # type: ignore[override]
17301779
verbose: if true, prints the progression bar of the sampling process.
17311780
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
17321781
is instance of SPADEAutoencoderKL, segmentation must be provided.
1782+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
17331783
"""
17341784

17351785
if (
@@ -1757,6 +1807,7 @@ def sample( # type: ignore[override]
17571807
mode=mode,
17581808
verbose=verbose,
17591809
seg=seg,
1810+
cfg=cfg,
17601811
)
17611812

17621813
if save_intermediates:

tests/inferers/test_controlnet_inferers.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,20 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape
482482
scheduler = DDPMScheduler(num_train_timesteps=10)
483483
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
484484
scheduler.set_timesteps(num_inference_steps=10)
485-
sample, intermediates = inferer.sample(
486-
input_noise=noise,
487-
diffusion_model=model,
488-
scheduler=scheduler,
489-
controlnet=controlnet,
490-
cn_cond=mask,
491-
save_intermediates=True,
492-
intermediate_steps=1,
493-
)
494-
self.assertEqual(len(intermediates), 10)
485+
486+
for cfg in [5, None]:
487+
sample, intermediates = inferer.sample(
488+
input_noise=noise,
489+
diffusion_model=model,
490+
scheduler=scheduler,
491+
controlnet=controlnet,
492+
cn_cond=mask,
493+
save_intermediates=True,
494+
intermediate_steps=1,
495+
cfg=cfg,
496+
)
497+
498+
self.assertEqual(len(intermediates), 10)
495499

496500
@parameterized.expand(CNDM_TEST_CASES)
497501
@skipUnless(has_einops, "Requires einops")

tests/inferers/test_diffusion_inferer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,27 @@ def test_sample_intermediates(self, model_params, input_shape):
8888
)
8989
self.assertEqual(len(intermediates), 10)
9090

91+
@parameterized.expand(TEST_CASES)
92+
@skipUnless(has_einops, "Requires einops")
93+
def test_sample_cfg(self, model_params, input_shape):
94+
model = DiffusionModelUNet(**model_params)
95+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
96+
model.to(device)
97+
model.eval()
98+
noise = torch.randn(input_shape).to(device)
99+
scheduler = DDPMScheduler(num_train_timesteps=10)
100+
inferer = DiffusionInferer(scheduler=scheduler)
101+
scheduler.set_timesteps(num_inference_steps=10)
102+
sample, intermediates = inferer.sample(
103+
input_noise=noise,
104+
diffusion_model=model,
105+
scheduler=scheduler,
106+
save_intermediates=True,
107+
intermediate_steps=1,
108+
cfg=5,
109+
)
110+
self.assertEqual(sample.shape, noise.shape)
111+
91112
@parameterized.expand(TEST_CASES)
92113
@skipUnless(has_einops, "Requires einops")
93114
def test_ddpm_sampler(self, model_params, input_shape):
@@ -244,6 +265,38 @@ def test_sampler_conditioned_concat(self, model_params, input_shape):
244265
)
245266
self.assertEqual(len(intermediates), 10)
246267

268+
@parameterized.expand(TEST_CASES)
269+
@skipUnless(has_einops, "Requires einops")
270+
def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):
271+
# copy the model_params dict to prevent from modifying test cases
272+
model_params = model_params.copy()
273+
n_concat_channel = 2
274+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
275+
model_params["cross_attention_dim"] = None
276+
model_params["with_conditioning"] = False
277+
model = DiffusionModelUNet(**model_params)
278+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
279+
model.to(device)
280+
model.eval()
281+
noise = torch.randn(input_shape).to(device)
282+
conditioning_shape = list(input_shape)
283+
conditioning_shape[1] = n_concat_channel
284+
conditioning = torch.randn(conditioning_shape).to(device)
285+
scheduler = DDIMScheduler(num_train_timesteps=1000)
286+
inferer = DiffusionInferer(scheduler=scheduler)
287+
scheduler.set_timesteps(num_inference_steps=10)
288+
sample, intermediates = inferer.sample(
289+
input_noise=noise,
290+
diffusion_model=model,
291+
scheduler=scheduler,
292+
save_intermediates=True,
293+
intermediate_steps=1,
294+
conditioning=conditioning,
295+
mode="concat",
296+
cfg=5,
297+
)
298+
self.assertEqual(len(intermediates), 10)
299+
247300
@parameterized.expand(TEST_CASES)
248301
@skipUnless(has_einops, "Requires einops")
249302
def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,55 @@ def test_sample_shape(
414414
)
415415
self.assertEqual(sample.shape, input_shape)
416416

417+
@parameterized.expand(TEST_CASES)
418+
@skipUnless(has_einops, "Requires einops")
419+
def test_sample_shape_with_cfg(
420+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
421+
):
422+
stage_1 = None
423+
424+
if ae_model_type == "AutoencoderKL":
425+
stage_1 = AutoencoderKL(**autoencoder_params)
426+
if ae_model_type == "VQVAE":
427+
stage_1 = VQVAE(**autoencoder_params)
428+
if dm_model_type == "SPADEDiffusionModelUNet":
429+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
430+
else:
431+
stage_2 = DiffusionModelUNet(**stage_2_params)
432+
433+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
434+
stage_1.to(device)
435+
stage_2.to(device)
436+
stage_1.eval()
437+
stage_2.eval()
438+
439+
noise = torch.randn(latent_shape).to(device)
440+
441+
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
442+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
443+
scheduler.set_timesteps(num_inference_steps=10)
444+
445+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
446+
input_shape_seg = list(input_shape)
447+
if "label_nc" in stage_2_params.keys():
448+
input_shape_seg[1] = stage_2_params["label_nc"]
449+
else:
450+
input_shape_seg[1] = autoencoder_params["label_nc"]
451+
input_seg = torch.randn(input_shape_seg).to(device)
452+
sample = inferer.sample(
453+
input_noise=noise,
454+
autoencoder_model=stage_1,
455+
diffusion_model=stage_2,
456+
scheduler=scheduler,
457+
seg=input_seg,
458+
cfg=5,
459+
)
460+
else:
461+
sample = inferer.sample(
462+
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5
463+
)
464+
self.assertEqual(sample.shape, input_shape)
465+
417466
@parameterized.expand(TEST_CASES)
418467
@skipUnless(has_einops, "Requires einops")
419468
def test_sample_intermediates(

0 commit comments

Comments
 (0)