17
17
from einops_exts import rearrange_many , repeat_many , check_shape
18
18
from einops_exts .torch import EinopsToAndFrom
19
19
20
- from kornia .filters import gaussian_blur2d
21
- import kornia .augmentation as K
22
-
23
20
from resize_right import resize
24
21
25
22
from imagen_pytorch .t5 import t5_encode_text , get_encoded_dim , DEFAULT_T5_NAME
@@ -175,8 +172,8 @@ def linear_beta_schedule(timesteps):
175
172
return torch .linspace (beta_start , beta_end , timesteps , dtype = torch .float64 )
176
173
177
174
178
- class BaseGaussianDiffusion (nn .Module ):
179
- def __init__ (self , * , beta_schedule , timesteps , loss_type ):
175
+ class GaussianDiffusion (nn .Module ):
176
+ def __init__ (self , * , beta_schedule , timesteps ):
180
177
super ().__init__ ()
181
178
182
179
if beta_schedule == "cosine" :
@@ -193,18 +190,6 @@ def __init__(self, *, beta_schedule, timesteps, loss_type):
193
190
timesteps , = betas .shape
194
191
self .num_timesteps = int (timesteps )
195
192
196
- if loss_type == 'l1' :
197
- loss_fn = F .l1_loss
198
- elif loss_type == 'l2' :
199
- loss_fn = F .mse_loss
200
- elif loss_type == 'huber' :
201
- loss_fn = F .smooth_l1_loss
202
- else :
203
- raise NotImplementedError ()
204
-
205
- self .loss_type = loss_type
206
- self .loss_fn = loss_fn
207
-
208
193
# register buffer helper function to cast double back to float
209
194
210
195
register_buffer = lambda name , val : self .register_buffer (name , val .to (torch .float32 ), persistent = False )
@@ -922,7 +907,7 @@ def forward(
922
907
923
908
return self .final_conv (x )
924
909
925
- class Imagen (BaseGaussianDiffusion ):
910
+ class Imagen (nn . Module ):
926
911
def __init__ (
927
912
self ,
928
913
unets ,
@@ -943,11 +928,26 @@ def __init__(
943
928
auto_normalize_img = True , # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
944
929
dynamic_thresholding_percentile = 0.9 # unsure what this was based on perusal of paper
945
930
):
946
- super ().__init__ (
947
- beta_schedule = beta_schedule ,
948
- timesteps = timesteps ,
949
- loss_type = loss_type
950
- )
931
+ super ().__init__ ()
932
+
933
+ self .noise_scheduler = GaussianDiffusion (beta_schedule = beta_schedule , timesteps = timesteps )
934
+ self .num_timesteps = self .noise_scheduler .num_timesteps
935
+
936
+ # loss
937
+
938
+ if loss_type == 'l1' :
939
+ loss_fn = F .l1_loss
940
+ elif loss_type == 'l2' :
941
+ loss_fn = F .mse_loss
942
+ elif loss_type == 'huber' :
943
+ loss_fn = F .smooth_l1_loss
944
+ else :
945
+ raise NotImplementedError ()
946
+
947
+ self .loss_type = loss_type
948
+ self .loss_fn = loss_fn
949
+
950
+ # conditioning hparams
951
951
952
952
self .condition_on_text = condition_on_text
953
953
self .unconditional = not condition_on_text
@@ -1058,7 +1058,7 @@ def p_mean_variance(self, unet, x, t, text_embeds = None, text_mask = None, lowr
1058
1058
if learned_variance :
1059
1059
pred , var_interp_frac_unnormalized = pred .chunk (2 , dim = 1 )
1060
1060
1061
- x_recon = self .predict_start_from_noise (x , t = t , noise = pred )
1061
+ x_recon = self .noise_scheduler . predict_start_from_noise (x , t = t , noise = pred )
1062
1062
1063
1063
if clip_denoised :
1064
1064
# following pseudocode in appendix
@@ -1074,14 +1074,14 @@ def p_mean_variance(self, unet, x, t, text_embeds = None, text_mask = None, lowr
1074
1074
s = s .view (- 1 , * ((1 ,) * (x_recon .ndim - 1 )))
1075
1075
x_recon = x_recon .clamp (- s , s ) / s
1076
1076
1077
- model_mean , posterior_variance , posterior_log_variance = self .q_posterior (x_start = x_recon , x_t = x , t = t )
1077
+ model_mean , posterior_variance , posterior_log_variance = self .noise_scheduler . q_posterior (x_start = x_recon , x_t = x , t = t )
1078
1078
1079
1079
if learned_variance :
1080
1080
# if learned variance, posterio variance and posterior log variance are predicted by the network
1081
1081
# by an interpolation of the max and min log beta values
1082
1082
# eq 15 - https://arxiv.org/abs/2102.09672
1083
- min_log = extract (self .posterior_log_variance_clipped , t , x .shape )
1084
- max_log = extract (torch .log (self .betas ), t , x .shape )
1083
+ min_log = extract (self .noise_scheduler . posterior_log_variance_clipped , t , x .shape )
1084
+ max_log = extract (torch .log (self .noise_scheduler . betas ), t , x .shape )
1085
1085
var_interp_frac = unnormalize_zero_to_one (var_interp_frac_unnormalized )
1086
1086
1087
1087
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac ) * min_log
@@ -1100,7 +1100,7 @@ def p_sample(self, unet, x, t, text_embeds = None, text_mask = None, cond_scale
1100
1100
1101
1101
@torch .no_grad ()
1102
1102
def p_sample_loop (self , unet , shape , learned_variance = False , clip_denoised = True , lowres_cond_img = None , lowres_noise_times = None , text_embeds = None , text_mask = None , cond_scale = 1 ):
1103
- device = self .betas .device
1103
+ device = self .noise_scheduler . betas .device
1104
1104
1105
1105
b = shape [0 ]
1106
1106
img = torch .randn (shape , device = device )
@@ -1133,15 +1133,15 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, lowres_aug_t
1133
1133
1134
1134
# get x_t
1135
1135
1136
- x_noisy = self .q_sample (x_start = x_start , t = times , noise = noise )
1136
+ x_noisy = self .noise_scheduler . q_sample (x_start = x_start , t = times , noise = noise )
1137
1137
1138
1138
# also noise the lowres conditioning image
1139
1139
# at sample time, they then fix the noise level of 0.1 - 0.3
1140
1140
1141
1141
lowres_cond_img_noisy = None
1142
1142
if exists (lowres_cond_img ):
1143
1143
lowres_aug_times = default (lowres_aug_times , times )
1144
- lowres_cond_img_noisy = self .q_sample (x_start = lowres_cond_img , t = lowres_aug_times , noise = torch .randn_like (lowres_cond_img ))
1144
+ lowres_cond_img_noisy = self .noise_scheduler . q_sample (x_start = lowres_cond_img , t = lowres_aug_times , noise = torch .randn_like (lowres_cond_img ))
1145
1145
1146
1146
# get prediction
1147
1147
@@ -1173,7 +1173,7 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, lowres_aug_t
1173
1173
1174
1174
# if learning the variance, also include the extra weight kl loss
1175
1175
1176
- true_mean , _ , true_log_variance_clipped = self .q_posterior (x_start = x_start , x_t = x_noisy , t = times )
1176
+ true_mean , _ , true_log_variance_clipped = self .noise_scheduler . q_posterior (x_start = x_start , x_t = x_noisy , t = times )
1177
1177
model_mean , _ , model_log_variance = self .p_mean_variance (unet , x = x_noisy , t = times , clip_denoised = clip_denoised , learned_variance = True , model_output = model_output )
1178
1178
1179
1179
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -1238,7 +1238,7 @@ def sample(
1238
1238
1239
1239
if unet .lowres_cond :
1240
1240
lowres_cond_img = resize_image_to (img , image_size )
1241
- lowres_cond_img = self .q_sample (x_start = lowres_cond_img , t = lowres_noise_times , noise = torch .randn_like (lowres_cond_img ))
1241
+ lowres_cond_img = self .noise_scheduler . q_sample (x_start = lowres_cond_img , t = lowres_noise_times , noise = torch .randn_like (lowres_cond_img ))
1242
1242
1243
1243
shape = (batch_size , self .channels , image_size , image_size )
1244
1244
@@ -1301,12 +1301,4 @@ def forward(
1301
1301
1302
1302
image = resize_image_to (image , target_image_size )
1303
1303
1304
- if exists (random_crop_size ):
1305
- aug = K .RandomCrop ((random_crop_size , random_crop_size ), p = 1. )
1306
-
1307
- # make sure low res conditioner and image both get augmented the same way
1308
- # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
1309
- image = aug (image )
1310
- lowres_cond_img = aug (lowres_cond_img , params = aug ._params )
1311
-
1312
1304
return self .p_losses (unet , image , times , text_embeds = text_embeds , text_mask = text_masks , lowres_cond_img = lowres_cond_img , lowres_aug_times = lowres_aug_times , learned_variance = learned_variance )
0 commit comments