Skip to content

Commit 42d87b8

Browse files
committed
offer convenient way to return list of pillow images for saving
1 parent c468b4f commit 42d87b8

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

imagen_pytorch/imagen_pytorch.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,8 @@ def sample(
11431143
batch_size = 1,
11441144
cond_scale = 1.,
11451145
lowres_sample_noise_level = None,
1146-
stop_at_unet_number = None
1146+
stop_at_unet_number = None,
1147+
return_pil_images = False
11471148
):
11481149
device = next(self.parameters()).device
11491150

@@ -1196,7 +1197,11 @@ def sample(
11961197
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
11971198
break
11981199

1199-
return img
1200+
if not return_pil_images:
1201+
return img
1202+
1203+
pil_images = list(map(T.ToPILImage(), img.unbind(dim = 0)))
1204+
return pil_images # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
12001205

12011206
def p_losses(self, unet, x_start, times, *, noise_scheduler, lowres_cond_img = None, lowres_aug_times = None, text_embeds = None, text_mask = None, noise = None, learned_variance = False, clip_denoised = False):
12021207
noise = default(noise, lambda: torch.randn_like(x_start))

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'imagen-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.24',
6+
version = '0.0.25',
77
license='MIT',
88
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)