-
Notifications
You must be signed in to change notification settings - Fork 32
Flow matching smc #236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flow matching smc #236
Changes from all commits
39ac4d5
36f781d
2b385d7
db1ee90
d401b71
4ee192b
98da709
555d38c
ca16b9e
cb8ffea
ccec5af
b77b6a0
1f06c13
aa57ce8
bdba137
7f46ff9
98084c3
fea5eb5
497ce92
f8754b1
826b2bd
211d1f9
cd5ab41
61087d8
8a2ced0
1729bdf
922efe4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,11 +1,68 @@ | ||||||||||||||||||
| from typing import Callable, List, Tuple | ||||||||||||||||||
| from typing import Callable, List, Tuple, Optional | ||||||||||||||||||
|
|
||||||||||||||||||
| import equinox as eqx | ||||||||||||||||||
| import jax | ||||||||||||||||||
| import jax.numpy as jnp | ||||||||||||||||||
| from jaxtyping import Array, Float, PRNGKeyArray | ||||||||||||||||||
| from abc import abstractmethod | ||||||||||||||||||
|
|
||||||||||||||||||
| from flowMC.resource.nf_model.base import Bijection, Distribution | ||||||||||||||||||
|
|
||||||||||||||||||
| class Bijection(eqx.Module): | ||||||||||||||||||
| """Base class for bijective transformations. | ||||||||||||||||||
|
|
||||||||||||||||||
| This is an abstract template that should not be directly used. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def __init__(self): | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
|
||||||||||||||||||
|
Comment on lines
+7
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. abstractmethod without ABC won’t enforce abstractness.
Apply this diff to enforce the contract: -from abc import abstractmethod
+from abc import ABC, abstractmethod
@@
-class Bijection(eqx.Module):
+class Bijection(eqx.Module, ABC):
@@
-class Distribution(eqx.Module):
+class Distribution(eqx.Module, ABC):🤖 Prompt for AI Agents |
||||||||||||||||||
| def __call__( | ||||||||||||||||||
| self, | ||||||||||||||||||
| x: Float[Array, " n_dim"], | ||||||||||||||||||
| condition: Float[Array, " n_condition"], | ||||||||||||||||||
| ) -> tuple[Float[Array, " n_dim"], Float]: | ||||||||||||||||||
| return self.forward(x, condition) | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| x: Float[Array, " n_dim"], | ||||||||||||||||||
| condition: Float[Array, " n_condition"], | ||||||||||||||||||
| ) -> tuple[Float[Array, " n_dim"], Float]: | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def inverse( | ||||||||||||||||||
| self, | ||||||||||||||||||
| x: Float[Array, " n_dim"], | ||||||||||||||||||
| condition: Float[Array, " n_condition"], | ||||||||||||||||||
| ) -> tuple[Float[Array, " n_dim"], Float]: | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class Distribution(eqx.Module): | ||||||||||||||||||
| """Base class for probability distributions. | ||||||||||||||||||
|
|
||||||||||||||||||
| This is an abstract template that should not be directly used. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def __init__(self): | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
|
||||||||||||||||||
| def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None) -> Array: | ||||||||||||||||||
| return self.log_prob(x) | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def log_prob(self, x: Array) -> Array: | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
|
||||||||||||||||||
| @abstractmethod | ||||||||||||||||||
| def sample( | ||||||||||||||||||
| self, rng_key: PRNGKeyArray, n_samples: int | ||||||||||||||||||
| ) -> Float[Array, " n_samples n_features"]: | ||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||
|
Comment on lines
+62
to
65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Unify sample return shape variable: use n_dim consistently. The base Apply this diff: def sample(
self, rng_key: PRNGKeyArray, n_samples: int
- ) -> Float[Array, " n_samples n_features"]:
+ ) -> Float[Array, " n_samples n_dim"]:
raise NotImplementedErrorYou’ll also want to align concrete implementations—see the follow-up snippet for 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class MLP(eqx.Module): | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify no stale
local_kernelimports remain across the repoTo avoid runtime/import errors in docs or examples, verify all references were updated from
resource.local_kerneltoresource.kernel, and thatMALA.pyexists at the new location.Run:
🏁 Script executed:
Length of output: 3343
Missing
MALA.pyinflowMC/resource/kernelThe new imports reference
flowMC.resource.kernel.MALA, but there is noMALA.pyfile undersrc/flowMC/resource/kernel/. Please add or moveMALA.py(containing theMALAclass) into that directory so all imports resolve.All stale
flowMC.resource.local_kernelreferences have been removed, and new-style imports are present in:🤖 Prompt for AI Agents