-
Notifications
You must be signed in to change notification settings - Fork 1
Trainer: fixing the gaussian_multinomial_diffusion.py file #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to have to trust you a bit on the documentation fidelity in a lot of places, since I don't know the code nearly as well as you do 🙂. There were a few places where I wanted to clarify my understanding of the documentation. So definitely tell me if I'm off base anywhere.
Some other fairly minor comments throughout.
self, | ||
batch_size: int, | ||
device: torch.device, | ||
method: Literal["uniform", "importance"] = "uniform", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can just make this a local enum here?
DIRECT = "direct" | ||
|
||
|
||
class ConditioningFunction(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More for my own learning, but do you mind explaining why it might be advantageous for this to be a class inheriting from Protocol rather than a Callable? It's an easier typing annotation and also adds some documentation, which is nice. Perhaps that's enough justification, but also though perhaps there was something deeper?
prevent singularities. | ||
Args: | ||
num_diffusion_timesteps: The number of betas to produce. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This documentation is technically true, but also doesn't really capture what the variable actually represents? (I know you didn't write it 😂)
produces the cumulative product of (1-beta) up to that | ||
part of the diffusion process. | ||
max_beta: The maximum beta to use; use values lower than 1 to | ||
prevent singularities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give the documentation here, should we assert that max_beta
is, in fact, lower than 1...to prevent singularities... 😂
""" | ||
if device is None: | ||
device = torch.device("cpu") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The super call below is a bit of an old-school style isn't it?
Args: | ||
log_x_start: The log probability of the initial input. | ||
log_x_t: The log probability of the features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments about the log prob here.
Args: | ||
model_out: The model output. | ||
log_x: The log probability of the features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about probabilities. Also are we sure we're predicting probabilities of the model output?
categorical_features = features[:, self.num_numerical_features :] | ||
|
||
numerical_features_ts = numerical_features | ||
log_categrocial_features_ts = categorical_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the line above are a bit of weird indirection and tying values togehter. We keep numerical_features
and categorical_features
around only to check their shapes. Rather than creating new tensors here (without copying), why not just use the original tensors?
|
||
b = num_samples | ||
z_norm = torch.randn((b, self.num_numerical_features), device=self.device) | ||
batch_size = num_samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just call num_samples
argument batch size here? It doesn't look like num_samples is used elsewhere anyway.
|
||
b = num_samples | ||
z_norm = torch.randn((b, self.num_numerical_features), device=self.device) | ||
batch_size = num_samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just call num_samples argument batch size here? It doesn't look like num_samples is used elsewhere anyway.
PR Type
Fix
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868fuke6e
General improvements in the
midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py
file:__init__
functionTests Added
Only minor adjustments, the functionality does not change.