Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
39ac4d5
Rename folder from nf_model to model
kazewong Aug 5, 2025
36f781d
Revert "Rename folder from nf_model to model"
kazewong Aug 5, 2025
2b385d7
Move nf_model to model/nf_model and update imports
kazewong Aug 5, 2025
db1ee90
Refactor model packaging
kazewong Aug 12, 2025
d401b71
add base.py
kazewong Aug 12, 2025
4ee192b
Rename local_kernel to kernel
kazewong Aug 13, 2025
98da709
Move nf_proposal to kernel
kazewong Aug 13, 2025
555d38c
Add SMC skeleton
kazewong Aug 13, 2025
ca16b9e
add flow matching scaffolding
kazewong Aug 13, 2025
cb8ffea
Add diffrax as dependency
kazewong Aug 16, 2025
ccec5af
Add path example
kazewong Aug 16, 2025
b77b6a0
Constructing solver for sample
kazewong Aug 16, 2025
1f06c13
Add initial training loop. Not working yet
kazewong Aug 16, 2025
aa57ce8
It seems training now
kazewong Aug 16, 2025
bdba137
Implement code changes to enhance functionality and improve performance
kazewong Aug 16, 2025
7f46ff9
Add dt parameter to sample method in FlowMatchingModel
kazewong Aug 16, 2025
98084c3
Update train_flow_match.ipynb
kazewong Aug 16, 2025
fea5eb5
Add log_prob function
kazewong Aug 16, 2025
497ce92
Update flowMC log_prob notebook
kazewong Aug 16, 2025
f8754b1
Update train_flow_match.ipynb
kazewong Aug 16, 2025
826b2bd
Add whitening procedure to FlowMatchingModel
google-labs-jules[bot] Aug 17, 2025
211d1f9
Refactor flow matching code for readability and add diffrax to
kazewong Aug 17, 2025
cd5ab41
Merge pull request #233 from kazewong/add-whitening-to-flow-matching
kazewong Aug 17, 2025
61087d8
add tests
kazewong Aug 18, 2025
8a2ced0
Update test_flowmatching.py
kazewong Aug 18, 2025
1729bdf
Format test_flowmatching.py for improved readability
kazewong Aug 18, 2025
922efe4
Merge pull request #235 from kazewong/add-whitening-to-flow-matching
kazewong Aug 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: v1.1.396
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm]
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm, diffrax]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.1
hooks:
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/custom_strategy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flowMC.resource.local_kernel.Gaussian_random_walk import GaussianRandomWalk\n",
"from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk\n",
"from flowMC.resource.buffers import Buffer\n",
"from flowMC.resource.logPDF import LogPDF\n",
"from flowMC.strategy.optimization import AdamOptimization\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/parallel_tempering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"from jaxtyping import Float, Array\n",
"from typing import Any\n",
"\n",
"from flowMC.resource.local_kernel.MALA import MALA\n",
"from flowMC.resource.kernel.MALA import MALA\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify no stale local_kernel imports remain across the repo

To avoid runtime/import errors in docs or examples, verify all references were updated from resource.local_kernel to resource.kernel, and that MALA.py exists at the new location.

Run:


🏁 Script executed:

#!/bin/bash
# Verify the new MALA module exists under the reorganized path.
fd -HI 'MALA.py' | rg -nP 'flowMC/.*/resource/.*/kernel/.*/MALA\.py|flowMC/.*/resource/kernel/MALA\.py' || true

# Find any lingering old imports anywhere (expect: no matches).
rg -nP -C2 '\bflowMC\.resource\.local_kernel\b' || true

# Confirm new-style imports are present (expect: matches in updated notebooks/tests).
rg -nP -C2 '\bfrom\s+flowMC\.resource\.kernel\.MALA\s+import\s+MALA\b' || true

Length of output: 3343


Missing MALA.py in flowMC/resource/kernel

The new imports reference flowMC.resource.kernel.MALA, but there is no MALA.py file under src/flowMC/resource/kernel/. Please add or move MALA.py (containing the MALA class) into that directory so all imports resolve.

All stale flowMC.resource.local_kernel references have been removed, and new-style imports are present in:

  • docs/tutorials/parallel_tempering.ipynb (line 22)
  • test/unit/test_strategies.py (line 9)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (line 11)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (line 12)
  • test/unit/test_resources.py (line 5)
  • test/unit/test_kernels.py (line 7)
  • test/integration/test_MALA.py (line 6)
🤖 Prompt for AI Agents
In docs/tutorials/parallel_tempering.ipynb around line 22, the import references
flowMC.resource.kernel.MALA but there is no src/flowMC/resource/kernel/MALA.py;
add (or move) a MALA.py file into src/flowMC/resource/kernel/ that defines the
MALA class with the expected public API used by the listed files, and ensure the
module is importable (add/update __init__.py if necessary) so all imports in the
notebook and tests resolve.

