Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jesterTOV/inference/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Normalizing flow models for gravitational wave inference."""

from .flow import Flow, load_model
from .flow import Flow, ConditionalFlow, load_model
from .bilby_extract import extract_gw_posterior_from_bilby

__all__ = ["Flow", "load_model", "extract_gw_posterior_from_bilby"]
__all__ = ["Flow", "ConditionalFlow", "load_model", "extract_gw_posterior_from_bilby"]
7 changes: 6 additions & 1 deletion jesterTOV/inference/flows/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class FlowTrainingConfig(BaseModel):
invert : bool
Whether to invert the flow (default: True)
cond_dim : int | None
Conditional dimension for conditional flows (default: None)
If set, a conditional flow will be trained.
cond_dim is the dimension for the conditional data.
In this case, `conditional_parameter_names` needs to be provided as well (default: None)
cond_parameter_names: list[str] | None
List of parameter names that are the conditional data (default: None)
max_samples : int
Maximum number of samples to use for training (default: 50,000)
seed : int
Expand Down Expand Up @@ -84,6 +88,7 @@ class FlowTrainingConfig(BaseModel):
flow_layers: int = 1
invert: bool = True
cond_dim: int | None = None
cond_parameter_names: list[str] | None = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on inconsistent conditional configs.

This addition exposes cond_parameter_names, but the model still accepts cond_dim=0, negative cond_dim, cond_parameter_names without cond_dim, and lists whose length does not match cond_dim. The runtime guard in jesterTOV/inference/flows/train_flow.py only checks presence, so some bad configs are silently ignored while others fail much later inside flow creation/training.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/config.py` at line 91, Validate conditional config
fields immediately when constructing the configuration: in the config class
where cond_dim and cond_parameter_names are declared, add a sanity check that
raises ValueError for invalid combinations — require cond_dim to be None or an
int >= 1 (reject 0 and negatives), require cond_parameter_names to be either
None or a list[str], forbid cond_parameter_names when cond_dim is None or <=0,
and require len(cond_parameter_names) == cond_dim when both are provided; remove
reliance on the weaker presence-only check in train_flow.py and let
train_flow.py assume config validation has already enforced consistency.

max_samples: int = 50_000
seed: int = 0
plot_corner: bool = True
Expand Down
128 changes: 128 additions & 0 deletions jesterTOV/inference/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,134 @@ def log_prob(self, x: Array) -> Array:

return log_p

class ConditionalFlow(Flow):


def __init__(
self,
flow: AbstractDistribution,
metadata: Dict[str, Any],
flow_kwargs: Dict[str, Any],
):
"""
Initialize Flow wrapper.

Args:
flow: Trained flowjax flow model
metadata: Training metadata
flow_kwargs: Flow architecture kwargs
"""

super().__init__(flow, metadata, flow_kwargs)

if self.standardization_method=="zscore":
self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
self.cond_data_std = jnp.array(metadata["cond_data_std"])

elif self.standardization_method=="minmax":
self.cond_data_min = jnp.array(metadata["cond_data_min"])
self.cond_data_max = jnp.array(metadata["cond_data_max"])
self.cond_data_range = self.cond_data_max - self.cond_data_min

def standardize_cond_data(self, data: Array) -> Array:

if self.standardization_method == "zscore":
return (data - self.cond_data_mean) / self.cond_data_std
else:
return (data - self.cond_data_min) / self.cond_data_range
Comment on lines +314 to +328
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the conditional standardization math.

The z-score branch divides by cond_data_mean instead of cond_data_std, the "none" case falls through to min-max attributes that were never initialized, and the min-max path does not recreate the zero-range guard from training. ConditionalFlow.log_prob() will therefore be wrong for standardized models and can crash for unstandardized ones.

🛠️ Suggested fix
         if self.standardization_method=="zscore":
             self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
             self.cond_data_std = jnp.array(metadata["cond_data_std"])
+            self.cond_data_std = jnp.where(
+                self.cond_data_std == 0, 1.0, self.cond_data_std
+            )

         elif self.standardization_method=="minmax":
             self.cond_data_min = jnp.array(metadata["cond_data_min"])
             self.cond_data_max = jnp.array(metadata["cond_data_max"])
             self.cond_data_range = self.cond_data_max - self.cond_data_min
+            self.cond_data_range = jnp.where(
+                self.cond_data_range == 0, 1.0, self.cond_data_range
+            )

     def standardize_cond_data(self, data: Array) -> Array:

         if self.standardization_method == "zscore":
-            return (data - self.cond_data_mean) / self.cond_data_mean
-        else:
+            return (data - self.cond_data_mean) / self.cond_data_std
+        if self.standardization_method == "minmax":
             return (data - self.cond_data_min) / self.cond_data_range
+        return data
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 314 - 328, Fix
standardize_cond_data in ConditionalFlow: for the "zscore" branch divide by
self.cond_data_std (not cond_data_mean); for the "minmax" branch guard against
zero ranges by using a safe_range = jnp.where(self.cond_data_range == 0, 1,
self.cond_data_range) and divide by that; and add an explicit "none" case that
returns the input data unchanged so it doesn't reference min/max attributes that
weren't initialized. Update the standardize_cond_data method accordingly
(referencing standardize_cond_data, self.cond_data_mean, self.cond_data_std,
self.cond_data_min, self.cond_data_range).



@classmethod
def from_directory(cls, output_dir: str) -> "ConditionalFlow":
"""
Load a trained conditional flow from a directory.

Args:
output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json

Returns:
ConditionalFlow instance with loaded model and metadata

"""
# Load the flow model and metadata
flow_model, metadata = load_model(output_dir)

# Load kwargs
kwargs_path = os.path.join(output_dir, "flow_kwargs.json")
with open(kwargs_path, "r") as f:
flow_kwargs = json.load(f)

return cls(flow_model, metadata, flow_kwargs)

def log_prob(self, x: Array, y: Array) -> Array:
"""
Evaluate log probability of samples under the flow.

If standardization was used, input data is automatically standardized
before evaluation and Jacobian correction is applied. If not, operations
are identity (no-op).

The Jacobian correction accounts for the change of variables:
- Z-score: log p(x|y) = log p(x_std|y_std) - sum(log(std))
- Min-max: log p(x|y) = log p(x_std|y_std) - sum(log(max - min))
- None: log p(x|y) = log p(x_std|y_std) (no correction)

Comment on lines +354 to +365
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Format the new probability equations as Sphinx math.

The ConditionalFlow.log_prob() docstring adds math-heavy text, but it is plain prose right now. Please switch it to Sphinx math markup so the docs render consistently.

As per coding guidelines "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with :math: role for inline math and .. math:: directive for display equations" and "Use raw strings (r""") for docstrings containing LaTeX to avoid Python escape sequence warnings".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 354 - 365, Update the
ConditionalFlow.log_prob() docstring to use a raw string (r""") and convert the
probability change-of-variable lines into Sphinx math: use :math: for inline
symbols and a ``.. math::`` block for the display equations showing Z-score,
Min-max, and None corrections (e.g. .. math:: \\log p(x\\mid y) = \\log
p(x_{std}\\mid y_{std}) - \\sum \\log(\\mathrm{std}), etc.), ensuring each
formula is valid LaTeX inside the math directive and the prose references those
math expressions via :math: where needed.

Args:
x: Data in original scale, shape (n_samples, n_features).
JAX array.
y: Conditional parameter, shape (n_samples, cond_dim).
JAX array.

Returns:
Log probabilities as JAX array, shape (n_samples,)

Example:
>>> data = jnp.array([[1.4, 1.3, 100, 200]])
>>> y = jnp.array([[2.0, 3.0]])
>>> log_prob = flow.log_prob(data, y=y)
"""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Standardize input (method-dependent or identity)
x_std = self.standardize_input(x)
y_std = self.standardize_cond_data(y)

# Evaluate log probability in standardized space
log_p = self.flow.log_prob(x_std, y_std)

# Account for Jacobian of inverse transformation
if self.standardization_method == "zscore":
# Z-score: log |det J| = sum(log(std))
log_det_jacobian = -jnp.sum(jnp.log(self.data_std))
else:
# Min-max or none: log |det J| = sum(log(range))
# If standardization disabled (range=1), log_det_jacobian = 0
log_det_jacobian = -jnp.sum(jnp.log(self.data_range))

log_p = log_p + log_det_jacobian

return log_p

def sample(self, key: Array, shape: Tuple[int, ...], y: Array) -> Array:
"""
Sample from the conditional flow and return in original scale.

If standardization was used during training, samples are automatically
converted back to the original scale using the inverse transformation
(z-score or min-max). If not, the transformation is identity (no-op).

Args:
key: JAX random key (jax.Array)
shape: Shape of samples to generate (e.g., (1000,) for 1000 samples)
y: Conditional parameters (Array)
Returns:
Samples in original scale as JAX array of shape (*shape, y.shape[0], n_features)
"""

y_std = self.standardize_cond_data(y_std)
samples = self.flow.sample(key, shape, condition=y)

samples = self.destandardize_output(samples)
Comment on lines +417 to +419
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

sample() is conditioning on the wrong scale.

Training standardizes conditional inputs before fitting when config.standardize is enabled, but this method passes raw y straight into the flow. With the default config, generated samples are conditioned on mismatched inputs.

🛠️ Suggested fix
-        samples = self.flow.sample(key, shape, condition=y)
+        y_std = self.standardize_cond_data(y)
+        samples = self.flow.sample(key, shape, condition=y_std)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/flow.py` around lines 414 - 416, The code calls
self.flow.sample with the raw conditional y but training standardized
conditionals when config.standardize is true; update the method so that if
self.config.standardize (or equivalent flag) is enabled you first transform y
using the same routine used during training (e.g., call the existing
standardization helper such as self.standardize_input(y) or
self.standardize_condition(y) — if that helper doesn't exist, implement the same
standardization logic used during fit), then pass the standardized conditional
into self.flow.sample(key, shape, condition=standardized_y), and finally keep
the existing destandardize_output(samples) step.


return samples

def create_transformer(
transformer_type: str = "affine",
Expand Down
162 changes: 124 additions & 38 deletions jesterTOV/inference/flows/train_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
import os
import sys
from pathlib import Path
from typing import Any, Dict, Tuple, Mapping
from typing import Any, Dict, Tuple, Mapping, Iterable

import equinox as eqx
import jax
Expand Down Expand Up @@ -167,6 +167,45 @@ def load_posterior(

return data, metadata

def standardize_data(
data: np.ndarray,
standardization_method: str,
parameter_names: list[str] = []
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
"""
Standardize data based on the selected standardiation method.

Args:
data: Array of shape (n_samples, n_features)
standardization_method: str
Which method to standardize with.
Can either be 'zscore' for zero mean and unit std,
otherwise min-max scaling will be applied.
parameter_names:
Parameter names for which the rescaled data range should be printed.
Defaults to [].
"""
Comment on lines +175 to +187
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Format math in docstrings using Sphinx math roles.

This new docstring includes mathematical expressions (mean=0, std=1, [0, 1]) but not Sphinx math markup.

As per coding guidelines, "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with :math: role for inline math and .. math:: directive for display equations".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/train_flow.py` around lines 175 - 187, Update the
docstring for the function that standardizes data (references: parameter names
'standardization_method' and 'parameter_names') to use Sphinx math roles:
replace inline math like "mean=0" and "std=1" with :math:`mean = 0` and
:math:`std = 1`, and replace range notation "[0, 1]" with :math:`[0, 1]`; if any
expression should be displayed on its own line convert it to a block using the
.. math:: directive. Ensure the textual description still mentions 'zscore' and
min-max scaling but uses :math:`...` or .. math:: for all mathematical
fragments.


if standardization_method == "zscore":
logger.info("Standardizing data using z-score (mean=0, std=1)...")
data, data_statistics = standardize_data_zscore(data)
logger.info("Standardized data statistics:")
for i, name in enumerate(parameter_names):
logger.info(
f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}"
)
logger.info("Data mean and std saved for inverse transformation")
else: # minmax
logger.info("Standardizing data using min-max [0, 1] scaling...")
data, data_statistics = standardize_data_minmax(data)
logger.info("Standardized data ranges:")
for i, name in enumerate(parameter_names):
logger.info(
f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]"
)
logger.info("Data bounds saved for inverse transformation")

Comment on lines +198 to +207
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Logging statement inside loop causes repeated output.

Line 206 is indented inside the for loop, so "Data bounds saved for inverse transformation" will be logged once per parameter instead of once total. Compare with the z-score branch where the equivalent message (line 197) is correctly outside the loop.

🛠️ Suggested fix
     else:  # minmax
         logger.info("Standardizing data using min-max [0, 1] scaling...")
         data, data_statistics = standardize_data_minmax(data)
         logger.info("Standardized data ranges:")
         for i, name in enumerate(parameter_names):
             logger.info(
                 f"  {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]"
             )
-            logger.info("Data bounds saved for inverse transformation")
+        logger.info("Data bounds saved for inverse transformation")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/flows/train_flow.py` around lines 198 - 207, The logging
message "Data bounds saved for inverse transformation" is inside the for-loop
that iterates over parameter_names, causing it to be emitted once per parameter;
move that logger.info call out of the loop so it runs once after the loop that
logs per-parameter ranges in the minmax branch (the block using
standardize_data_minmax, data, data_statistics, parameter_names, and logger).
Ensure the loop only logs the per-parameter ranges and then, immediately after
the loop completes, call logger.info("Data bounds saved for inverse
transformation") once.

return data, data_statistics

def standardize_data_zscore(
data: np.ndarray,
Expand Down Expand Up @@ -268,7 +307,7 @@ def inverse_standardize_data_minmax(

def train_flow(
flow: Any,
data: np.ndarray,
data: np.ndarray | Iterable[np.ndarray],
key: Array,
learning_rate: float = 1e-3,
max_epochs: int = 600,
Expand All @@ -281,7 +320,9 @@ def train_flow(

Args:
flow: Untrained flowjax flow
data: Training data of shape (n_samples, n_dims)
data: Training data of shape (n_samples, n_dims)
or if conditional flow is trained iterable of two arrays
where the last array are the conditional parameters.
key: JAX random key
learning_rate: Learning rate for optimizer
max_epochs: Maximum number of epochs
Expand Down Expand Up @@ -455,6 +496,7 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None:
logger.info(f"Flow layers: {config.flow_layers}")
logger.info(f"Invert: {config.invert}")
logger.info(f"Cond dim: {config.cond_dim}")
logger.info(f"Cond. parameters: {config.cond_parameter_names}")
logger.info(f"Transformer: {config.transformer}")
logger.info(f"Transformer knots: {config.transformer_knots}")
logger.info(f"Transformer interval: {config.transformer_interval}")
Expand Down Expand Up @@ -490,30 +532,46 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None:
# Standardize data if requested
data_statistics = None
if config.standardize:
if config.standardization_method == "zscore":
logger.info("Standardizing data using z-score (mean=0, std=1)...")
data, data_statistics = standardize_data_zscore(data)
logger.info("Standardized data statistics:")
for i, name in enumerate(parameter_names):
logger.info(
f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}"
)
logger.info("Data mean and std saved for inverse transformation")
else: # minmax
logger.info("Standardizing data using min-max [0, 1] scaling...")
data, data_statistics = standardize_data_minmax(data)
logger.info("Standardized data ranges:")
for i, name in enumerate(parameter_names):
logger.info(
f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]"
)
logger.info("Data bounds saved for inverse transformation")
data, data_statistics = standardize_data(
data,
config.standardization_method,
parameter_names,
)
dim = data.shape[1] # Infer dimensionality from data

if config.cond_dim:
logger.info("[1.5/5] Loading conditional samples...")

if len(config.cond_parameter_names) != config.cond_dim:
raise ValueError(
f"If conditional dimension is set, "
f"you also need to provide {config.cond_dim} conditional "
f"parameter names. You provided {len(config.cond_parameter_names)}."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

cond_samples, cond_samples_metadata = load_posterior(
config.posterior_file,
parameter_names=config.cond_parameter_names,
max_samples=config.max_samples
)

# Standardize data if requested
original_cond_samples = cond_samples.copy()
cond_data_statistics = None
if config.standardize:
cond_samples, cond_data_statistics = standardize_data(
cond_samples,
config.standardization_method,
)

original_data = np.hstack((original_data, original_cond_samples))
data = (data, cond_samples)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Create flow
logger.info("[2/5] Creating flow architecture...")
flow_key, train_key, sample_key = jax.random.split(jax.random.key(config.seed), 3)
dim = data.shape[1] # Infer dimensionality from data
logger.info(f"Flow dimensionality: {dim}D")

flow = create_flow(
key=flow_key,
dim=dim,
Expand All @@ -531,7 +589,6 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None:

# Train flow
logger.info("[3/5] Training flow...")
logger.info(f"Training dataset shape: {data.shape}")
trained_flow, losses = train_flow(
flow,
data,
Expand Down Expand Up @@ -591,6 +648,15 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None:
else: # minmax
metadata["data_bounds_min"] = data_statistics["min"].tolist()
metadata["data_bounds_max"] = data_statistics["max"].tolist()

# Add conditional data statistics to metadata if standardization was used
if config.cond_dim:
if config.standardization_method == "zscore":
metadata["cond_data_mean"] = cond_data_statistics["mean"].tolist()
metadata["cond_data_std"] = cond_data_statistics["std"].tolist()
else:
metadata["cond_data_min"] = cond_data_statistics["min"].tolist()
metadata["cond_data_max"] = cond_data_statistics["max"].tolist()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

save_model(trained_flow, config.output_dir, flow_kwargs, metadata)

Expand All @@ -608,26 +674,46 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None:
if config.plot_corner:
try:
# Sample from trained flow
n_plot_samples = min(10_000, data.shape[0])
flow_samples = trained_flow.sample(sample_key, (n_plot_samples,))
flow_samples_np = np.array(flow_samples)

# Inverse transform samples if data was standardized
if config.standardize and data_statistics is not None:
if config.standardization_method == "zscore":
flow_samples_np = inverse_standardize_data_zscore(
flow_samples_np, data_statistics
)
else: # minmax
flow_samples_np = inverse_standardize_data_minmax(
flow_samples_np, data_statistics
)
n_plot_samples = min(10_000, original_data.shape[0])

if config.cond_dim:
# get flow samples and untransform them
flow_samples = trained_flow.sample(sample_key, (1,), condition=cond_samples)
flow_samples_np = np.array(flow_samples)
if config.standardize and data_statistics is not None:
if config.standardization_method == "zscore":
flow_samples_np = inverse_standardize_data_zscore(
flow_samples_np, data_statistics
)
else: # minmax
flow_samples_np = inverse_standardize_data_minmax(
flow_samples_np, data_statistics
)
# stack together with conditional data for plot
flow_samples_np = np.hstack((flow_samples_np.reshape(-1, 1), original_data[:, -config.cond_dim:]))
labels = [*config.parameter_names, *config.cond_parameter_names]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

else:
flow_samples = trained_flow.sample(sample_key, (n_plot_samples,))
flow_samples_np = np.array(flow_samples)
labels = parameter_names

# Inverse transform samples if data was standardized
if config.standardize and data_statistics is not None:
if config.standardization_method == "zscore":
flow_samples_np = inverse_standardize_data_zscore(
flow_samples_np, data_statistics
)
else: # minmax
flow_samples_np = inverse_standardize_data_minmax(
flow_samples_np, data_statistics
)

corner_path = os.path.join(figures_dir, "corner.png")
# Use original_data for corner plot comparison
# Update labels based on parameter names
plot_corner(
original_data, flow_samples_np, corner_path, labels=parameter_names
original_data, flow_samples_np, corner_path, labels=labels
)
except Exception as e:
logger.warning(
Expand Down
Loading