Skip to content

Commit b5deb14

Browse files
committed
remove random crop augmentation and kornia altogether, and also ready for separate noise schedules per unet in the cascade
1 parent 032bbe5 commit b5deb14

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

Diff for: imagen_pytorch/imagen_pytorch.py

+32-40
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
from einops_exts import rearrange_many, repeat_many, check_shape
1818
from einops_exts.torch import EinopsToAndFrom
1919

20-
from kornia.filters import gaussian_blur2d
21-
import kornia.augmentation as K
22-
2320
from resize_right import resize
2421

2522
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
@@ -175,8 +172,8 @@ def linear_beta_schedule(timesteps):
175172
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
176173

177174

178-
class BaseGaussianDiffusion(nn.Module):
179-
def __init__(self, *, beta_schedule, timesteps, loss_type):
175+
class GaussianDiffusion(nn.Module):
176+
def __init__(self, *, beta_schedule, timesteps):
180177
super().__init__()
181178

182179
if beta_schedule == "cosine":
@@ -193,18 +190,6 @@ def __init__(self, *, beta_schedule, timesteps, loss_type):
193190
timesteps, = betas.shape
194191
self.num_timesteps = int(timesteps)
195192

196-
if loss_type == 'l1':
197-
loss_fn = F.l1_loss
198-
elif loss_type == 'l2':
199-
loss_fn = F.mse_loss
200-
elif loss_type == 'huber':
201-
loss_fn = F.smooth_l1_loss
202-
else:
203-
raise NotImplementedError()
204-
205-
self.loss_type = loss_type
206-
self.loss_fn = loss_fn
207-
208193
# register buffer helper function to cast double back to float
209194

210195
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32), persistent = False)
@@ -922,7 +907,7 @@ def forward(
922907

923908
return self.final_conv(x)
924909

925-
class Imagen(BaseGaussianDiffusion):
910+
class Imagen(nn.Module):
926911
def __init__(
927912
self,
928913
unets,
@@ -943,11 +928,26 @@ def __init__(
943928
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
944929
dynamic_thresholding_percentile = 0.9 # unsure what this was based on perusal of paper
945930
):
946-
super().__init__(
947-
beta_schedule = beta_schedule,
948-
timesteps = timesteps,
949-
loss_type = loss_type
950-
)
931+
super().__init__()
932+
933+
self.noise_scheduler = GaussianDiffusion(beta_schedule = beta_schedule, timesteps = timesteps)
934+
self.num_timesteps = self.noise_scheduler.num_timesteps
935+
936+
# loss
937+
938+
if loss_type == 'l1':
939+
loss_fn = F.l1_loss
940+
elif loss_type == 'l2':
941+
loss_fn = F.mse_loss
942+
elif loss_type == 'huber':
943+
loss_fn = F.smooth_l1_loss
944+
else:
945+
raise NotImplementedError()
946+
947+
self.loss_type = loss_type
948+
self.loss_fn = loss_fn
949+
950+
# conditioning hparams
951951

952952
self.condition_on_text = condition_on_text
953953
self.unconditional = not condition_on_text
@@ -1058,7 +1058,7 @@ def p_mean_variance(self, unet, x, t, text_embeds = None, text_mask = None, lowr
10581058
if learned_variance:
10591059
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
10601060

1061-
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
1061+
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
10621062

10631063
if clip_denoised:
10641064
# following pseudocode in appendix
@@ -1074,14 +1074,14 @@ def p_mean_variance(self, unet, x, t, text_embeds = None, text_mask = None, lowr
10741074
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
10751075
x_recon = x_recon.clamp(-s, s) / s
10761076

1077-
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1077+
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
10781078

10791079
if learned_variance:
10801080
# if learned variance, posterio variance and posterior log variance are predicted by the network
10811081
# by an interpolation of the max and min log beta values
10821082
# eq 15 - https://arxiv.org/abs/2102.09672
1083-
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
1084-
max_log = extract(torch.log(self.betas), t, x.shape)
1083+
min_log = extract(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
1084+
max_log = extract(torch.log(self.noise_scheduler.betas), t, x.shape)
10851085
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
10861086

10871087
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
@@ -1100,7 +1100,7 @@ def p_sample(self, unet, x, t, text_embeds = None, text_mask = None, cond_scale
11001100

11011101
@torch.no_grad()
11021102
def p_sample_loop(self, unet, shape, learned_variance = False, clip_denoised = True, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_scale = 1):
1103-
device = self.betas.device
1103+
device = self.noise_scheduler.betas.device
11041104

11051105
b = shape[0]
11061106
img = torch.randn(shape, device = device)
@@ -1133,15 +1133,15 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, lowres_aug_t
11331133

11341134
# get x_t
11351135

1136-
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
1136+
x_noisy = self.noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
11371137

11381138
# also noise the lowres conditioning image
11391139
# at sample time, they then fix the noise level of 0.1 - 0.3
11401140

11411141
lowres_cond_img_noisy = None
11421142
if exists(lowres_cond_img):
11431143
lowres_aug_times = default(lowres_aug_times, times)
1144-
lowres_cond_img_noisy = self.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
1144+
lowres_cond_img_noisy = self.noise_scheduler.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
11451145

11461146
# get prediction
11471147

@@ -1173,7 +1173,7 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, lowres_aug_t
11731173

11741174
# if learning the variance, also include the extra weight kl loss
11751175

1176-
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
1176+
true_mean, _, true_log_variance_clipped = self.noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
11771177
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
11781178

11791179
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -1238,7 +1238,7 @@ def sample(
12381238

12391239
if unet.lowres_cond:
12401240
lowres_cond_img = resize_image_to(img, image_size)
1241-
lowres_cond_img = self.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
1241+
lowres_cond_img = self.noise_scheduler.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
12421242

12431243
shape = (batch_size, self.channels, image_size, image_size)
12441244

@@ -1301,12 +1301,4 @@ def forward(
13011301

13021302
image = resize_image_to(image, target_image_size)
13031303

1304-
if exists(random_crop_size):
1305-
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
1306-
1307-
# make sure low res conditioner and image both get augmented the same way
1308-
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
1309-
image = aug(image)
1310-
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
1311-
13121304
return self.p_losses(unet, image, times, text_embeds = text_embeds, text_mask = text_masks, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, learned_variance = learned_variance)

Diff for: setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'imagen-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.22',
6+
version = '0.0.23',
77
license='MIT',
88
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
99
author = 'Phil Wang',
@@ -20,7 +20,6 @@
2020
install_requires=[
2121
'einops>=0.4',
2222
'einops-exts',
23-
'kornia',
2423
'numpy',
2524
'pydantic',
2625
'resize-right',

0 commit comments

Comments
 (0)