"from flowMC.resource.buffers import Buffer\n",
"from flowMC.resource.states import State\n",
"from flowMC.strategy.take_steps import TakeSerialSteps\n",
Expand Down
215 changes: 215 additions & 0 deletions docs/tutorials/train_flow_match.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/tutorials/train_normalizing_flow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"import equinox as eqx # Equinox\n",
"\n",
"\n",
"from flowMC.resource.nf_model.realNVP import RealNVP\n",
"from flowMC.resource.nf_model.rqSpline import MaskedCouplingRQSpline"
"from flowMC.resource.model.nf_model.realNVP import RealNVP\n",
"from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline"
]
},
{
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ requires-python = ">=3.11"
keywords = ["sampling", "inference", "machine learning", "normalizing", "autodiff", "jax"]
dependencies = [
"chex>=0.1.87",
"diffrax>=0.7.0",
"equinox>=0.11.9",
"jax[cpu]>=0.5.0",
"jaxtyping>=0.2.36",
Expand Down Expand Up @@ -55,18 +56,17 @@ visualize = [
cuda = [
"jax[cuda12]>=0.5.0",
]
codeqa = [
"coveralls>=4.0.1",
"pre-commit>=4.0.1",
"pyright>=1.1.389",
"pytest>=8.3.3",
"ruff>=0.8.0",
]

[dependency-groups]
dev = [
"flowMC",
"ipykernel>=6.29.5",
"coveralls>=4.0.1",
"pre-commit>=4.0.1",
"pyright>=1.1.389",
"pytest>=8.3.3",
"ruff>=0.8.0",
"ipython>=8.30.0",
]

[tool.uv.sources]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
from typing import Callable

from flowMC.resource.local_kernel.base import ProposalBase
from flowMC.resource.kernel.base import ProposalBase
from flowMC.resource.logPDF import LogPDF


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree

from flowMC.resource.local_kernel.base import ProposalBase
from flowMC.resource.kernel.base import ProposalBase
from flowMC.resource.logPDF import LogPDF


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable

from flowMC.resource.logPDF import LogPDF
from flowMC.resource.local_kernel.base import ProposalBase
from flowMC.resource.kernel.base import ProposalBase


class MALA(ProposalBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Callable
import equinox as eqx

from flowMC.resource.nf_model.base import NFModel
from flowMC.resource.local_kernel.base import ProposalBase
from flowMC.resource.model.nf_model.base import NFModel
from flowMC.resource.kernel.base import ProposalBase
from flowMC.resource.logPDF import LogPDF


Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,68 @@
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray
from abc import abstractmethod

from flowMC.resource.nf_model.base import Bijection, Distribution

class Bijection(eqx.Module):
"""Base class for bijective transformations.

This is an abstract template that should not be directly used.
"""

@abstractmethod
def __init__(self):
raise NotImplementedError

Comment on lines +7 to +19
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

abstractmethod without ABC won’t enforce abstractness.

Bijection and Distribution don’t inherit from ABC, so @abstractmethod has no effect at class level. This allows accidental instantiation at runtime.

Apply this diff to enforce the contract:

-from abc import abstractmethod
+from abc import ABC, abstractmethod
@@
-class Bijection(eqx.Module):
+class Bijection(eqx.Module, ABC):
@@
-class Distribution(eqx.Module):
+class Distribution(eqx.Module, ABC):
🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 7 to 19, the @abstractmethod
has no effect because Bijection doesn't inherit from ABC; update the class to
inherit from abc.ABC (import ABC from abc) in addition to eqx.Module (e.g.,
class Bijection(ABC, eqx.Module)): this will enforce the abstract contract so
__init__ cannot be instantiated, and apply the same pattern for Distribution if
present elsewhere.

def __call__(
self,
x: Float[Array, " n_dim"],
condition: Float[Array, " n_condition"],
) -> tuple[Float[Array, " n_dim"], Float]:
return self.forward(x, condition)

@abstractmethod
def forward(
self,
x: Float[Array, " n_dim"],
condition: Float[Array, " n_condition"],
) -> tuple[Float[Array, " n_dim"], Float]:
raise NotImplementedError

@abstractmethod
def inverse(
self,
x: Float[Array, " n_dim"],
condition: Float[Array, " n_condition"],
) -> tuple[Float[Array, " n_dim"], Float]:
raise NotImplementedError


class Distribution(eqx.Module):
"""Base class for probability distributions.

This is an abstract template that should not be directly used.
"""

@abstractmethod
def __init__(self):
raise NotImplementedError

def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None) -> Array:
return self.log_prob(x)

@abstractmethod
def log_prob(self, x: Array) -> Array:
raise NotImplementedError

@abstractmethod
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_features"]:
raise NotImplementedError
Comment on lines +62 to 65
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unify sample return shape variable: use n_dim consistently.

The base Distribution.sample returns shape " n_samples n_features" while the rest of the module uses " n_dim". This inconsistency will trip pyright/jaxtyping checks and downstream type expectations.

Apply this diff:

     def sample(
         self, rng_key: PRNGKeyArray, n_samples: int
-    ) -> Float[Array, " n_samples n_features"]:
+    ) -> Float[Array, " n_samples n_dim"]:
         raise NotImplementedError

You’ll also want to align concrete implementations—see the follow-up snippet for Gaussian.sample.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_features"]:
raise NotImplementedError
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_dim"]:
raise NotImplementedError
🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 62-65, the sample return
shape annotation uses " n_samples n_features" but the project standard is
"n_dim"; change the type string to " n_samples n_dim" and update any matching
type alias usage. Also update all concrete implementations (e.g.,
Gaussian.sample) to use the same " n_samples n_dim" return annotation so
pyright/jaxtyping types match end-to-end.



class MLP(eqx.Module):
Expand Down
Empty file.
Loading
Loading