Skip to content

Conversation

lotif
Copy link
Collaborator

@lotif lotif commented Oct 8, 2025

PR Type

Fix

Short Description

Clickup Ticket(s): https://app.clickup.com/t/868fuke6e

General improvements in the midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py file:

  • Removing Ruff and mypy ignores
  • Refactoring the __init__ function
  • Adding docstrings
  • Fixing parameter and variable names
  • Removing unused parameters
  • Removing unused methods

Tests Added

Only minor adjustments, the functionality does not change.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to have to trust you a bit on the documentation fidelity in a lot of places, since I don't know the code nearly as well as you do 🙂. There were a few places where I wanted to clarify my understanding of the documentation. So definitely tell me if I'm off base anywhere.

Some other fairly minor comments throughout.

self,
batch_size: int,
device: torch.device,
method: Literal["uniform", "importance"] = "uniform",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can just make this a local enum here?

DIRECT = "direct"


class ConditioningFunction(Protocol):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More for my own learning, but do you mind explaining why it might be advantageous for this to be a class inheriting from Protocol rather than a Callable? It's an easier typing annotation and also adds some documentation, which is nice. Perhaps that's enough justification, but also though perhaps there was something deeper?

prevent singularities.
Args:
num_diffusion_timesteps: The number of betas to produce.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is technically true, but also doesn't really capture what the variable actually represents? (I know you didn't write it 😂)

produces the cumulative product of (1-beta) up to that
part of the diffusion process.
max_beta: The maximum beta to use; use values lower than 1 to
prevent singularities.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give the documentation here, should we assert that max_beta is, in fact, lower than 1...to prevent singularities... 😂

"""
if device is None:
device = torch.device("cpu")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The super call below is a bit of an old-school style isn't it?

Args:
log_x_start: The log probability of the initial input.
log_x_t: The log probability of the features.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comments about the log prob here.

Args:
model_out: The model output.
log_x: The log probability of the features.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about probabilities. Also are we sure we're predicting probabilities of the model output?

categorical_features = features[:, self.num_numerical_features :]

numerical_features_ts = numerical_features
log_categrocial_features_ts = categorical_features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the line above are a bit of weird indirection and tying values togehter. We keep numerical_features and categorical_features around only to check their shapes. Rather than creating new tensors here (without copying), why not just use the original tensors?


b = num_samples
z_norm = torch.randn((b, self.num_numerical_features), device=self.device)
batch_size = num_samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just call num_samples argument batch size here? It doesn't look like num_samples is used elsewhere anyway.


b = num_samples
z_norm = torch.randn((b, self.num_numerical_features), device=self.device)
batch_size = num_samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just call num_samples argument batch size here? It doesn't look like num_samples is used elsewhere anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants