Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jesterTOV/inference/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Normalizing flow models for gravitational wave inference."""

from .flow import Flow, load_model
from .flow import Flow, ConditionalFlow, load_model
from .bilby_extract import extract_gw_posterior_from_bilby

__all__ = ["Flow", "load_model", "extract_gw_posterior_from_bilby"]
__all__ = ["Flow", "ConditionalFlow", "load_model", "extract_gw_posterior_from_bilby"]
7 changes: 6 additions & 1 deletion jesterTOV/inference/flows/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class FlowTrainingConfig(BaseModel):
invert : bool
Whether to invert the flow (default: True)
cond_dim : int | None
Conditional dimension for conditional flows (default: None)
If set, a conditional flow will be trained.
cond_dim is the dimension for the conditional data.
In this case, `conditional_parameter_names` needs to be provided as well (default: None)
cond_parameter_names: list[str] | None
List of parameter names that are the conditional data (default: None)
max_samples : int
Maximum number of samples to use for training (default: 50,000)
seed : int
Expand Down Expand Up @@ -84,6 +88,7 @@ class FlowTrainingConfig(BaseModel):
flow_layers: int = 1
invert: bool = True
cond_dim: int | None = None
cond_parameter_names: list[str] | None = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on inconsistent conditional configs.

This addition exposes cond_parameter_names, but the model still accepts cond_dim=0, negative cond_dim, cond_parameter_names without cond_dim, and lists whose length does not match cond_dim. The runtime guard in jesterTOV/inference/flows/train_flow.py only checks presence, so some bad configs are silently ignored while others fail much later inside flow creation/training.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/config.py` at line 91, Validate conditional config
fields immediately when constructing the configuration: in the config class
where cond_dim and cond_parameter_names are declared, add a sanity check that
raises ValueError for invalid combinations — require cond_dim to be None or an
int >= 1 (reject 0 and negatives), require cond_parameter_names to be either
None or a list[str], forbid cond_parameter_names when cond_dim is None or <=0,
and require len(cond_parameter_names) == cond_dim when both are provided; remove
reliance on the weaker presence-only check in train_flow.py and let
train_flow.py assume config validation has already enforced consistency.

max_samples: int = 50_000
seed: int = 0
plot_corner: bool = True
Expand Down
128 changes: 128 additions & 0 deletions jesterTOV/inference/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,134 @@ def log_prob(self, x: Array) -> Array:

return log_p

class ConditionalFlow(Flow):


def __init__(
self,
flow: AbstractDistribution,
metadata: Dict[str, Any],
flow_kwargs: Dict[str, Any],
):
"""
Initialize Flow wrapper.

Args:
flow: Trained flowjax flow model
metadata: Training metadata
flow_kwargs: Flow architecture kwargs
"""

super().__init__(flow, metadata, flow_kwargs)

if self.standardization_method=="zscore":
self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
self.cond_data_std = jnp.array(metadata["cond_data_std"])

elif self.standardization_method=="minmax":
self.cond_data_min = jnp.array(metadata["cond_data_min"])
self.cond_data_max = jnp.array(metadata["cond_data_max"])
self.cond_data_range = self.cond_data_max - self.cond_data_min

def standardize_cond_data(self, data: Array) -> Array:

if self.standardization_method == "zscore":
return (data - self.cond_data_mean) / self.cond_data_std
else:
return (data - self.cond_data_min) / self.cond_data_range
Comment on lines +314 to +328
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the conditional standardization math.

The z-score branch divides by cond_data_mean instead of cond_data_std, the "none" case falls through to min-max attributes that were never initialized, and the min-max path does not recreate the zero-range guard from training. ConditionalFlow.log_prob() will therefore be wrong for standardized models and can crash for unstandardized ones.

🛠️ Suggested fix
         if self.standardization_method=="zscore":
             self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
             self.cond_data_std = jnp.array(metadata["cond_data_std"])
+            self.cond_data_std = jnp.where(
+                self.cond_data_std == 0, 1.0, self.cond_data_std
+            )

         elif self.standardization_method=="minmax":
             self.cond_data_min = jnp.array(metadata["cond_data_min"])
             self.cond_data_max = jnp.array(metadata["cond_data_max"])
             self.cond_data_range = self.cond_data_max - self.cond_data_min
+            self.cond_data_range = jnp.where(
+                self.cond_data_range == 0, 1.0, self.cond_data_range
+            )

     def standardize_cond_data(self, data: Array) -> Array:

         if self.standardization_method == "zscore":
-            return (data - self.cond_data_mean) / self.cond_data_mean
-        else:
+            return (data - self.cond_data_mean) / self.cond_data_std
+        if self.standardization_method == "minmax":
             return (data - self.cond_data_min) / self.cond_data_range
+        return data
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 314 - 328, Fix
standardize_cond_data in ConditionalFlow: for the "zscore" branch divide by
self.cond_data_std (not cond_data_mean); for the "minmax" branch guard against
zero ranges by using a safe_range = jnp.where(self.cond_data_range == 0, 1,
self.cond_data_range) and divide by that; and add an explicit "none" case that
returns the input data unchanged so it doesn't reference min/max attributes that
weren't initialized. Update the standardize_cond_data method accordingly
(referencing standardize_cond_data, self.cond_data_mean, self.cond_data_std,
self.cond_data_min, self.cond_data_range).



@classmethod
def from_directory(cls, output_dir: str) -> "ConditionalFlow":
"""
Load a trained conditional flow from a directory.

