diff --git a/guided_diffusion/gaussian_diffusion.py b/guided_diffusion/gaussian_diffusion.py index b1e37fe4e..1870bf497 100644 --- a/guided_diffusion/gaussian_diffusion.py +++ b/guided_diffusion/gaussian_diffusion.py @@ -169,6 +169,15 @@ def __init__( / (1.0 - self.alphas_cumprod) ) + betas = get_named_beta_schedule('linear', 1000) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + g_full = np.sqrt(1-alphas_cumprod)/np.sqrt(alphas_cumprod) + self.g_full = th.tensor(g_full).cuda() + # linear t + self.g = np.sqrt(1-self.alphas_cumprod)/np.sqrt(self.alphas_cumprod) + self.g_prev = np.sqrt(1-self.alphas_cumprod_prev)/np.sqrt(self.alphas_cumprod_prev) + def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). @@ -326,6 +335,104 @@ def process_xstart(x): "pred_xstart": pred_xstart, } + def p_mean_variance_non_warp( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + #model_output = model(x, t, **model_kwargs) + #pdb.set_trace() + ''' + t = t.long() + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + pdb.set_trace() + ''' + return { + #"mean": model_mean, + #"variance": model_variance, + #"log_variance": model_log_variance, + #"pred_xstart": pred_xstart, + "eps": model_output, + } + def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( @@ -693,6 +800,7 @@ def ddim_sample( # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta @@ -814,6 +922,7 @@ def ddim_sample_loop( clip_denoised=True, denoised_fn=None, cond_fn=None, + impu_fn=None, model_kwargs=None, device=None, progress=False, @@ -856,6 +965,7 @@ def ddim_sample_loop_progressive( clip_denoised=True, denoised_fn=None, cond_fn=None, + impu_fn=None, model_kwargs=None, device=None, progress=False, @@ -950,6 +1060,7 @@ def plms_sample( clip_denoised=True, denoised_fn=None, cond_fn=None, + impu_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, @@ -992,7 +1103,7 @@ def get_model_output(x, t): alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) - if order > 1 and old_out is None: + if order >= 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps @@ -1024,6 +1135,9 @@ def get_model_output(x, t): nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) + if impu_fn is not None: + sample = self.condition_score3(impu_fn, None, sample, t, model_kwargs=model_kwargs) + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( @@ -1034,6 +1148,7 @@ def plms_sample_loop( clip_denoised=True, denoised_fn=None, cond_fn=None, + impu_fn=None, model_kwargs=None, device=None, progress=False, @@ -1056,6 +1171,7 @@ def plms_sample_loop( clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, + impu_fn=None, model_kwargs=model_kwargs, device=device, progress=progress, @@ -1076,6 +1192,7 @@ def plms_sample_loop_progressive( clip_denoised=True, denoised_fn=None, cond_fn=None, + impu_fn=None, model_kwargs=None, device=None, progress=False, @@ -1130,6 +1247,7 @@ def plms_sample_loop_progressive( clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, + impu_fn=impu_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, @@ -1138,6 +1256,362 @@ def plms_sample_loop_progressive( yield out old_out = out img = out["sample"] + def condition_score2(self, cond_fn, p_mean_var, x, t, del_g, s0=0, s_1=0, model_kwargs=None): + """ + Unlike condition_score(), this function output only grad value. + Note that p_mean_var and s_1 is never used in this version. + """ + grad = cond_fn( x, self._scale_timesteps(t), **model_kwargs ) + return - grad * del_g * (1 - s0**2).sqrt() + + def condition_score3(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + x = cond_fn( x, self._scale_timesteps(t), **model_kwargs ) + return x + + def stsp_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + cond_fn_with_grad=False, + order=2, + old_out=None, + ): + """ + Sample x_{t-1} from the model using Pseudo Linear Multistep + and Strange Splitting for conditioning. + Same usage as p_sample(). + """ + g0 = _extract_into_tensor(self.g, t[0], (1,)) + g_1 = _extract_into_tensor(self.g_prev, t[0], (1,)) + s0 = 1/th.sqrt(g0**2 + 1) + s_1 = 1/th.sqrt(g_1**2 + 1) + del_g = g_1 - g0 + + if cond_fn is not None: + alpha_half = 1/((g0 + g_1)**2/4+1) + s_half = th.sqrt(alpha_half) + grad = self.condition_score2(cond_fn, None, x, t, del_g/2, s0, s_half, model_kwargs=model_kwargs) + x = x + grad * s0 + + out_orig = self.p_mean_variance( + model, x, t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + eps = self._predict_eps_from_xstart(x, t, out_orig["pred_xstart"]) + + if old_out is None: + old_out = [] + old_eps = [eps] + else: + old_eps = old_out['old_eps'] + old_eps.append(eps) + + eps_prime = plms_mixer(old_eps, order) + sample = (x/s0 + del_g * eps_prime) * s_1 + + if cond_fn is not None: + if (t[0]).long() < 1: t = t + 1 + grad = self.condition_score2(cond_fn, out_orig, sample, t-1, del_g/2, s_half, s_1, model_kwargs=model_kwargs) + sample = sample + grad * s_1 + + if impu_fn is not None: + sample = self.condition_score3(impu_fn, None, sample, t, model_kwargs=model_kwargs) + + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} + + def stsp_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Generate samples from the model using Strang Splitting + and Pseudo Linear Multistep. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.stsp_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + impu_fn=impu_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + ): + final = sample + return final["sample"] + + def stsp_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Use STSP to sample from the model and yield intermediate samples from each + timestep of STSP. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + indices = tqdm(indices, desc="Steps") + + old_out = None + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + out = self.stsp_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + impu_fn=impu_fn, + model_kwargs=model_kwargs, + order=order, + old_out=old_out, + ) + yield out + old_out = out + img = out["sample"] + + def ltsp_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + cond_fn_with_grad=False, + order=2, + old_out=None, + ): + """ + Sample x_{t-1} from the model using Lie-Trotter Splitting + and Pseudo Linear Multistep. + Same usage as p_sample(). + """ + + g0 = _extract_into_tensor(self.g, t[0], (1,)) + g_1 = _extract_into_tensor(self.g_prev, t[0], (1,)) + s0 = 1/th.sqrt(g0**2 + 1) + s_1 = 1/th.sqrt(g_1**2 + 1) + del_g = g_1 - g0 + + if cond_fn is not None and True: # For testing only + grad = self.condition_score2(cond_fn, None, x, t, del_g, s0, s_1, model_kwargs=model_kwargs) + x = x + grad * s0 + + out_orig = self.p_mean_variance( + model, x, t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + eps = self._predict_eps_from_xstart(x, t, out_orig["pred_xstart"]) + + if old_out is None: + old_out = [] + old_eps = [eps] + else: + old_eps = old_out['old_eps'] + old_eps.append(eps) + + + eps_prime = plms_mixer(old_eps, order) + sample = (x/s0 + del_g * eps_prime) * s_1 + + if cond_fn is not None and False: + grad = self.condition_score2(cond_fn, None, sample, t, del_g, s0, s_1, model_kwargs=model_kwargs) + sample = sample + grad * s_1 + + if impu_fn is not None: + sample = self.condition_score3(impu_fn, None, sample, t, model_kwargs=model_kwargs) + + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} + + def ltsp_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Generate samples from the model using Lie-Trotter Splitting + and Pseudo Linear Multistep. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ltsp_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + impu_fn=impu_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + ): + final = sample + return final["sample"] + + def ltsp_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + impu_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Use LTPS to sample from the model and yield intermediate samples from each + timestep of LTSP. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + indices = tqdm(indices, desc="Steps") + + old_out = None + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + out = self.ltsp_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + impu_fn=impu_fn, + model_kwargs=model_kwargs, + order=order, + old_out=old_out, + ) + yield out + old_out = out + img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None @@ -1339,3 +1813,19 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) + + +def plms_mixer(old_eps, order=1): + cur_order = min(order, len(old_eps)) + if cur_order == 1: + eps_prime = old_eps[-1] + elif cur_order == 2: + eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 + elif cur_order == 3: + eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 + elif cur_order == 4: + eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 + + if len(old_eps) >= order: + old_eps.pop(0) + return eps_prime