From ae09be912b7febed58790d51ba3dccbe4ae12a29 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 8 Oct 2025 17:53:44 -0400 Subject: [PATCH 1/5] Fixing the gaussian multinomial diffusion module --- .../gaussian_multinomial_diffusion.py | 1238 +++++++++++------ .../models/clavaddpm/test_model.py | 6 +- 2 files changed, 800 insertions(+), 444 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py index 29cd4b67..a029c8a1 100644 --- a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py +++ b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py @@ -8,15 +8,15 @@ import math from collections.abc import Callable from enum import Enum -from typing import Any, cast +from logging import DEBUG, INFO +from typing import Any, Literal, Protocol, cast import numpy as np import torch -import torch.nn.functional as F - -# ruff: noqa: N812 from torch import Tensor +from torch.nn import functional +from midst_toolkit.common.logger import log from midst_toolkit.models.clavaddpm.diffusion_utils import ( FoundNaNsError, discretized_gaussian_log_likelihood, @@ -66,6 +66,24 @@ class Parametrization(Enum): DIRECT = "direct" +class ConditioningFunction(Protocol): + """The definition of a function used to condition the model output.""" + + def __call__(self, features: Tensor, timestep: Tensor, **kwargs: Any) -> Tensor: + """ + The function call definition. + + Args: + features: The input features. + timestep: The timestep. + **kwargs: Extra keyword arguments passed to the model. + + Returns: + The model output. + """ + ... + + def get_named_beta_schedule(scheduler_type: SchedulerType, num_diffusion_timesteps: int) -> np.ndarray: """ Get a pre-defined beta schedule for the given name. @@ -88,11 +106,13 @@ def get_named_beta_schedule(scheduler_type: SchedulerType, num_diffusion_timeste beta_start = scale * 0.0001 beta_end = scale * 0.02 return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + if scheduler_type == SchedulerType.COSINE: return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) + raise ValueError(f"Unsupported scheduler: {scheduler_type.value}") @@ -100,24 +120,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, alpha_bar: Callable, max_b """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. + + Args: + num_diffusion_timesteps: The number of betas to produce. + alpha_bar: A lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + max_beta: The maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + The beta schedule. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( - # ruff: noqa: PLR0915 self, num_classes: np.ndarray, num_numerical_features: int, @@ -129,7 +154,20 @@ def __init__( scheduler_type: SchedulerType = SchedulerType.COSINE, device: torch.device | None = None, ): - # ruff: noqa: D107 + """ + Initialize a GaussianMultinomialDiffusion instance. + + Args: + num_classes: The number of classes. + num_numerical_features: The number of numerical features. + denoise_fn: The denoising function. + num_timesteps: The number of timesteps. Default is 1000. + gaussian_loss_type: The type of Gaussian loss. Default is GaussianLossType.MSE. + gaussian_parametrization: The type of Gaussian parametrization. Default is GaussianParametrization.EPS. + parametrization: The type of parametrization. Default is Parametrization.X0. + scheduler_type: The type of scheduler. Default is SchedulerType.COSINE. + device: The device to use. Default is None, which means the device is the CPU. + """ if device is None: device = torch.device("cpu") @@ -166,97 +204,138 @@ def __init__( self.log_1_min_cumprod_alpha: Tensor self.sqrt_recipm1_alphas_cumprod: Tensor self.sqrt_recip_alphas_cumprod: Tensor - self.Lt_history: Tensor - self.Lt_count: Tensor + self.lt_history: Tensor + self.lt_count: Tensor - a = 1.0 - get_named_beta_schedule(scheduler_type, num_timesteps) + buffers = self._calculate_buffer_values() + + # Gaussian diffusion + betas = 1.0 - buffers["alphas"] + self.posterior_variance = betas * (1.0 - buffers["alphas_cumprod_prev"]) / (1.0 - buffers["alphas_cumprod"]) + posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_log_variance_clipped = torch.from_numpy(posterior_log_variance_clipped).float().to(self.device) + posterior_mean_coef1 = betas * np.sqrt(buffers["alphas_cumprod_prev"]) / (1.0 - buffers["alphas_cumprod"]) + self.posterior_mean_coef1 = posterior_mean_coef1.float().to(self.device) + coef2_denominator = (1.0 - buffers["alphas_cumprod_prev"]) * np.sqrt(buffers["alphas"].numpy()) + coef2_numerator = 1.0 - buffers["alphas_cumprod"] + self.posterior_mean_coef2 = (coef2_denominator / coef2_numerator).float().to(self.device) + + assert log_add_exp(buffers["log_alpha"], buffers["log_1_min_alpha"]).abs().sum().item() < 1.0e-5 + assert log_add_exp(buffers["log_cumprod_alpha"], buffers["log_1_min_cumprod_alpha"]).abs().sum().item() < 1e-5 + diff: Tensor = cast(Tensor, np.cumsum(buffers["log_alpha"]) - buffers["log_cumprod_alpha"]) + assert diff.abs().sum().item() < 1.0e-5 + + # Convert to float32 and register buffers. + for key, value in buffers.items(): + self.register_buffer(key, value.float().to(self.device)) + + def _calculate_buffer_values(self) -> dict[str, Tensor]: + """ + Calculate the values to register in this module's buffer. + + Returns: + A dictionary of tensors with the values to register in the buffer. Will contain the keys: + log_alpha, log_cumprod_alpha, log_1_min_alpha, log_1_min_cumprod_alpha, alphas_cumprod, + alphas_cumprod_prev, alphas_cumprod_next, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, + sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod, lt_history, lt_count + """ + a = 1.0 - get_named_beta_schedule(self.scheduler_type, self.num_timesteps) alphas = torch.tensor(a.astype("float64")) - betas = 1.0 - alphas - log_alpha: Tensor = np.log(alphas) # type: ignore[assignment] - log_cumprod_alpha: Tensor = np.cumsum(log_alpha) # type: ignore[assignment] + log_alpha = torch.tensor(np.log(alphas)) + log_cumprod_alpha = torch.tensor(np.cumsum(log_alpha)) log_1_min_alpha: Tensor = log_1_min_a(log_alpha) log_1_min_cumprod_alpha: Tensor = log_1_min_a(log_cumprod_alpha) - alphas_cumprod: Tensor = np.cumprod(alphas, axis=0) # type: ignore[assignment] + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0)) alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1])) alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0)) - sqrt_alphas_cumprod: Tensor = np.sqrt(alphas_cumprod) # type: ignore[assignment] - sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) - sqrt_recip_alphas_cumprod: Tensor = np.sqrt(1.0 / alphas_cumprod) - sqrt_recipm1_alphas_cumprod: Tensor = np.sqrt(1.0 / alphas_cumprod - 1) + sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod)) + sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1.0 - alphas_cumprod)) + sqrt_recip_alphas_cumprod = torch.tensor(np.sqrt(1.0 / alphas_cumprod)) + sqrt_recipm1_alphas_cumprod = torch.tensor(np.sqrt(1.0 / alphas_cumprod - 1)) - # Gaussian diffusion + return { + "alphas": alphas, + "log_alpha": log_alpha, + "log_1_min_alpha": log_1_min_alpha, + "log_1_min_cumprod_alpha": log_1_min_cumprod_alpha, + "log_cumprod_alpha": log_cumprod_alpha, + "alphas_cumprod": alphas_cumprod, + "alphas_cumprod_prev": alphas_cumprod_prev, + "alphas_cumprod_next": alphas_cumprod_next, + "sqrt_alphas_cumprod": sqrt_alphas_cumprod, + "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod, + "sqrt_recip_alphas_cumprod": sqrt_recip_alphas_cumprod, + "sqrt_recipm1_alphas_cumprod": sqrt_recipm1_alphas_cumprod, + "lt_history": torch.zeros(self.num_timesteps), + "lt_count": torch.zeros(self.num_timesteps), + } - self.posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - self.posterior_log_variance_clipped = ( - torch.from_numpy(np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))) - .float() - .to(device) - ) - self.posterior_mean_coef1 = (betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)).float().to(device) - self.posterior_mean_coef2 = ( - ((1.0 - alphas_cumprod_prev) * np.sqrt(alphas.numpy()) / (1.0 - alphas_cumprod)).float().to(device) - ) + # Gaussian part + def gaussian_q_mean_variance(self, x_start: Tensor, timestep: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + Calculate the mean and variance of the Gaussian posterior distribution. - assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.0e-5 - assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 - diff: Tensor = cast(Tensor, np.cumsum(log_alpha) - log_cumprod_alpha) - assert diff.abs().sum().item() < 1.0e-5 + Args: + x_start: The initial, noiseless input. + timestep: The timestep. - # Convert to float32 and register buffers. - self.register_buffer("alphas", alphas.float().to(device)) - self.register_buffer("log_alpha", log_alpha.float().to(device)) - self.register_buffer("log_1_min_alpha", log_1_min_alpha.float().to(device)) - self.register_buffer("log_1_min_cumprod_alpha", log_1_min_cumprod_alpha.float().to(device)) - self.register_buffer("log_cumprod_alpha", log_cumprod_alpha.float().to(device)) - self.register_buffer("alphas_cumprod", alphas_cumprod.float().to(device)) - self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev.float().to(device)) - self.register_buffer("alphas_cumprod_next", alphas_cumprod_next.float().to(device)) - self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod.float().to(device)) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - sqrt_one_minus_alphas_cumprod.float().to(device), - ) - self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod.float().to(device)) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - sqrt_recipm1_alphas_cumprod.float().to(device), - ) + Returns: + The mean and variance of the Gaussian posterior distribution. + """ + mean = extract(self.sqrt_alphas_cumprod, timestep, x_start.shape) * x_start + variance = extract(1.0 - self.alphas_cumprod, timestep, x_start.shape) + log_variance = extract(self.log_1_min_cumprod_alpha, timestep, x_start.shape) + return mean, variance, log_variance - self.register_buffer("Lt_history", torch.zeros(num_timesteps)) - self.register_buffer("Lt_count", torch.zeros(num_timesteps)) + def gaussian_q_sample(self, x_start: Tensor, timestep: Tensor, noise: Tensor | None = None) -> Tensor: + """ + Sample from the Gaussian posterior distribution. - # Gaussian part - def gaussian_q_mean_variance(self, x_start: Tensor, t: Tensor) -> tuple[Tensor, Tensor, Tensor]: - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_1_min_cumprod_alpha, t, x_start.shape) - return mean, variance, log_variance + Args: + x_start: The initial, noiseless input. + timestep: The timestep. + noise: The noise. Optional, default is None. - def gaussian_q_sample(self, x_start: Tensor, t: Tensor, noise: Tensor | None = None) -> Tensor: + Returns: + The sample from the Gaussian posterior distribution. + """ if noise is None: noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + extract(self.sqrt_alphas_cumprod, timestep, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, timestep, x_start.shape) * noise ) def gaussian_q_posterior_mean_variance( self, x_start: Tensor, - x_t: Tensor, - t: Tensor, + features: Tensor, + timestep: Tensor, ) -> tuple[Tensor, Tensor, Tensor]: - assert x_start.shape == x_t.shape + """ + Calculate the mean and variance of the Gaussian posterior distribution. + + Args: + x_start: The initial, noiseless input. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + + Returns: + A tuple with 3 tensors: the mean, the variance, and the log variance of + the Gaussian posterior distribution. + """ + assert x_start.shape == features.shape posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + extract(self.posterior_mean_coef1, timestep, features.shape) * x_start + + extract(self.posterior_mean_coef2, timestep, features.shape) * features ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + posterior_variance = extract(self.posterior_variance, timestep, features.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, timestep, features.shape) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] @@ -268,41 +347,65 @@ def gaussian_q_posterior_mean_variance( def gaussian_p_mean_variance( self, model_output: Tensor, - x: Tensor, - t: Tensor, - clip_denoised: bool = False, - denoised_fn: Callable | None = None, + features: Tensor, + timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Tensor]: + """ + Calculate the mean and variance of the Gaussian prior distribution. + + Args: + model_output: The model output. + features: The features of the Gaussian distribution. + timestep: The timestep. + clip_denoised: Whether to clip the denoised output. Optional, default is False. + denoised_fn: The denoised function. Optional, default is None. + model_kwargs: The model kwargs. Optional, default is None. + + Returns: + A dictionary with the followingf keys: + - "mean": the mean of the Gaussian prior distribution. + - "variance": the variance of the Gaussian prior distribution. + - "log_variance": the log variance of the Gaussian prior distribution. + - "pred_xstart": the predicted xstart of the Gaussian prior distribution. + """ if model_kwargs is None: model_kwargs = {} - B, C = x.shape[:2] - assert t.shape == (B,) + batch_size, _ = features.shape[:2] + assert timestep.shape == (batch_size,) model_variance = torch.cat( [ - self.posterior_variance[1].unsqueeze(0).to(x.device), + self.posterior_variance[1].unsqueeze(0).to(features.device), (1.0 - self.alphas)[1:], ], dim=0, ) model_log_variance = torch.log(model_variance) - model_variance = extract(model_variance, t, x.shape) - model_log_variance = extract(model_log_variance, t, x.shape) + model_variance = extract(model_variance, timestep, features.shape) + model_log_variance = extract(model_log_variance, timestep, features.shape) if self.gaussian_parametrization == GaussianParametrization.EPS: - pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + pred_xstart = self._predict_xstart_from_eps(features=features, timestep=timestep, eps=model_output) + elif self.gaussian_parametrization == GaussianParametrization.X0: pred_xstart = model_output + else: raise ValueError(f"Unsupported Gaussian parametrization: {self.gaussian_parametrization}") - model_mean, _, _ = self.gaussian_q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + model_mean, _, _ = self.gaussian_q_posterior_mean_variance( + x_start=pred_xstart, features=features, timestep=timestep + ) - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape, ( - f"{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}" + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == features.shape, ( + "Expected shapes to be equal, but got: ", + f"model_mean.shape: {model_mean.shape}, ", + f"model_log_variance.shape: {model_log_variance.shape}, ", + f"pred_xstart.shape: {pred_xstart.shape}, ", + f"features.shape: {features.shape}", ) return { @@ -316,38 +419,52 @@ def _vb_terms_bpd( self, model_output: Tensor, x_start: Tensor, - x_t: Tensor, - t: Tensor, - clip_denoised: bool = False, + features: Tensor, + timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Tensor]: + """ + Calculate the VB terms for the Gaussian part. + + Args: + model_output: The model output. + x_start: The initial, noiseless input. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + model_kwargs: The model kwargs. Optional, default is None. + + Returns: + A dictionary with the following keys: + - "output": The output of the VB terms. + - "pred_xstart": The predicted xstart of the Gaussian prior distribution. + - "out_mean": The mean of the Gaussian prior distribution. + - "true_mean": The true mean of the Gaussian prior distribution. + """ if model_kwargs is None: model_kwargs = {} - ( - true_mean, - _, - true_log_variance_clipped, - ) = self.gaussian_q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) - out = self.gaussian_p_mean_variance( - model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance( + x_start=x_start, + features=features, + timestep=timestep, ) - kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + p_mean_variance = self.gaussian_p_mean_variance(model_output, features, timestep, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, p_mean_variance["mean"], p_mean_variance["log_variance"]) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( - x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + x_start, means=p_mean_variance["mean"], log_scales=0.5 * p_mean_variance["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = torch.where((t == 0), decoder_nll, kl) + output = torch.where((timestep == 0), decoder_nll, kl) return { "output": output, - "pred_xstart": out["pred_xstart"], - "out_mean": out["mean"], + "pred_xstart": p_mean_variance["pred_xstart"], + "out_mean": p_mean_variance["mean"], "true_mean": true_mean, } @@ -358,8 +475,11 @@ def _prior_gaussian(self, x_start: Tensor) -> Tensor: This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. + Args: + x_start: the [N x C x ...] tensor of inputs. + + Returns: + A batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) @@ -371,47 +491,82 @@ def _gaussian_loss( self, model_out: Tensor, x_start: Tensor, - x_t: Tensor, - t: Tensor, + features: Tensor, + timestep: Tensor, noise: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> Tensor: + """ + Calculate the Gaussian loss. + + Args: + model_out: The model output. + x_start: The initial, noiseless input. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + noise: The noise. + model_kwargs: The model kwargs. Optional, default is None. + + Returns: + The Gaussian loss. + """ if model_kwargs is None: model_kwargs = {} - terms = {} if self.gaussian_loss_type == GaussianLossType.MSE: - terms["loss"] = mean_flat((noise - model_out) ** 2) - elif self.gaussian_loss_type == GaussianLossType.KL: - terms["loss"] = self._vb_terms_bpd( + return mean_flat((noise - model_out) ** 2) + + if self.gaussian_loss_type == GaussianLossType.KL: + return self._vb_terms_bpd( model_output=model_out, x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, + features=features, + timestep=timestep, model_kwargs=model_kwargs, )["output"] - return terms["loss"] + raise ValueError(f"Unsupported Gaussian loss type: {self.gaussian_loss_type}") - def _predict_xstart_from_eps(self, x_t: Tensor, t: Tensor, eps: Tensor) -> Tensor: - assert x_t.shape == eps.shape + def _predict_xstart_from_eps(self, features: Tensor, timestep: Tensor, eps: Tensor) -> Tensor: + """ + Predict the xstart from the eps. + + Args: + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + eps: The eps. + + Returns: + The predicted xstart. + """ + assert features.shape == eps.shape return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + extract(self.sqrt_recip_alphas_cumprod, timestep, features.shape) * features + - extract(self.sqrt_recipm1_alphas_cumprod, timestep, features.shape) * eps ) - def _predict_eps_from_xstart(self, x_t: Tensor, t: Tensor, pred_xstart: Tensor) -> Tensor: - return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / extract( - self.sqrt_recipm1_alphas_cumprod, t, x_t.shape + def _predict_eps_from_xstart(self, features: Tensor, timestep: Tensor, pred_xstart: Tensor) -> Tensor: + """ + Predict the eps from the xstart. + + Args: + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + pred_xstart: The predicted xstart. + + Returns: + The predicted eps. + """ + return (extract(self.sqrt_recip_alphas_cumprod, timestep, features.shape) * features - pred_xstart) / extract( + self.sqrt_recipm1_alphas_cumprod, timestep, features.shape ) def condition_mean( self, - cond_fn: Callable[[Tensor, Tensor, Any], Tensor], + cond_fn: ConditioningFunction, p_mean_var: dict[str, Tensor], - x: Tensor, - t: Tensor, + features: Tensor, + timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> Tensor: """ @@ -421,19 +576,29 @@ def condition_mean( condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + + Args: + cond_fn: The conditioning function. + p_mean_var: The mean and variance of the Gaussian prior distribution. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + model_kwargs: The model kwargs. Optional, default is None. + + Returns: + The mean for the previous step. """ if model_kwargs is None: model_kwargs = {} - gradient = cond_fn(x, t, **model_kwargs) # type: ignore[call-arg] + gradient = cond_fn(features, timestep, **model_kwargs) return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() def condition_score( self, - cond_fn: Callable[[Tensor, Tensor, Any], Tensor], + cond_fn: ConditioningFunction, p_mean_var: dict[str, Tensor], - x: Tensor, - t: Tensor, + features: Tensor, + timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Tensor]: """ @@ -444,46 +609,72 @@ def condition_score( Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). + + Args: + cond_fn: The conditioning function. + p_mean_var: The mean and variance of the Gaussian prior distribution. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + model_kwargs: The model kwargs. Optional, default is None. + + Returns: + The mean and variance of the Gaussian prior distribution. """ if model_kwargs is None: model_kwargs = {} - alpha_bar = extract(self.alphas_cumprod, t, x.shape) + alpha_bar = extract(self.alphas_cumprod, timestep, features.shape) - eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) # type: ignore[call-arg] + eps = self._predict_eps_from_xstart(features, timestep, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(features, timestep, **model_kwargs) out = p_mean_var.copy() - out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.gaussian_q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + out["pred_xstart"] = self._predict_xstart_from_eps(features, timestep, eps) + out["mean"], _, _ = self.gaussian_q_posterior_mean_variance( + x_start=out["pred_xstart"], + features=features, + timestep=timestep, + ) return out def gaussian_p_sample( self, model_out: Tensor, - x: Tensor, - t: Tensor, - clip_denoised: bool = False, - denoised_fn: Callable | None = None, + features: Tensor, + timestep: Tensor, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> dict[str, Tensor]: + """ + Sample from the Gaussian posterior distribution. + + Args: + model_out: The model output. + features: The features used to compute the Gaussian parameters. + timestep: The timestep. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + A dictionary with teo tensors: + - "sample": the sample from the Gaussian posterior distribution. + - "pred_xstart": the predicted xstart. + """ if model_kwargs is None: model_kwargs = {} out = self.gaussian_p_mean_variance( model_out, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, + features, + timestep, model_kwargs=model_kwargs, ) - noise = torch.randn_like(x) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + noise = torch.randn_like(features) + # no noise when t == 0 + nonzero_mask = (timestep != 0).float().view(-1, *([1] * (len(features.shape) - 1))) if cond_fn is not None: - out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + out["mean"] = self.condition_mean(cond_fn, out, features, timestep, model_kwargs=model_kwargs) sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -491,11 +682,31 @@ def gaussian_p_sample( # Multinomial part def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor) -> Tensor: + """ + Calculate the KL divergence between two log probabilities. + + Args: + log_prob1: The first log probability. + log_prob2: The second log probability. + + Returns: + The KL divergence. + """ return (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) - def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: - log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) - log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) + def q_pred_one_timestep(self, log_x_t: Tensor, timestep: Tensor) -> Tensor: + """ + Calculate the predicted log probability for one timestep. + + Args: + log_x_t: The log probability of the features. + timestep: The timestep. + + Returns: + The predicted log probability. + """ + log_alpha_t = extract(self.log_alpha, timestep, log_x_t.shape) + log_1_min_alpha_t = extract(self.log_1_min_alpha, timestep, log_x_t.shape) # alpha_t * E[xt] + (1 - alpha_t) 1 / K return log_add_exp( @@ -503,59 +714,124 @@ def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: log_1_min_alpha_t - torch.log(self.num_classes_expanded), ) - def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor: - log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) - log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) + def q_pred(self, log_x_start: Tensor, timestep: Tensor) -> Tensor: + """ + Calculate the predicted log probability for one timestep. + + Args: + log_x_start: The log probability of the start. + timestep: The timestep. + + Returns: + The predicted log probability. + """ + log_cumprod_alpha_t = extract(self.log_cumprod_alpha, timestep, log_x_start.shape) + log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, timestep, log_x_start.shape) return log_add_exp( log_x_start + log_cumprod_alpha_t, log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded), ) - def predict_start(self, model_out: Tensor, log_x_t: Tensor, t: Tensor, out_dict: dict[str, Tensor]) -> Tensor: + def predict_start(self, model_out: Tensor, log_x_t: Tensor) -> Tensor: + """ + Predict the start from the model output. + + Args: + model_out: The model output. + log_x_t: The log probability of the features. + + Returns: + The predicted start. + """ assert model_out.size(0) == log_x_t.size(0) assert self.num_classes is not None assert model_out.size(1) == self.num_classes.sum(), f"{model_out.size()}" log_pred = torch.empty_like(model_out) for ix in self.slices_for_classes: - log_pred[:, ix] = F.log_softmax(model_out[:, ix], dim=1) + log_pred[:, ix] = functional.log_softmax(model_out[:, ix], dim=1) return log_pred - def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, t: Tensor) -> Tensor: - t_minus_1 = t - 1 + def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, timestep: Tensor) -> Tensor: + """ + Calculate the posterior probability for one timestep. + + Args: + log_x_start: The log probability of the initial input. + log_x_t: The log probability of the features. + timestep: The timestep. + + Returns: + The posterior probability. + """ + t_minus_1 = timestep - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) - log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) + log_ev_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) num_axes = (1,) * (len(log_x_start.size()) - 1) - t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) - log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32)) + t_broadcast = timestep.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) + log_ev_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_ev_qxtmin_x0.to(torch.float32)) # unnormed_logprobs = log_EV_qxtmin_x0 + # log q_pred_one_timestep(x_t, t) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! # Not very easy to see why this is true. But it is :) - unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) + unnormed_logprobs = log_ev_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, timestep) return unnormed_logprobs - sliced_logsumexp(unnormed_logprobs, self.offsets) - def p_pred(self, model_out: Tensor, log_x: Tensor, t: Tensor, out_dict: dict[str, Tensor]) -> Tensor: + def p_pred(self, model_out: Tensor, log_x: Tensor, timestep: Tensor) -> Tensor: + """ + Predict the log probability of the model output. + + Args: + model_out: The model output. + log_x: The log probability of the features. + timestep: The timestep. + + Returns: + The log probability of the model output. + """ if self.parametrization == Parametrization.X0: - log_x_recon = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) - log_model_pred = self.q_posterior(log_x_start=log_x_recon, log_x_t=log_x, t=t) + log_x_recon = self.predict_start(model_out, log_x) + log_model_pred = self.q_posterior(log_x_start=log_x_recon, log_x_t=log_x, timestep=timestep) + elif self.parametrization == Parametrization.DIRECT: - log_model_pred = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) + log_model_pred = self.predict_start(model_out, log_x) + else: raise ValueError(f"Unsupported parametrization: {self.parametrization}") + return log_model_pred @torch.no_grad() - def p_sample(self, model_out: Tensor, log_x: Tensor, t: Tensor, out_dict: dict[str, Tensor]) -> Tensor: - model_log_prob = self.p_pred(model_out, log_x=log_x, t=t, out_dict=out_dict) + def p_sample(self, model_out: Tensor, log_x: Tensor, timestep: Tensor) -> Tensor: + """ + Sample from the model output. + + Args: + model_out: The model output. + log_x: The log probability of the features. + timestep: The timestep. + + Returns: + The sample from the model output. + """ + model_log_prob = self.p_pred(model_out, log_x=log_x, timestep=timestep) return self.log_sample_categorical(model_log_prob) def log_sample_categorical(self, logits: Tensor) -> Tensor: + """ + Sample from the categorical logits. + + Args: + logits: The logits. + + Returns: + The sample from the categorical logits. + """ full_sample = [] for i in range(len(self.num_classes)): one_class_logits = logits[:, self.slices_for_classes[i]] @@ -563,38 +839,67 @@ def log_sample_categorical(self, logits: Tensor) -> Tensor: gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) sample = (gumbel_noise + one_class_logits).argmax(dim=1) full_sample.append(sample.unsqueeze(1)) + full_sample_tensor = torch.cat(full_sample, dim=1) return index_to_log_onehot(full_sample_tensor, torch.from_numpy(self.num_classes)) - def q_sample(self, log_x_start: Tensor, t: Tensor) -> Tensor: - log_EV_qxt_x0 = self.q_pred(log_x_start, t) - # ruff: noqa: N806 - return self.log_sample_categorical(log_EV_qxt_x0) + def q_sample(self, log_x_start: Tensor, timestep: Tensor) -> Tensor: + """ + Sample from the logits for one timestep. + + Args: + log_x_start: The log probability of the initial input. + timestep: The timestep. + + Returns: + The sample from the categorical logits. + """ + log_ev_qxt_x0 = self.q_pred(log_x_start, timestep) + return self.log_sample_categorical(log_ev_qxt_x0) def kl_prior(self, log_x_start: Tensor) -> Tensor: - b = log_x_start.size(0) + """ + Calculate the KL divergence between the prior and the posterior. + + Args: + log_x_start: The log probability of the initial input. + + Returns: + The KL divergence between the prior and the posterior. + """ + batch_size = log_x_start.size(0) device = log_x_start.device - ones = torch.ones(b, device=device).long() + ones = torch.ones(batch_size, device=device).long() - log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) - # ruff: noqa: N806 - log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxT_prob)) + log_qxt_prob = self.q_pred(log_x_start, timestep=(self.num_timesteps - 1) * ones) + log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxt_prob)) - kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) + kl_prior = self.multinomial_kl(log_qxt_prob, log_half_prob) return sum_except_batch(kl_prior) - def compute_Lt( - # ruff: noqa: N802 + def compute_lt( self, model_out: Tensor, log_x_start: Tensor, log_x_t: Tensor, - t: Tensor, - out_dict: dict[str, Tensor], + timestep: Tensor, detach_mean: bool = False, ) -> Tensor: - log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) - log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t, out_dict=out_dict) + """ + Calculate the KL divergence between the true and the model probability. + + Args: + model_out: The model output. + log_x_start: The log probability of the initial input. + log_x_t: The log probability of the features. + timestep: The timestep. + detach_mean: Whether to detach the mean. + + Returns: + The KL divergence between the true and the model probability. + """ + log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, timestep=timestep) + log_model_prob = self.p_pred(model_out, log_x=log_x_t, timestep=timestep) if detach_mean: log_model_prob = log_model_prob.detach() @@ -605,296 +910,308 @@ def compute_Lt( decoder_nll = -log_categorical(log_x_start, log_model_prob) decoder_nll = sum_except_batch(decoder_nll) - mask = (t == torch.zeros_like(t)).float() + mask = (timestep == torch.zeros_like(timestep)).float() return mask * decoder_nll + (1.0 - mask) * kl - def sample_time(self, b: int, device: torch.device, method: str = "uniform") -> tuple[Tensor, Tensor]: + def sample_time( + self, + batch_size: int, + device: torch.device, + method: Literal["uniform", "importance"] = "uniform", + ) -> tuple[Tensor, Tensor]: + """ + Sample the timestep. + + Args: + batch_size: The batch size. + device: The device. + method: The method to sample the timestep. + + Returns: + The timestep and the probability of the timestep. + """ if method == "importance": - if not (self.Lt_count > 10).all(): - return self.sample_time(b, device, method="uniform") + if not (self.lt_count > 10).all(): + return self.sample_time(batch_size, device, method="uniform") - Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 - # ruff: noqa: N806 - Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. - pt_all = (Lt_sqrt / Lt_sqrt.sum()).to(device) + lt_sqrt = torch.sqrt(self.lt_history + 1e-10) + 0.0001 + lt_sqrt[0] = lt_sqrt[1] # Overwrite decoder term with L1. + pt_all = (lt_sqrt / lt_sqrt.sum()).to(device) - t = torch.multinomial(pt_all, num_samples=b, replacement=True).to(device) + timestep = torch.multinomial(pt_all, num_samples=batch_size, replacement=True).to(device) - pt = pt_all.gather(dim=0, index=t) + p_timestep = pt_all.gather(dim=0, index=timestep) - return t, pt + return timestep, p_timestep if method == "uniform": - t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + timestep = torch.randint(0, self.num_timesteps, (batch_size,), device=device).long() + + p_timestep = torch.ones_like(timestep).float() / self.num_timesteps + return timestep, p_timestep - pt = torch.ones_like(t).float() / self.num_timesteps - return t, pt - raise ValueError + raise ValueError(f"Unsupported method: {method}") def _multinomial_loss( self, model_out: Tensor, log_x_start: Tensor, log_x_t: Tensor, - t: Tensor, - pt: Tensor, - out_dict: dict[str, Tensor], + timestep: Tensor, + p_timestep: Tensor, ) -> Tensor: + """ + Calculate the multinomial loss. + + Args: + model_out: The model output. + log_x_start: The log probability of the initial input. + log_x_t: The log probability of the features. + timestep: The timestep. + p_timestep: The probability of the timestep. + + Returns: + The multinomial loss. + """ # Here we are calculating the VB_STOCHASTIC loss. In the original implementation, there # was a choice between VB_STOCHASTIC and VB_ALL. VB_ALL is deprecated for being too # expensive to calculate. - kl = self.compute_Lt(model_out, log_x_start, log_x_t, t, out_dict) + kl = self.compute_lt(model_out, log_x_start, log_x_t, timestep) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl - return kl / pt + kl_prior - - def mixed_loss(self, x: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: - b = x.shape[0] - device = x.device - t, pt = self.sample_time(b, device, "uniform") - - x_num = x[:, : self.num_numerical_features] - x_cat = x[:, self.num_numerical_features :] - - x_num_t = x_num - log_x_cat_t = x_cat - if x_num.shape[1] > 0: - noise = torch.randn_like(x_num) - x_num_t = self.gaussian_q_sample(x_num, t, noise=noise) - if x_cat.shape[1] > 0: - log_x_cat = index_to_log_onehot(x_cat.long(), torch.from_numpy(self.num_classes)) - log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t) + return (kl / p_timestep) + kl_prior - x_in = torch.cat([x_num_t, log_x_cat_t], dim=1) - - model_out = self._denoise_fn(x_in, t, **out_dict) + def mixed_loss(self, features: Tensor, out_dict: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """ + Calculate the mixed loss. - model_out_num = model_out[:, : self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features :] + Args: + features: The input features. + out_dict: The output dictionary. - loss_multi = torch.zeros((1,)).float() - loss_gauss = torch.zeros((1,)).float() - if x_cat.shape[1] > 0: - loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, out_dict) / len( - self.num_classes + Returns: + The multinomial loss and the Gaussian loss. + """ + batch_size = features.shape[0] + device = features.device + timestep, p_timestep = self.sample_time(batch_size, device, "uniform") + + numerical_features = features[:, : self.num_numerical_features] + categorical_features = features[:, self.num_numerical_features :] + + numerical_features_ts = numerical_features + log_categrocial_features_ts = categorical_features + if numerical_features.shape[1] > 0: + noise = torch.randn_like(numerical_features) + numerical_features_ts = self.gaussian_q_sample(numerical_features, timestep, noise=noise) + + if categorical_features.shape[1] > 0: + log_x_cat = index_to_log_onehot(categorical_features.long(), torch.from_numpy(self.num_classes)) + log_categrocial_features_ts = self.q_sample(log_x_start=log_x_cat, timestep=timestep) + + input_features = torch.cat([numerical_features_ts, log_categrocial_features_ts], dim=1) + + model_output = self._denoise_fn(input_features, timestep, **out_dict) + + model_numerical_output = model_output[:, : self.num_numerical_features] + model_categorical_output = model_output[:, self.num_numerical_features :] + + multinomial_loss = torch.zeros((1,)).float() + gaussian_loss = torch.zeros((1,)).float() + if categorical_features.shape[1] > 0: + multinomial_loss = self._multinomial_loss( + model_categorical_output, + log_x_cat, + log_categrocial_features_ts, + timestep, + p_timestep, ) - - if x_num.shape[1] > 0: - loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) - - # loss_multi = torch.where(out_dict['y'] == 1, loss_multi, 2 * loss_multi) - # loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss) - - return loss_multi.mean(), loss_gauss.mean() - - @torch.no_grad() - def mixed_elbo(self, x0: Tensor, out_dict: dict[str, Tensor]) -> dict[str, Tensor]: - b = x0.size(0) - device = x0.device - - x_num = x0[:, : self.num_numerical_features] - x_cat = x0[:, self.num_numerical_features :] - has_cat = x_cat.shape[1] > 0 - if has_cat: - log_x_cat = index_to_log_onehot(x_cat.long(), torch.from_numpy(self.num_classes)).to(device) - - gaussian_loss = [] - xstart_mse = [] - mse = [] - # mu_mse = [] - out_mean = [] - true_mean = [] - multinomial_loss = [] - for t in range(self.num_timesteps): - t_array = (torch.ones(b, device=device) * t).long() - noise = torch.randn_like(x_num) - - x_num_t = self.gaussian_q_sample(x_start=x_num, t=t_array, noise=noise) - log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t_array) if has_cat else x_cat - - model_out = self._denoise_fn(torch.cat([x_num_t, log_x_cat_t], dim=1), t_array, **out_dict) - - model_out_num = model_out[:, : self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features :] - - kl = torch.tensor([0.0]) - if has_cat: - kl = self.compute_Lt( - model_out=model_out_cat, - log_x_start=log_x_cat, - log_x_t=log_x_cat_t, - t=t_array, - out_dict=out_dict, - ) - - out = self._vb_terms_bpd( - model_out_num, - x_start=x_num, - x_t=x_num_t, - t=t_array, - clip_denoised=False, + multinomial_loss = multinomial_loss / len(self.num_classes) + + if numerical_features.shape[1] > 0: + gaussian_loss = self._gaussian_loss( + model_numerical_output, + numerical_features, + numerical_features_ts, + timestep, + noise, ) - multinomial_loss.append(kl) - gaussian_loss.append(out["output"]) - xstart_mse.append(mean_flat((out["pred_xstart"] - x_num) ** 2)) - # mu_mse.append(mean_flat(out["mean_mse"])) - out_mean.append(mean_flat(out["out_mean"])) - true_mean.append(mean_flat(out["true_mean"])) - - eps = self._predict_eps_from_xstart(x_num_t, t_array, out["pred_xstart"]) - mse.append(mean_flat((eps - noise) ** 2)) - - gaussian_loss_tensor = torch.stack(gaussian_loss, dim=1) - multinomial_loss_tensor = torch.stack(multinomial_loss, dim=1) - xstart_mse_tensor = torch.stack(xstart_mse, dim=1) - mse_tensor = torch.stack(mse, dim=1) - # mu_mse = torch.stack(mu_mse, dim=1) - out_mean_tensor = torch.stack(out_mean, dim=1) - true_mean_tensor = torch.stack(true_mean, dim=1) - - prior_gauss = self._prior_gaussian(x_num) - - prior_multin = torch.tensor([0.0]) - if has_cat: - prior_multin = self.kl_prior(log_x_cat) - - total_gauss = gaussian_loss_tensor.sum(dim=1) + prior_gauss - total_multin = multinomial_loss_tensor.sum(dim=1) + prior_multin - return { - "total_gaussian": total_gauss, - "total_multinomial": total_multin, - "losses_gaussian": gaussian_loss_tensor, - "losses_multinimial": multinomial_loss_tensor, - "xstart_mse": xstart_mse_tensor, - "mse": mse_tensor, - # "mu_mse": mu_mse - "out_mean": out_mean_tensor, - "true_mean": true_mean_tensor, - } + return multinomial_loss.mean(), gaussian_loss.mean() @torch.no_grad() def gaussian_ddim_step( self, - model_out_num: Tensor, - x: Tensor, - t: Tensor, - clip_denoised: bool = False, - denoised_fn: Callable | None = None, + model_mumerical_output: Tensor, + features: Tensor, + timestep: Tensor, eta: float = 0.0, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> Tensor: + """ + Calculate the Gaussian DDIM step. + + Args: + model_mumerical_output: The numerical features of themodel output. + features: The features. + timestep: The timestep. + eta: The DDIM stochasticity coefficient. Optional, default is 0.0. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + The predicted features. + """ if model_kwargs is None: model_kwargs = {} out = self.gaussian_p_mean_variance( - model_out_num, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, + model_mumerical_output, + features, + timestep, model_kwargs=None, ) if cond_fn is not None: - out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + out = self.condition_score(cond_fn, out, features, timestep, model_kwargs=model_kwargs) - eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + eps = self._predict_eps_from_xstart(features, timestep, out["pred_xstart"]) - alpha_bar = extract(self.alphas_cumprod, t, x.shape) - alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape) + alpha_bar = extract(self.alphas_cumprod, timestep, features.shape) + alpha_bar_prev = extract(self.alphas_cumprod_prev, timestep, features.shape) sigma = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) - noise = torch.randn_like(x) + noise = torch.randn_like(features) mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + nonzero_mask = (timestep != 0).float().view(-1, *([1] * (len(features.shape) - 1))) # no noise when t == 0 return mean_pred + nonzero_mask * sigma * noise @torch.no_grad() def gaussian_ddim_sample( self, noise: Tensor, - T: int, + num_timesteps: int, out_dict: dict[str, Tensor], eta: float = 0.0, model_kwargs: Any | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> Tensor: - # ruff: noqa: D102, N803 - x = noise - b = x.shape[0] - device = x.device - for t in reversed(range(T)): - print(f"Sample timestep {t:4d}", end="\r") - t_array = (torch.ones(b, device=device) * t).long() - out_num = self._denoise_fn(x, t_array, **out_dict) - x = self.gaussian_ddim_step(out_num, x, t_array, model_kwargs=model_kwargs, cond_fn=cond_fn) - print() - return x + """ + Produce the Gaussian DDIM sample. + + Args: + noise: The noise. + num_timesteps: The number of timesteps. + out_dict: The output dictionary. + eta: The DDIM stochasticity coefficient. Optional, default is 0.0. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + The predicted features. + """ + features = noise + batch_size = features.shape[0] + device = features.device + for t in reversed(range(num_timesteps)): + log(DEBUG, f"Sample timestep {t:4d}") + t_array = (torch.ones(batch_size, device=device) * t).long() + out_num = self._denoise_fn(features, t_array, **out_dict) + features = self.gaussian_ddim_step( + out_num, + features, + t_array, + eta=eta, + model_kwargs=model_kwargs, + cond_fn=cond_fn, + ) + + return features @torch.no_grad() def gaussian_ddim_reverse_step( self, model_out_num: Tensor, - x: Tensor, - t: Tensor, - clip_denoised: bool = False, - eta: float = 0.0, + features: Tensor, + timestep: Tensor, ) -> Tensor: - # ruff: noqa: D102 - assert eta == 0.0, "Eta must be zero." - out = self.gaussian_p_mean_variance( - model_out_num, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=None, - model_kwargs=None, - ) + """ + Calculate the Gaussian DDIM reverse step. - eps = (extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]) / extract( - self.sqrt_recipm1_alphas_cumprod, t, x.shape - ) - alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape) + Args: + model_out_num: The numerical features of the model output. + features: The input features. + timestep: The timestep. + + Returns: + The predicted features. + """ + out = self.gaussian_p_mean_variance(model_out_num, features, timestep) + + coefficient = extract(self.sqrt_recip_alphas_cumprod, timestep, features.shape) + denominator = extract(self.sqrt_recipm1_alphas_cumprod, timestep, features.shape) + numerator = coefficient * features - out["pred_xstart"] + eps = numerator / denominator + + alpha_bar_next = extract(self.alphas_cumprod_next, timestep, features.shape) return out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps @torch.no_grad() def gaussian_ddim_reverse_sample( self, - x: Tensor, - T: int, - # ruff: noqa: N803 + features: Tensor, + num_timesteps: int, out_dict: dict[str, Tensor], ) -> Tensor: - # ruff: noqa: D102 - b = x.shape[0] - device = x.device - for t in range(T): - print(f"Reverse timestep {t:4d}", end="\r") - t_array = (torch.ones(b, device=device) * t).long() - out_num = self._denoise_fn(x, t_array, **out_dict) - x = self.gaussian_ddim_reverse_step(out_num, x, t_array, eta=0.0) - print() - - return x + """ + Produce the Gaussian DDIM reverse sample. + + Args: + features: The input features. + num_timesteps: The number of timesteps. + out_dict: The output dictionary. + + Returns: + The predicted features. + """ + batch_size = features.shape[0] + device = features.device + output_features = features.clone() + + for t in range(num_timesteps): + log(DEBUG, f"Reverse timestep {t:4d}") + t_array = (torch.ones(batch_size, device=device) * t).long() + out_num = self._denoise_fn(output_features, t_array, **out_dict) + output_features = self.gaussian_ddim_reverse_step(out_num, output_features, t_array) + + return output_features @torch.no_grad() def multinomial_ddim_step( self, model_out_cat: Tensor, log_x_t: Tensor, - t: Tensor, - out_dict: dict[str, Tensor], + timestep: Tensor, eta: float = 0.0, ) -> Tensor: - # ruff: noqa: D102 - # not ddim, essentially - log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t, t=t, out_dict=out_dict) + """ + Calculate the multinomial DDIM step. - alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) - alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) + Args: + model_out_cat: The categorical model output. + log_x_t: The log probability of the features. + timestep: The timestep. + eta: The DDIM stochasticity coefficient. Optional, default is 0.0. + + Returns: + The predicted log probability. + """ + log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) + + alpha_bar = extract(self.alphas_cumprod, timestep, log_x_t.shape) + alpha_bar_prev = extract(self.alphas_cumprod_prev, timestep, log_x_t.shape) sigma = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) coef1 = sigma @@ -918,29 +1235,40 @@ def multinomial_ddim_step( def sample_ddim( self, num_samples: int, - y_dist: Tensor, + target_dist: Tensor, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> tuple[Tensor, dict[str, Tensor]]: - # ruff: noqa: D102 + """ + Sample using DDIM. + + Args: + num_samples: The number of samples. + target_dist: Class distribution to sample labels from. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + The samples and the output dictionary. + """ if model_kwargs is None: model_kwargs = {} - b = num_samples - z_norm = torch.randn((b, self.num_numerical_features), device=self.device) + batch_size = num_samples + z_norm = torch.randn((batch_size, self.num_numerical_features), device=self.device) assert self.num_classes is not None has_cat = self.num_classes[0] != 0 - log_z = torch.zeros((b, 0), device=self.device).float() + log_z = torch.zeros((batch_size, 0), device=self.device).float() if has_cat: - uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=self.device) + uniform_logits = torch.zeros((batch_size, len(self.num_classes_expanded)), device=self.device) log_z = self.log_sample_categorical(uniform_logits) - y = torch.multinomial(y_dist, num_samples=b, replacement=True) + y = torch.multinomial(target_dist, num_samples=batch_size, replacement=True) out_dict = {"y": y.long().to(self.device)} for i in reversed(range(0, self.num_timesteps)): - print(f"Sample timestep {i:4d}", end="\r") - t = torch.full((b,), i, device=self.device, dtype=torch.long) + log(DEBUG, f"Sample timestep {i:4d}") + t = torch.full((batch_size,), i, device=self.device, dtype=torch.long) model_out = self._denoise_fn(torch.cat([z_norm, log_z], dim=1).float(), t, **out_dict) model_out_num = model_out[:, : self.num_numerical_features] model_out_cat = model_out[:, self.num_numerical_features :] @@ -948,42 +1276,51 @@ def sample_ddim( model_out_num, z_norm, t, - clip_denoised=False, model_kwargs=model_kwargs, cond_fn=cond_fn, ) if has_cat: - log_z = self.multinomial_ddim_step(model_out_cat, log_z, t, out_dict) + log_z = self.multinomial_ddim_step(model_out_cat, log_z, t) - print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: z_cat = ohe_to_categories(z_ohe, torch.from_numpy(self.num_classes)) sample = torch.cat([z_norm, z_cat], dim=1).cpu() + return sample, out_dict @torch.no_grad() def conditional_sample( self, - ys: Tensor, + targets: Tensor, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> tuple[Tensor, dict[str, Tensor]]: - # ruff: noqa: D102 + """ + Sample using conditional DDIM. + + Args: + targets: The targets. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + The samples and the output dictionary. + """ if model_kwargs is None: model_kwargs = {} - b = len(ys) - z_norm = torch.randn((b, self.num_numerical_features), device=self.device) + batch_size = len(targets) + z_norm = torch.randn((batch_size, self.num_numerical_features), device=self.device) assert self.num_classes is not None has_cat = self.num_classes[0] != 0 - log_z = torch.zeros((b, 0), device=self.device).float() + log_z = torch.zeros((batch_size, 0), device=self.device).float() - out_dict = {"y": ys.long().to(self.device)} + out_dict = {"y": targets.long().to(self.device)} for i in reversed(range(0, self.num_timesteps)): - print(f"Sample timestep {i:4d}", end="\r") - t = torch.full((b,), i, device=self.device, dtype=torch.long) + log(DEBUG, f"Sample timestep {i:4d}") + t = torch.full((batch_size,), i, device=self.device, dtype=torch.long) model_out = self._denoise_fn(torch.cat([z_norm, log_z], dim=1).float(), t, **out_dict) model_out_num = model_out[:, : self.num_numerical_features] model_out_cat = model_out[:, self.num_numerical_features :] @@ -991,14 +1328,12 @@ def conditional_sample( model_out_num, z_norm, t, - clip_denoised=False, model_kwargs=model_kwargs, cond_fn=cond_fn, )["sample"] if has_cat: - log_z = self.p_sample(model_out_cat, log_z, t, out_dict) + log_z = self.p_sample(model_out_cat, log_z, t) - print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -1010,29 +1345,40 @@ def conditional_sample( def sample( self, num_samples: int, - y_dist: Tensor, + target_dist: Tensor, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> tuple[Tensor, dict[str, Tensor]]: - # ruff: noqa: D102 + """ + Sample using ancestral (DDPM-style) sampling. + + Args: + num_samples: The number of samples. + target_dist: Class distribution to sample labels from. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + The samples and the output dictionary. + """ if model_kwargs is None: model_kwargs = {} - b = num_samples - z_norm = torch.randn((b, self.num_numerical_features), device=self.device) + batch_size = num_samples + z_norm = torch.randn((batch_size, self.num_numerical_features), device=self.device) assert self.num_classes is not None has_cat = self.num_classes[0] != 0 - log_z = torch.zeros((b, 0), device=self.device).float() + log_z = torch.zeros((batch_size, 0), device=self.device).float() if has_cat: - uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=self.device) + uniform_logits = torch.zeros((batch_size, len(self.num_classes_expanded)), device=self.device) log_z = self.log_sample_categorical(uniform_logits) - y = torch.multinomial(y_dist, num_samples=b, replacement=True) + y = torch.multinomial(target_dist, num_samples=batch_size, replacement=True) out_dict = {"y": y.long().to(self.device)} for i in reversed(range(0, self.num_timesteps)): - print(f"Sample timestep {i:4d}", end="\r") - t = torch.full((b,), i, device=self.device, dtype=torch.long) + log(DEBUG, f"Sample timestep {i:4d}") + t = torch.full((batch_size,), i, device=self.device, dtype=torch.long) model_out = self._denoise_fn(torch.cat([z_norm, log_z], dim=1).float(), t, **out_dict) model_out_num = model_out[:, : self.num_numerical_features] model_out_cat = model_out[:, self.num_numerical_features :] @@ -1040,14 +1386,12 @@ def sample( model_out_num, z_norm, t, - clip_denoised=False, model_kwargs=model_kwargs, cond_fn=cond_fn, )["sample"] if has_cat: - log_z = self.p_sample(model_out_cat, log_z, t, out_dict) + log_z = self.p_sample(model_out_cat, log_z, t) - print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -1059,36 +1403,48 @@ def sample_all( self, num_samples: int, batch_size: int, - y_dist: Tensor, + target_dist: Tensor, ddim: bool = False, model_kwargs: dict[str, Any] | None = None, - cond_fn: Callable | None = None, + cond_fn: ConditioningFunction | None = None, ) -> tuple[Tensor, Tensor]: - # ruff: noqa: D102 + """ + Generate samples in batches of ``batch_size`` until ``num_samples`` are produced. + Uses DDIM if ``ddim`` is ``True``. + + Args: + num_samples: The number of samples. + batch_size: The batch size. + target_dist: Class distribution to sample labels from. + ddim: Whether to use DDIM. Optional, default is False. + model_kwargs: The model kwargs. Optional, default is None. + cond_fn: The conditioning function. Optional, default is None. + + Returns: + A tuple with the generated features and corresponding targets. + """ if ddim: - print("Sample using DDIM.") + log(INFO, "Sample using DDIM.") sample_fn = self.sample_ddim else: sample_fn = self.sample - b = batch_size - - all_y = [] + all_targets = [] all_samples = [] num_generated = 0 while num_generated < num_samples: - sample, out_dict = sample_fn(b, y_dist, model_kwargs=model_kwargs, cond_fn=cond_fn) + sample, out_dict = sample_fn(batch_size, target_dist, model_kwargs=model_kwargs, cond_fn=cond_fn) mask_nan = torch.any(sample.isnan(), dim=1) sample = sample[~mask_nan] out_dict["y"] = out_dict["y"][~mask_nan] all_samples.append(sample) - all_y.append(out_dict["y"].cpu()) - if sample.shape[0] != b: + all_targets.append(out_dict["y"].cpu()) + if sample.shape[0] != batch_size: raise FoundNaNsError num_generated += sample.shape[0] - x_gen = torch.cat(all_samples, dim=0)[:num_samples] - y_gen = torch.cat(all_y, dim=0)[:num_samples] + generated_features = torch.cat(all_samples, dim=0)[:num_samples] + generated_targets = torch.cat(all_targets, dim=0)[:num_samples] - return x_gen, y_gen + return generated_features, generated_targets diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 80cb68f5..847ab67d 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -372,9 +372,9 @@ def test_train_multi_table(tmp_path: Path): ys_tensor = torch.tensor(np.array(ys).reshape(-1, 1), requires_grad=False) conditional_sample, _ = models[1][key]["diffusion"].conditional_sample( - ys=ys_tensor, + targets=ys_tensor, model_kwargs={"y": ys_tensor}, - cond_fn=get_conditional_function_for_the_classifier(models[1][key]["classifier"], classifier_scale), + cond_fn=get_conditioning_function_for_diffusion(models[1][key]["classifier"], classifier_scale), ) expected_conditional_sample = torch.load( @@ -444,7 +444,7 @@ def test_clustering_reload(tmp_path: Path): unset_all_random_seeds() -def get_conditional_function_for_the_classifier(classifier: Classifier, classifier_scale: float) -> Callable: +def get_conditioning_function_for_diffusion(classifier: Classifier, classifier_scale: float) -> Callable: def cond_fn( x: torch.Tensor, t: torch.Tensor, From 1175bc613694a7e61659a75674a9223c564f2550 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 8 Oct 2025 17:56:18 -0400 Subject: [PATCH 2/5] Small docstring adjustment --- .../models/clavaddpm/gaussian_multinomial_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py index a029c8a1..779cf3aa 100644 --- a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py +++ b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py @@ -79,7 +79,7 @@ def __call__(self, features: Tensor, timestep: Tensor, **kwargs: Any) -> Tensor: **kwargs: Extra keyword arguments passed to the model. Returns: - The model output. + The tensor result of the conditioning function. """ ... From 99a6fbe0acf2df1c592725afc5c5fa0462d8ccf8 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 16 Oct 2025 15:02:57 -0400 Subject: [PATCH 3/5] Fixing refactoring issues --- .../models/clavaddpm/synthesizer.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/synthesizer.py b/src/midst_toolkit/models/clavaddpm/synthesizer.py index 9d90d6b6..4c65f4e4 100644 --- a/src/midst_toolkit/models/clavaddpm/synthesizer.py +++ b/src/midst_toolkit/models/clavaddpm/synthesizer.py @@ -181,18 +181,22 @@ def conditional_sampling_by_group_size( # noqa: PLR0915, PLR0912 """ def cond_fn( - x: torch.Tensor, - t: torch.Tensor, - y: torch.Tensor | None = None, - remove_first_col: bool = False, + features: torch.Tensor, + timestep: torch.Tensor, + **kwargs: Any, ) -> torch.Tensor: - assert y is not None + assert "y" in kwargs and kwargs["y"] is not None, "The kwargs parameter `y` must be provided." + assert isinstance(kwargs["y"], torch.Tensor), "The kwargs parameter `y` must be a Tensor." + + y = kwargs["y"] + remove_first_col = kwargs.get("remove_first_col", False) + with torch.enable_grad(): if remove_first_col: - x_in = x[:, 1:].detach().requires_grad_(True).float() + x_in = features[:, 1:].detach().requires_grad_(True).float() else: - x_in = x.detach().requires_grad_(True).float() - logits = classifier(x_in, t) + x_in = features.detach().requires_grad_(True).float() + logits = classifier(x_in, timestep) log_probs = F.log_softmax(logits, dim=-1) selected = log_probs[range(len(logits)), y.view(-1)] return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale @@ -216,7 +220,7 @@ def cond_fn( curr_ys = torch.tensor(np.array(ys[curr_index:end_index]).reshape(-1, 1), requires_grad=False) curr_model_kwargs = {} curr_model_kwargs["y"] = curr_ys - curr_sample, _ = diffusion.conditional_sample(ys=curr_ys, model_kwargs=curr_model_kwargs, cond_fn=cond_fn) + curr_sample, _ = diffusion.conditional_sample(targets=curr_ys, model_kwargs=curr_model_kwargs, cond_fn=cond_fn) all_rows.extend([sample.cpu().numpy() for sample in [curr_sample]]) all_clusters.extend([curr_ys.cpu().numpy() for curr_ys in [curr_ys]]) curr_index += sample_batch_size From d648615ab74c9d7118b98d597f35cff501a5cbee Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 17 Oct 2025 12:17:43 -0400 Subject: [PATCH 4/5] David's CR --- .../gaussian_multinomial_diffusion.py | 239 +++++++++--------- 1 file changed, 125 insertions(+), 114 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py index 469ea33b..c859acea 100644 --- a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py +++ b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py @@ -8,8 +8,8 @@ import math from collections.abc import Callable from enum import Enum -from logging import DEBUG, INFO -from typing import Any, Literal, Protocol, cast +from logging import DEBUG, INFO, WARNING +from typing import Any, Protocol import numpy as np import torch @@ -17,6 +17,7 @@ from torch.nn import functional from midst_toolkit.common.logger import log +from midst_toolkit.common.variables import DEVICE from midst_toolkit.models.clavaddpm.diffusion_utils import ( FoundNaNsError, discretized_gaussian_log_likelihood, @@ -66,6 +67,13 @@ class Parametrization(Enum): DIRECT = "direct" +class SampleMethod(Enum): + """Possible types of sample method.""" + + UNIFORM = "uniform" + IMPORTANCE = "importance" + + class ConditioningFunction(Protocol): """The definition of a function used to condition the model output.""" @@ -122,7 +130,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, alpha_bar: Callable, max_b which defines the cumulative product of (1-beta) over time from t = [0,1]. Args: - num_diffusion_timesteps: The number of betas to produce. + num_diffusion_timesteps: The number of timesteps to produce the betas. alpha_bar: A lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. @@ -132,6 +140,9 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, alpha_bar: Callable, max_b Returns: The beta schedule. """ + if max_beta >= 1: + log(WARNING, f"max_beta is set to {max_beta}. Use values lower than 1 to prevent singularities.") + betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps @@ -153,7 +164,7 @@ def __init__( gaussian_parametrization: GaussianParametrization = GaussianParametrization.EPS, parametrization: Parametrization = Parametrization.X0, scheduler_type: SchedulerType = SchedulerType.COSINE, - device: torch.device | None = None, + device: torch.device = DEVICE, ): """ Initialize a GaussianMultinomialDiffusion instance. @@ -167,12 +178,9 @@ def __init__( gaussian_parametrization: The type of Gaussian parametrization. Default is GaussianParametrization.EPS. parametrization: The type of parametrization. Default is Parametrization.X0. scheduler_type: The type of scheduler. Default is SchedulerType.COSINE. - device: The device to use. Default is None, which means the device is the CPU. + device: The device to use. Default is midst_toolkit.common.variables.DEVICE. """ - if device is None: - device = torch.device("cpu") - - super(GaussianMultinomialDiffusion, self).__init__() + super().__init__() self.num_numerical_features = num_numerical_features self.num_classes = num_classes # it as a vector [K1, K2, ..., Km] @@ -215,15 +223,15 @@ def __init__( self.posterior_variance = betas * (1.0 - buffers["alphas_cumprod_prev"]) / (1.0 - buffers["alphas_cumprod"]) posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) self.posterior_log_variance_clipped = torch.from_numpy(posterior_log_variance_clipped).float().to(self.device) - posterior_mean_coef1 = betas * np.sqrt(buffers["alphas_cumprod_prev"]) / (1.0 - buffers["alphas_cumprod"]) + posterior_mean_coef1 = betas * torch.sqrt(buffers["alphas_cumprod_prev"]) / (1.0 - buffers["alphas_cumprod"]) self.posterior_mean_coef1 = posterior_mean_coef1.float().to(self.device) - coef2_denominator = (1.0 - buffers["alphas_cumprod_prev"]) * np.sqrt(buffers["alphas"].numpy()) + coef2_denominator = (1.0 - buffers["alphas_cumprod_prev"]) * torch.sqrt(buffers["alphas"]) coef2_numerator = 1.0 - buffers["alphas_cumprod"] self.posterior_mean_coef2 = (coef2_denominator / coef2_numerator).float().to(self.device) assert log_add_exp(buffers["log_alpha"], buffers["log_1_min_alpha"]).abs().sum().item() < 1.0e-5 assert log_add_exp(buffers["log_cumprod_alpha"], buffers["log_1_min_cumprod_alpha"]).abs().sum().item() < 1e-5 - diff: Tensor = cast(Tensor, np.cumsum(buffers["log_alpha"]) - buffers["log_cumprod_alpha"]) + diff = torch.cumsum(buffers["log_alpha"], dim=0) - buffers["log_cumprod_alpha"] assert diff.abs().sum().item() < 1.0e-5 # Convert to float32 and register buffers. @@ -243,19 +251,19 @@ def _calculate_buffer_values(self) -> dict[str, Tensor]: a = 1.0 - get_named_beta_schedule(self.scheduler_type, self.num_timesteps) alphas = torch.tensor(a.astype("float64")) - log_alpha = torch.tensor(np.log(alphas)) - log_cumprod_alpha = torch.tensor(np.cumsum(log_alpha)) + log_alpha = torch.log(alphas) + log_cumprod_alpha = torch.cumsum(log_alpha, dim=0) - log_1_min_alpha: Tensor = log_1_min_a(log_alpha) - log_1_min_cumprod_alpha: Tensor = log_1_min_a(log_cumprod_alpha) + log_1_min_alpha = log_1_min_a(log_alpha) + log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0)) + alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1])) alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0)) - sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod)) - sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1.0 - alphas_cumprod)) - sqrt_recip_alphas_cumprod = torch.tensor(np.sqrt(1.0 / alphas_cumprod)) - sqrt_recipm1_alphas_cumprod = torch.tensor(np.sqrt(1.0 / alphas_cumprod - 1)) + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) + sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) return { "alphas": alphas, @@ -314,41 +322,43 @@ def gaussian_q_sample(self, x_start: Tensor, timestep: Tensor, noise: Tensor | N def gaussian_q_posterior_mean_variance( self, - x_start: Tensor, - features: Tensor, + features_start: Tensor, + features_timestep: Tensor, timestep: Tensor, ) -> tuple[Tensor, Tensor, Tensor]: """ Calculate the mean and variance of the Gaussian posterior distribution. Args: - x_start: The initial, noiseless input. - features: The features used to compute the Gaussian parameters. + features_start: The initial, noiseless input. + features_timestep: The features used to compute the Gaussian parameters at the given timestep. timestep: The timestep. Returns: A tuple with 3 tensors: the mean, the variance, and the log variance of the Gaussian posterior distribution. """ - assert x_start.shape == features.shape + assert features_start.shape == features_timestep.shape posterior_mean = ( - extract(self.posterior_mean_coef1, timestep, features.shape) * x_start - + extract(self.posterior_mean_coef2, timestep, features.shape) * features + extract(self.posterior_mean_coef1, timestep, features_timestep.shape) * features_start + + extract(self.posterior_mean_coef2, timestep, features_timestep.shape) * features_timestep + ) + posterior_variance = extract(self.posterior_variance, timestep, features_timestep.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, timestep, features_timestep.shape ) - posterior_variance = extract(self.posterior_variance, timestep, features.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, timestep, features.shape) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] + == features_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def gaussian_p_mean_variance( self, model_output: Tensor, - features: Tensor, + features_timestep: Tensor, timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Tensor]: @@ -357,10 +367,8 @@ def gaussian_p_mean_variance( Args: model_output: The model output. - features: The features of the Gaussian distribution. + features_timestep: The features of the Gaussian distribution at the given timestep. timestep: The timestep. - clip_denoised: Whether to clip the denoised output. Optional, default is False. - denoised_fn: The denoised function. Optional, default is None. model_kwargs: The model kwargs. Optional, default is None. Returns: @@ -373,23 +381,25 @@ def gaussian_p_mean_variance( if model_kwargs is None: model_kwargs = {} - batch_size, _ = features.shape[:2] + batch_size, _ = features_timestep.shape[:2] assert timestep.shape == (batch_size,) model_variance = torch.cat( [ - self.posterior_variance[1].unsqueeze(0).to(features.device), + self.posterior_variance[1].unsqueeze(0).to(self.device), (1.0 - self.alphas)[1:], ], dim=0, ) model_log_variance = torch.log(model_variance) - model_variance = extract(model_variance, timestep, features.shape) - model_log_variance = extract(model_log_variance, timestep, features.shape) + model_variance = extract(model_variance, timestep, features_timestep.shape) + model_log_variance = extract(model_log_variance, timestep, features_timestep.shape) if self.gaussian_parametrization == GaussianParametrization.EPS: - pred_xstart = self._predict_xstart_from_eps(features=features, timestep=timestep, eps=model_output) + pred_xstart = self._predict_xstart_from_eps( + features_timestep=features_timestep, timestep=timestep, eps=model_output + ) elif self.gaussian_parametrization == GaussianParametrization.X0: pred_xstart = model_output @@ -398,15 +408,17 @@ def gaussian_p_mean_variance( raise ValueError(f"Unsupported Gaussian parametrization: {self.gaussian_parametrization}") model_mean, _, _ = self.gaussian_q_posterior_mean_variance( - x_start=pred_xstart, features=features, timestep=timestep + features_start=pred_xstart, + features_timestep=features_timestep, + timestep=timestep, ) - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == features.shape, ( + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == features_timestep.shape, ( "Expected shapes to be equal, but got: ", f"model_mean.shape: {model_mean.shape}, ", f"model_log_variance.shape: {model_log_variance.shape}, ", f"pred_xstart.shape: {pred_xstart.shape}, ", - f"features.shape: {features.shape}", + f"features.shape: {features_timestep.shape}", ) return { @@ -419,8 +431,8 @@ def gaussian_p_mean_variance( def _vb_terms_bpd( self, model_output: Tensor, - x_start: Tensor, - features: Tensor, + features_start: Tensor, + features_timestep: Tensor, timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Tensor]: @@ -429,8 +441,8 @@ def _vb_terms_bpd( Args: model_output: The model output. - x_start: The initial, noiseless input. - features: The features used to compute the Gaussian parameters. + features_start: The initial, noiseless input. + features_timestep: The features used to compute the Gaussian parameters at the given timestep. timestep: The timestep. model_kwargs: The model kwargs. Optional, default is None. @@ -445,18 +457,20 @@ def _vb_terms_bpd( model_kwargs = {} true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance( - x_start=x_start, - features=features, + features_start=features_start, + features_timestep=features_timestep, timestep=timestep, ) - p_mean_variance = self.gaussian_p_mean_variance(model_output, features, timestep, model_kwargs=model_kwargs) + p_mean_variance = self.gaussian_p_mean_variance( + model_output, features_timestep, timestep, model_kwargs=model_kwargs + ) kl = normal_kl(true_mean, true_log_variance_clipped, p_mean_variance["mean"], p_mean_variance["log_variance"]) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( - x_start, means=p_mean_variance["mean"], log_scales=0.5 * p_mean_variance["log_variance"] + features_start, means=p_mean_variance["mean"], log_scales=0.5 * p_mean_variance["log_variance"] ) - assert decoder_nll.shape == x_start.shape + assert decoder_nll.shape == features_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, @@ -491,8 +505,8 @@ def _prior_gaussian(self, x_start: Tensor) -> Tensor: def _gaussian_loss( self, model_out: Tensor, - x_start: Tensor, - features: Tensor, + features_start: Tensor, + features_timestep: Tensor, timestep: Tensor, noise: Tensor, model_kwargs: dict[str, Any] | None = None, @@ -502,8 +516,8 @@ def _gaussian_loss( Args: model_out: The model output. - x_start: The initial, noiseless input. - features: The features used to compute the Gaussian parameters. + features_start: The initial, noiseless input. + features_timestep: The features used to compute the Gaussian parameters at the given timestep. timestep: The timestep. noise: The noise. model_kwargs: The model kwargs. Optional, default is None. @@ -520,30 +534,30 @@ def _gaussian_loss( if self.gaussian_loss_type == GaussianLossType.KL: return self._vb_terms_bpd( model_output=model_out, - x_start=x_start, - features=features, + features_start=features_start, + features_timestep=features_timestep, timestep=timestep, model_kwargs=model_kwargs, )["output"] raise ValueError(f"Unsupported Gaussian loss type: {self.gaussian_loss_type}") - def _predict_xstart_from_eps(self, features: Tensor, timestep: Tensor, eps: Tensor) -> Tensor: + def _predict_xstart_from_eps(self, features_timestep: Tensor, timestep: Tensor, eps: Tensor) -> Tensor: """ Predict the xstart from the eps. Args: - features: The features used to compute the Gaussian parameters. + features_timestep: The features at the given timestep. timestep: The timestep. eps: The eps. Returns: The predicted xstart. """ - assert features.shape == eps.shape + assert features_timestep.shape == eps.shape return ( - extract(self.sqrt_recip_alphas_cumprod, timestep, features.shape) * features - - extract(self.sqrt_recipm1_alphas_cumprod, timestep, features.shape) * eps + extract(self.sqrt_recip_alphas_cumprod, timestep, features_timestep.shape) * features_timestep + - extract(self.sqrt_recipm1_alphas_cumprod, timestep, features_timestep.shape) * eps ) def _predict_eps_from_xstart(self, features: Tensor, timestep: Tensor, pred_xstart: Tensor) -> Tensor: @@ -564,22 +578,22 @@ def _predict_eps_from_xstart(self, features: Tensor, timestep: Tensor, pred_xsta def condition_mean( self, - cond_fn: ConditioningFunction, + conditioning_function: ConditioningFunction, p_mean_var: dict[str, Tensor], features: Tensor, timestep: Tensor, model_kwargs: dict[str, Any] | None = None, ) -> Tensor: """ - Compute the mean for the previous step, given a function cond_fn that - computes the gradient of a conditional log probability with respect to - x. In particular, cond_fn computes grad(log(p(y|x))), and we want to - condition on y. + Compute the mean for the previous step, given a function ``conditioning_function`` + that computes the gradient of a conditional log probability with respect to + ``features``. In particular, ``conditioning_function`` computes grad(log(p(y|x))), + and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). Args: - cond_fn: The conditioning function. + conditioning_function: The conditioning function. p_mean_var: The mean and variance of the Gaussian prior distribution. features: The features used to compute the Gaussian parameters. timestep: The timestep. @@ -591,12 +605,12 @@ def condition_mean( if model_kwargs is None: model_kwargs = {} - gradient = cond_fn(features, timestep, **model_kwargs) + gradient = conditioning_function(features, timestep, **model_kwargs) return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() def condition_score( self, - cond_fn: ConditioningFunction, + conditioning_function: ConditioningFunction, p_mean_var: dict[str, Tensor], features: Tensor, timestep: Tensor, @@ -604,22 +618,24 @@ def condition_score( ) -> dict[str, Tensor]: """ Compute what the p_mean_variance output would have been, should the - model's score function be conditioned by cond_fn. + model's score function be conditioned by ``conditioning_function``. - See condition_mean() for details on cond_fn. + See condition_mean() for details on ``conditioning_function``. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). Args: - cond_fn: The conditioning function. + conditioning_function: The conditioning function. p_mean_var: The mean and variance of the Gaussian prior distribution. features: The features used to compute the Gaussian parameters. timestep: The timestep. model_kwargs: The model kwargs. Optional, default is None. Returns: - The mean and variance of the Gaussian prior distribution. + A copy of the ``p_mean_var`` dictionary with the following additional keys: + - "pred_xstart": the predicted xstart. + - "mean": the mean of the Gaussian prior distribution. """ if model_kwargs is None: model_kwargs = {} @@ -627,13 +643,13 @@ def condition_score( alpha_bar = extract(self.alphas_cumprod, timestep, features.shape) eps = self._predict_eps_from_xstart(features, timestep, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn(features, timestep, **model_kwargs) + eps = eps - (1 - alpha_bar).sqrt() * conditioning_function(features, timestep, **model_kwargs) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(features, timestep, eps) out["mean"], _, _ = self.gaussian_q_posterior_mean_variance( - x_start=out["pred_xstart"], - features=features, + features_start=out["pred_xstart"], + features_timestep=features, timestep=timestep, ) return out @@ -700,7 +716,7 @@ def q_pred_one_timestep(self, log_x_t: Tensor, timestep: Tensor) -> Tensor: Calculate the predicted log probability for one timestep. Args: - log_x_t: The log probability of the features. + log_x_t: The log samples of the features at the given timestep. timestep: The timestep. Returns: @@ -717,10 +733,10 @@ def q_pred_one_timestep(self, log_x_t: Tensor, timestep: Tensor) -> Tensor: def q_pred(self, log_x_start: Tensor, timestep: Tensor) -> Tensor: """ - Calculate the predicted log probability for one timestep. + Compute the predicted log-probability at ``timestep`` given ``log_x_start``. Args: - log_x_start: The log probability of the start. + log_x_start: The log sample of the start. timestep: The timestep. Returns: @@ -740,7 +756,7 @@ def predict_start(self, model_out: Tensor, log_x_t: Tensor) -> Tensor: Args: model_out: The model output. - log_x_t: The log probability of the features. + log_x_t: The log sample of the features at the given timestep. Returns: The predicted start. @@ -759,8 +775,8 @@ def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, timestep: Tensor) -> Calculate the posterior probability for one timestep. Args: - log_x_start: The log probability of the initial input. - log_x_t: The log probability of the features. + log_x_start: The log sample of the initial input. + log_x_t: The log sample of the features at the given timestep. timestep: The timestep. Returns: @@ -785,15 +801,15 @@ def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, timestep: Tensor) -> def p_pred(self, model_out: Tensor, log_x: Tensor, timestep: Tensor) -> Tensor: """ - Predict the log probability of the model output. + Predict the start from the model output based on the parametrization set in ``self.parametrization``. Args: model_out: The model output. - log_x: The log probability of the features. + log_x: The log sample of the features. timestep: The timestep. Returns: - The log probability of the model output. + The predicted start from the model output. """ if self.parametrization == Parametrization.X0: log_x_recon = self.predict_start(model_out, log_x) @@ -814,7 +830,7 @@ def p_sample(self, model_out: Tensor, log_x: Tensor, timestep: Tensor) -> Tensor Args: model_out: The model output. - log_x: The log probability of the features. + log_x: The log sample of the features. timestep: The timestep. Returns: @@ -846,10 +862,10 @@ def log_sample_categorical(self, logits: Tensor) -> Tensor: def q_sample(self, log_x_start: Tensor, timestep: Tensor) -> Tensor: """ - Sample from the logits for one timestep. + Sample from the log of the initial input for one timestep. Args: - log_x_start: The log probability of the initial input. + log_x_start: The log of the initial input. timestep: The timestep. Returns: @@ -863,7 +879,7 @@ def kl_prior(self, log_x_start: Tensor) -> Tensor: Calculate the KL divergence between the prior and the posterior. Args: - log_x_start: The log probability of the initial input. + log_x_start: The log sample of the initial input. Returns: The KL divergence between the prior and the posterior. @@ -891,8 +907,8 @@ def compute_lt( Args: model_out: The model output. - log_x_start: The log probability of the initial input. - log_x_t: The log probability of the features. + log_x_start: The log sample of the initial input. + log_x_t: The log samples of the features at the given timestep. timestep: The timestep. detach_mean: Whether to detach the mean. @@ -918,7 +934,7 @@ def sample_time( self, batch_size: int, device: torch.device, - method: Literal["uniform", "importance"] = "uniform", + method: SampleMethod = SampleMethod.UNIFORM, ) -> tuple[Tensor, Tensor]: """ Sample the timestep. @@ -931,9 +947,9 @@ def sample_time( Returns: The timestep and the probability of the timestep. """ - if method == "importance": + if method == SampleMethod.IMPORTANCE: if not (self.lt_count > 10).all(): - return self.sample_time(batch_size, device, method="uniform") + return self.sample_time(batch_size, device, method=SampleMethod.UNIFORM) lt_sqrt = torch.sqrt(self.lt_history + 1e-10) + 0.0001 lt_sqrt[0] = lt_sqrt[1] # Overwrite decoder term with L1. @@ -945,7 +961,7 @@ def sample_time( return timestep, p_timestep - if method == "uniform": + if method == SampleMethod.UNIFORM: timestep = torch.randint(0, self.num_timesteps, (batch_size,), device=device).long() p_timestep = torch.ones_like(timestep).float() / self.num_timesteps @@ -966,8 +982,8 @@ def _multinomial_loss( Args: model_out: The model output. - log_x_start: The log probability of the initial input. - log_x_t: The log probability of the features. + log_x_start: The log samples of the initial input. + log_x_t: The log samples of the features at the given timestep. timestep: The timestep. p_timestep: The probability of the timestep. @@ -994,18 +1010,17 @@ def mixed_loss(self, features: Tensor, out_dict: dict[str, Tensor]) -> tuple[Ten The multinomial loss and the Gaussian loss. """ batch_size = features.shape[0] - device = features.device - timestep, p_timestep = self.sample_time(batch_size, device, "uniform") + timestep, p_timestep = self.sample_time(batch_size, self.device, method=SampleMethod.UNIFORM) numerical_features = features[:, : self.num_numerical_features] categorical_features = features[:, self.num_numerical_features :] numerical_features_ts = numerical_features - log_categrocial_features_ts = categorical_features if numerical_features.shape[1] > 0: noise = torch.randn_like(numerical_features) numerical_features_ts = self.gaussian_q_sample(numerical_features, timestep, noise=noise) + log_categrocial_features_ts = categorical_features if categorical_features.shape[1] > 0: log_x_cat = index_to_log_onehot(categorical_features.long(), torch.from_numpy(self.num_classes)) log_categrocial_features_ts = self.q_sample(log_x_start=log_x_cat, timestep=timestep) @@ -1114,10 +1129,9 @@ def gaussian_ddim_sample( """ features = noise batch_size = features.shape[0] - device = features.device for t in reversed(range(num_timesteps)): log(DEBUG, f"Sample timestep {t:4d}") - t_array = (torch.ones(batch_size, device=device) * t).long() + t_array = (torch.ones(batch_size, device=self.device) * t).long() out_num = self._denoise_fn(features, t_array, **out_dict) features = self.gaussian_ddim_step( out_num, @@ -1178,12 +1192,11 @@ def gaussian_ddim_reverse_sample( The predicted features. """ batch_size = features.shape[0] - device = features.device output_features = features.clone() for t in range(num_timesteps): log(DEBUG, f"Reverse timestep {t:4d}") - t_array = (torch.ones(batch_size, device=device) * t).long() + t_array = (torch.ones(batch_size, device=self.device) * t).long() out_num = self._denoise_fn(output_features, t_array, **out_dict) output_features = self.gaussian_ddim_reverse_step(out_num, output_features, t_array) @@ -1202,12 +1215,12 @@ def multinomial_ddim_step( Args: model_out_cat: The categorical model output. - log_x_t: The log probability of the features. + log_x_t: The log samples of the features at the given timestep. timestep: The timestep. eta: The DDIM stochasticity coefficient. Optional, default is 0.0. Returns: - The predicted log probability. + The multinomial DDIM step. """ log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) @@ -1235,7 +1248,7 @@ def multinomial_ddim_step( @torch.no_grad() def sample_ddim( self, - num_samples: int, + batch_size: int, target_dist: Tensor, model_kwargs: dict[str, Any] | None = None, cond_fn: ConditioningFunction | None = None, @@ -1244,7 +1257,7 @@ def sample_ddim( Sample using DDIM. Args: - num_samples: The number of samples. + batch_size: The batch size. target_dist: Class distribution to sample labels from. model_kwargs: The model kwargs. Optional, default is None. cond_fn: The conditioning function. Optional, default is None. @@ -1255,7 +1268,6 @@ def sample_ddim( if model_kwargs is None: model_kwargs = {} - batch_size = num_samples z_norm = torch.randn((batch_size, self.num_numerical_features), device=self.device) assert self.num_classes is not None @@ -1345,7 +1357,7 @@ def conditional_sample( @torch.no_grad() def sample( self, - num_samples: int, + batch_size: int, target_dist: Tensor, model_kwargs: dict[str, Any] | None = None, cond_fn: ConditioningFunction | None = None, @@ -1354,7 +1366,7 @@ def sample( Sample using ancestral (DDPM-style) sampling. Args: - num_samples: The number of samples. + batch_size: The batch size. target_dist: Class distribution to sample labels from. model_kwargs: The model kwargs. Optional, default is None. cond_fn: The conditioning function. Optional, default is None. @@ -1365,7 +1377,6 @@ def sample( if model_kwargs is None: model_kwargs = {} - batch_size = num_samples z_norm = torch.randn((batch_size, self.num_numerical_features), device=self.device) assert self.num_classes is not None From 93ecd5aa8d74b8010337dc491131348032985ca0 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 17 Oct 2025 14:10:32 -0400 Subject: [PATCH 5/5] Change I forgot to submit --- .../models/clavaddpm/gaussian_multinomial_diffusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py index c859acea..fbbf097b 100644 --- a/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py +++ b/src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py @@ -600,7 +600,7 @@ def condition_mean( model_kwargs: The model kwargs. Optional, default is None. Returns: - The mean for the previous step. + The conditioned mean for the previous step. """ if model_kwargs is None: model_kwargs = {} @@ -1020,12 +1020,12 @@ def mixed_loss(self, features: Tensor, out_dict: dict[str, Tensor]) -> tuple[Ten noise = torch.randn_like(numerical_features) numerical_features_ts = self.gaussian_q_sample(numerical_features, timestep, noise=noise) - log_categrocial_features_ts = categorical_features + log_categorical_features_ts = categorical_features if categorical_features.shape[1] > 0: log_x_cat = index_to_log_onehot(categorical_features.long(), torch.from_numpy(self.num_classes)) - log_categrocial_features_ts = self.q_sample(log_x_start=log_x_cat, timestep=timestep) + log_categorical_features_ts = self.q_sample(log_x_start=log_x_cat, timestep=timestep) - input_features = torch.cat([numerical_features_ts, log_categrocial_features_ts], dim=1) + input_features = torch.cat([numerical_features_ts, log_categorical_features_ts], dim=1) model_output = self._denoise_fn(input_features, timestep, **out_dict) @@ -1038,7 +1038,7 @@ def mixed_loss(self, features: Tensor, out_dict: dict[str, Tensor]) -> tuple[Ten multinomial_loss = self._multinomial_loss( model_categorical_output, log_x_cat, - log_categrocial_features_ts, + log_categorical_features_ts, timestep, p_timestep, )