Args:
output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json

Returns:
ConditionalFlow instance with loaded model and metadata

"""
# Load the flow model and metadata
flow_model, metadata = load_model(output_dir)

# Load kwargs
kwargs_path = os.path.join(output_dir, "flow_kwargs.json")
with open(kwargs_path, "r") as f:
flow_kwargs = json.load(f)

return cls(flow_model, metadata, flow_kwargs)

def log_prob(self, x: Array, y: Array) -> Array:
"""
Evaluate log probability of samples under the flow.

If standardization was used, input data is automatically standardized
before evaluation and Jacobian correction is applied. If not, operations
are identity (no-op).

The Jacobian correction accounts for the change of variables:
- Z-score: log p(x|y) = log p(x_std|y_std) - sum(log(std))
- Min-max: log p(x|y) = log p(x_std|y_std) - sum(log(max - min))
- None: log p(x|y) = log p(x_std|y_std) (no correction)

Comment on lines +354 to +365
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Format the new probability equations as Sphinx math.

The ConditionalFlow.log_prob() docstring adds math-heavy text, but it is plain prose right now. Please switch it to Sphinx math markup so the docs render consistently.

As per coding guidelines "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with :math: role for inline math and .. math:: directive for display equations" and "Use raw strings (r""") for docstrings containing LaTeX to avoid Python escape sequence warnings".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 354 - 365, Update the
ConditionalFlow.log_prob() docstring to use a raw string (r""") and convert the
probability change-of-variable lines into Sphinx math: use :math: for inline
symbols and a ``.. math::`` block for the display equations showing Z-score,
Min-max, and None corrections (e.g. .. math:: \\log p(x\\mid y) = \\log
p(x_{std}\\mid y_{std}) - \\sum \\log(\\mathrm{std}), etc.), ensuring each
formula is valid LaTeX inside the math directive and the prose references those
math expressions via :math: where needed.

Args:
x: Data in original scale, shape (n_samples, n_features).
JAX array.
y: Conditional parameter, shape (n_samples, cond_dim).
JAX array.

Returns:
Log probabilities as JAX array, shape (n_samples,)

Example:
>>> data = jnp.array([[1.4, 1.3, 100, 200]])
>>> y = jnp.array([[2.0, 3.0]])
>>> log_prob = flow.log_prob(data, y=y)
"""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Standardize input (method-dependent or identity)
x_std = self.standardize_input(x)
y_std = self.standardize_cond_data(y)

# Evaluate log probability in standardized space
log_p = self.flow.log_prob(x_std, y_std)

# Account for Jacobian of inverse transformation
if self.standardization_method == "zscore":
# Z-score: log |det J| = sum(log(std))
log_det_jacobian = -jnp.sum(jnp.log(self.data_std))
else:
# Min-max or none: log |det J| = sum(log(range))
# If standardization disabled (range=1), log_det_jacobian = 0
log_det_jacobian = -jnp.sum(jnp.log(self.data_range))

log_p = log_p + log_det_jacobian

return log_p

def sample(self, key: Array, shape: Tuple[int, ...], y: Array) -> Array:
"""
Sample from the conditional flow and return in original scale.

If standardization was used during training, samples are automatically
converted back to the original scale using the inverse transformation
(z-score or min-max). If not, the transformation is identity (no-op).

Args:
key: JAX random key (jax.Array)
shape: Shape of samples to generate (e.g., (1000,) for 1000 samples)
y: Conditional parameters (Array)
Returns:
Samples in original scale as JAX array of shape (*shape, y.shape[0], n_features)
"""

y_std = self.standardize_cond_data(y_std)
samples = self.flow.sample(key, shape, condition=y)

samples = self.destandardize_output(samples)
Comment on lines +417 to +419
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

sample() is conditioning on the wrong scale.

Training standardizes conditional inputs before fitting when config.standardize is enabled, but this method passes raw y straight into the flow. With the default config, generated samples are conditioned on mismatched inputs.

🛠️ Suggested fix
-        samples = self.flow.sample(key, shape, condition=y)
+        y_std = self.standardize_cond_data(y)
+        samples = self.flow.sample(key, shape, condition=y_std)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 414 - 416, The code calls
self.flow.sample with the raw conditional y but training standardized
conditionals when config.standardize is true; update the method so that if
self.config.standardize (or equivalent flag) is enabled you first transform y
using the same routine used during training (e.g., call the existing
standardization helper such as self.standardize_input(y) or
self.standardize_condition(y) — if that helper doesn't exist, implement the same
standardization logic used during fit), then pass the standardized conditional
into self.flow.sample(key, shape, condition=standardized_y), and finally keep
the existing destandardize_output(samples) step.


return samples

def create_transformer(
transformer_type: str = "affine",
Expand Down
Loading
Loading