-
Notifications
You must be signed in to change notification settings - Fork 7
First attempt to add conditional NFs to jester infrastructure #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
f618d38
fd31ffd
84a4bfb
41640ae
65ece2e
0dd578a
09c5741
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """Normalizing flow models for gravitational wave inference.""" | ||
|
|
||
| from .flow import Flow, load_model | ||
| from .flow import Flow, ConditionalFlow, load_model | ||
| from .bilby_extract import extract_gw_posterior_from_bilby | ||
|
|
||
| __all__ = ["Flow", "load_model", "extract_gw_posterior_from_bilby"] | ||
| __all__ = ["Flow", "ConditionalFlow", "load_model", "extract_gw_posterior_from_bilby"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -291,6 +291,134 @@ def log_prob(self, x: Array) -> Array: | |
|
|
||
| return log_p | ||
|
|
||
| class ConditionalFlow(Flow): | ||
|
|
||
|
|
||
| def __init__( | ||
| self, | ||
| flow: AbstractDistribution, | ||
| metadata: Dict[str, Any], | ||
| flow_kwargs: Dict[str, Any], | ||
| ): | ||
| """ | ||
| Initialize Flow wrapper. | ||
|
|
||
| Args: | ||
| flow: Trained flowjax flow model | ||
| metadata: Training metadata | ||
| flow_kwargs: Flow architecture kwargs | ||
| """ | ||
|
|
||
| super().__init__(flow, metadata, flow_kwargs) | ||
|
|
||
| if self.standardization_method=="zscore": | ||
| self.cond_data_mean = jnp.array(metadata["cond_data_mean"]) | ||
| self.cond_data_std = jnp.array(metadata["cond_data_std"]) | ||
|
|
||
| elif self.standardization_method=="minmax": | ||
| self.cond_data_min = jnp.array(metadata["cond_data_min"]) | ||
| self.cond_data_max = jnp.array(metadata["cond_data_max"]) | ||
| self.cond_data_range = self.cond_data_max - self.cond_data_min | ||
|
|
||
| def standardize_cond_data(self, data: Array) -> Array: | ||
|
|
||
| if self.standardization_method == "zscore": | ||
| return (data - self.cond_data_mean) / self.cond_data_std | ||
| else: | ||
| return (data - self.cond_data_min) / self.cond_data_range | ||
|
Comment on lines
+314
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the conditional standardization math. The z-score branch divides by 🛠️ Suggested fix if self.standardization_method=="zscore":
self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
self.cond_data_std = jnp.array(metadata["cond_data_std"])
+ self.cond_data_std = jnp.where(
+ self.cond_data_std == 0, 1.0, self.cond_data_std
+ )
elif self.standardization_method=="minmax":
self.cond_data_min = jnp.array(metadata["cond_data_min"])
self.cond_data_max = jnp.array(metadata["cond_data_max"])
self.cond_data_range = self.cond_data_max - self.cond_data_min
+ self.cond_data_range = jnp.where(
+ self.cond_data_range == 0, 1.0, self.cond_data_range
+ )
def standardize_cond_data(self, data: Array) -> Array:
if self.standardization_method == "zscore":
- return (data - self.cond_data_mean) / self.cond_data_mean
- else:
+ return (data - self.cond_data_mean) / self.cond_data_std
+ if self.standardization_method == "minmax":
return (data - self.cond_data_min) / self.cond_data_range
+ return data🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| @classmethod | ||
| def from_directory(cls, output_dir: str) -> "ConditionalFlow": | ||
| """ | ||
| Load a trained conditional flow from a directory. | ||
|
|
||
| Args: | ||
| output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json | ||
|
|
||
| Returns: | ||
| ConditionalFlow instance with loaded model and metadata | ||
|
|
||
| """ | ||
| # Load the flow model and metadata | ||
| flow_model, metadata = load_model(output_dir) | ||
|
|
||
| # Load kwargs | ||
| kwargs_path = os.path.join(output_dir, "flow_kwargs.json") | ||
| with open(kwargs_path, "r") as f: | ||
| flow_kwargs = json.load(f) | ||
|
|
||
| return cls(flow_model, metadata, flow_kwargs) | ||
|
|
||
| def log_prob(self, x: Array, y: Array) -> Array: | ||
| """ | ||
| Evaluate log probability of samples under the flow. | ||
|
|
||
| If standardization was used, input data is automatically standardized | ||
| before evaluation and Jacobian correction is applied. If not, operations | ||
| are identity (no-op). | ||
|
|
||
| The Jacobian correction accounts for the change of variables: | ||
| - Z-score: log p(x|y) = log p(x_std|y_std) - sum(log(std)) | ||
| - Min-max: log p(x|y) = log p(x_std|y_std) - sum(log(max - min)) | ||
| - None: log p(x|y) = log p(x_std|y_std) (no correction) | ||
|
|
||
|
Comment on lines
+354
to
+365
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Format the new probability equations as Sphinx math. The As per coding guidelines "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with 🤖 Prompt for AI Agents |
||
| Args: | ||
| x: Data in original scale, shape (n_samples, n_features). | ||
| JAX array. | ||
| y: Conditional parameter, shape (n_samples, cond_dim). | ||
| JAX array. | ||
|
|
||
| Returns: | ||
| Log probabilities as JAX array, shape (n_samples,) | ||
|
|
||
| Example: | ||
| >>> data = jnp.array([[1.4, 1.3, 100, 200]]) | ||
| >>> y = jnp.array([[2.0, 3.0]]) | ||
| >>> log_prob = flow.log_prob(data, y=y) | ||
| """ | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| # Standardize input (method-dependent or identity) | ||
| x_std = self.standardize_input(x) | ||
| y_std = self.standardize_cond_data(y) | ||
|
|
||
| # Evaluate log probability in standardized space | ||
| log_p = self.flow.log_prob(x_std, y_std) | ||
|
|
||
| # Account for Jacobian of inverse transformation | ||
| if self.standardization_method == "zscore": | ||
| # Z-score: log |det J| = sum(log(std)) | ||
| log_det_jacobian = -jnp.sum(jnp.log(self.data_std)) | ||
| else: | ||
| # Min-max or none: log |det J| = sum(log(range)) | ||
| # If standardization disabled (range=1), log_det_jacobian = 0 | ||
| log_det_jacobian = -jnp.sum(jnp.log(self.data_range)) | ||
|
|
||
| log_p = log_p + log_det_jacobian | ||
|
|
||
| return log_p | ||
|
|
||
| def sample(self, key: Array, shape: Tuple[int, ...], y: Array) -> Array: | ||
| """ | ||
| Sample from the conditional flow and return in original scale. | ||
|
|
||
| If standardization was used during training, samples are automatically | ||
| converted back to the original scale using the inverse transformation | ||
| (z-score or min-max). If not, the transformation is identity (no-op). | ||
|
|
||
| Args: | ||
| key: JAX random key (jax.Array) | ||
| shape: Shape of samples to generate (e.g., (1000,) for 1000 samples) | ||
| y: Conditional parameters (Array) | ||
| Returns: | ||
| Samples in original scale as JAX array of shape (*shape, y.shape[0], n_features) | ||
| """ | ||
|
|
||
| y_std = self.standardize_cond_data(y_std) | ||
| samples = self.flow.sample(key, shape, condition=y) | ||
|
|
||
| samples = self.destandardize_output(samples) | ||
|
Comment on lines
+417
to
+419
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Training standardizes conditional inputs before fitting when 🛠️ Suggested fix- samples = self.flow.sample(key, shape, condition=y)
+ y_std = self.standardize_cond_data(y)
+ samples = self.flow.sample(key, shape, condition=y_std)🤖 Prompt for AI Agents |
||
|
|
||
| return samples | ||
|
|
||
| def create_transformer( | ||
| transformer_type: str = "affine", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,7 +88,7 @@ | |
| import os | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Any, Dict, Tuple, Mapping | ||
| from typing import Any, Dict, Tuple, Mapping, Iterable | ||
|
|
||
| import equinox as eqx | ||
| import jax | ||
|
|
@@ -167,6 +167,45 @@ def load_posterior( | |
|
|
||
| return data, metadata | ||
|
|
||
| def standardize_data( | ||
| data: np.ndarray, | ||
| standardization_method: str, | ||
| parameter_names: list[str] = [] | ||
| ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: | ||
| """ | ||
| Standardize data based on the selected standardiation method. | ||
|
|
||
| Args: | ||
| data: Array of shape (n_samples, n_features) | ||
| standardization_method: str | ||
| Which method to standardize with. | ||
| Can either be 'zscore' for zero mean and unit std, | ||
| otherwise min-max scaling will be applied. | ||
| parameter_names: | ||
| Parameter names for which the rescaled data range should be printed. | ||
| Defaults to []. | ||
| """ | ||
|
Comment on lines
+175
to
+187
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Format math in docstrings using Sphinx math roles. This new docstring includes mathematical expressions ( As per coding guidelines, "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with 🤖 Prompt for AI Agents |
||
|
|
||
| if standardization_method == "zscore": | ||
| logger.info("Standardizing data using z-score (mean=0, std=1)...") | ||
| data, data_statistics = standardize_data_zscore(data) | ||
| logger.info("Standardized data statistics:") | ||
| for i, name in enumerate(parameter_names): | ||
| logger.info( | ||
| f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}" | ||
| ) | ||
| logger.info("Data mean and std saved for inverse transformation") | ||
| else: # minmax | ||
| logger.info("Standardizing data using min-max [0, 1] scaling...") | ||
| data, data_statistics = standardize_data_minmax(data) | ||
| logger.info("Standardized data ranges:") | ||
| for i, name in enumerate(parameter_names): | ||
| logger.info( | ||
| f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]" | ||
| ) | ||
| logger.info("Data bounds saved for inverse transformation") | ||
|
|
||
|
Comment on lines
+198
to
+207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Logging statement inside loop causes repeated output. Line 206 is indented inside the 🛠️ Suggested fix else: # minmax
logger.info("Standardizing data using min-max [0, 1] scaling...")
data, data_statistics = standardize_data_minmax(data)
logger.info("Standardized data ranges:")
for i, name in enumerate(parameter_names):
logger.info(
f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]"
)
- logger.info("Data bounds saved for inverse transformation")
+ logger.info("Data bounds saved for inverse transformation")🤖 Prompt for AI Agents |
||
| return data, data_statistics | ||
|
|
||
| def standardize_data_zscore( | ||
| data: np.ndarray, | ||
|
|
@@ -268,7 +307,7 @@ def inverse_standardize_data_minmax( | |
|
|
||
| def train_flow( | ||
| flow: Any, | ||
| data: np.ndarray, | ||
| data: np.ndarray | Iterable[np.ndarray], | ||
| key: Array, | ||
| learning_rate: float = 1e-3, | ||
| max_epochs: int = 600, | ||
|
|
@@ -281,7 +320,9 @@ def train_flow( | |
|
|
||
| Args: | ||
| flow: Untrained flowjax flow | ||
| data: Training data of shape (n_samples, n_dims) | ||
| data: Training data of shape (n_samples, n_dims) | ||
| or if conditional flow is trained iterable of two arrays | ||
| where the last array are the conditional parameters. | ||
| key: JAX random key | ||
| learning_rate: Learning rate for optimizer | ||
| max_epochs: Maximum number of epochs | ||
|
|
@@ -455,6 +496,7 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: | |
| logger.info(f"Flow layers: {config.flow_layers}") | ||
| logger.info(f"Invert: {config.invert}") | ||
| logger.info(f"Cond dim: {config.cond_dim}") | ||
| logger.info(f"Cond. parameters: {config.cond_parameter_names}") | ||
| logger.info(f"Transformer: {config.transformer}") | ||
| logger.info(f"Transformer knots: {config.transformer_knots}") | ||
| logger.info(f"Transformer interval: {config.transformer_interval}") | ||
|
|
@@ -490,30 +532,46 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: | |
| # Standardize data if requested | ||
| data_statistics = None | ||
| if config.standardize: | ||
| if config.standardization_method == "zscore": | ||
| logger.info("Standardizing data using z-score (mean=0, std=1)...") | ||
| data, data_statistics = standardize_data_zscore(data) | ||
| logger.info("Standardized data statistics:") | ||
| for i, name in enumerate(parameter_names): | ||
| logger.info( | ||
| f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}" | ||
| ) | ||
| logger.info("Data mean and std saved for inverse transformation") | ||
| else: # minmax | ||
| logger.info("Standardizing data using min-max [0, 1] scaling...") | ||
| data, data_statistics = standardize_data_minmax(data) | ||
| logger.info("Standardized data ranges:") | ||
| for i, name in enumerate(parameter_names): | ||
| logger.info( | ||
| f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]" | ||
| ) | ||
| logger.info("Data bounds saved for inverse transformation") | ||
| data, data_statistics = standardize_data( | ||
| data, | ||
| config.standardization_method, | ||
| parameter_names, | ||
| ) | ||
| dim = data.shape[1] # Infer dimensionality from data | ||
|
|
||
| if config.cond_dim: | ||
| logger.info("[1.5/5] Loading conditional samples...") | ||
|
|
||
| if len(config.cond_parameter_names) != config.cond_dim: | ||
| raise ValueError( | ||
| f"If conditional dimension is set, " | ||
| f"you also need to provide {config.cond_dim} conditional " | ||
| f"parameter names. You provided {len(config.cond_parameter_names)}." | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| cond_samples, cond_samples_metadata = load_posterior( | ||
| config.posterior_file, | ||
| parameter_names=config.cond_parameter_names, | ||
| max_samples=config.max_samples | ||
| ) | ||
|
|
||
| # Standardize data if requested | ||
| original_cond_samples = cond_samples.copy() | ||
| cond_data_statistics = None | ||
| if config.standardize: | ||
| cond_samples, cond_data_statistics = standardize_data( | ||
| cond_samples, | ||
| config.standardization_method, | ||
| ) | ||
|
|
||
| original_data = np.hstack((original_data, original_cond_samples)) | ||
| data = (data, cond_samples) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| # Create flow | ||
| logger.info("[2/5] Creating flow architecture...") | ||
| flow_key, train_key, sample_key = jax.random.split(jax.random.key(config.seed), 3) | ||
| dim = data.shape[1] # Infer dimensionality from data | ||
| logger.info(f"Flow dimensionality: {dim}D") | ||
|
|
||
| flow = create_flow( | ||
| key=flow_key, | ||
| dim=dim, | ||
|
|
@@ -531,7 +589,6 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: | |
|
|
||
| # Train flow | ||
| logger.info("[3/5] Training flow...") | ||
| logger.info(f"Training dataset shape: {data.shape}") | ||
| trained_flow, losses = train_flow( | ||
| flow, | ||
| data, | ||
|
|
@@ -591,6 +648,15 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: | |
| else: # minmax | ||
| metadata["data_bounds_min"] = data_statistics["min"].tolist() | ||
| metadata["data_bounds_max"] = data_statistics["max"].tolist() | ||
|
|
||
| # Add conditional data statistics to metadata if standardization was used | ||
| if config.cond_dim: | ||
| if config.standardization_method == "zscore": | ||
| metadata["cond_data_mean"] = cond_data_statistics["mean"].tolist() | ||
| metadata["cond_data_std"] = cond_data_statistics["std"].tolist() | ||
| else: | ||
| metadata["cond_data_min"] = cond_data_statistics["min"].tolist() | ||
| metadata["cond_data_max"] = cond_data_statistics["max"].tolist() | ||
|
coderabbitai[bot] marked this conversation as resolved.
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| save_model(trained_flow, config.output_dir, flow_kwargs, metadata) | ||
|
|
||
|
|
@@ -608,26 +674,46 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: | |
| if config.plot_corner: | ||
| try: | ||
| # Sample from trained flow | ||
| n_plot_samples = min(10_000, data.shape[0]) | ||
| flow_samples = trained_flow.sample(sample_key, (n_plot_samples,)) | ||
| flow_samples_np = np.array(flow_samples) | ||
|
|
||
| # Inverse transform samples if data was standardized | ||
| if config.standardize and data_statistics is not None: | ||
| if config.standardization_method == "zscore": | ||
| flow_samples_np = inverse_standardize_data_zscore( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
| else: # minmax | ||
| flow_samples_np = inverse_standardize_data_minmax( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
| n_plot_samples = min(10_000, original_data.shape[0]) | ||
|
|
||
| if config.cond_dim: | ||
| # get flow samples and untransform them | ||
| flow_samples = trained_flow.sample(sample_key, (1,), condition=cond_samples) | ||
| flow_samples_np = np.array(flow_samples) | ||
| if config.standardize and data_statistics is not None: | ||
| if config.standardization_method == "zscore": | ||
| flow_samples_np = inverse_standardize_data_zscore( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
| else: # minmax | ||
| flow_samples_np = inverse_standardize_data_minmax( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
| # stack together with conditional data for plot | ||
| flow_samples_np = np.hstack((flow_samples_np.reshape(-1, 1), original_data[:, -config.cond_dim:])) | ||
| labels = [*config.parameter_names, *config.cond_parameter_names] | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| else: | ||
| flow_samples = trained_flow.sample(sample_key, (n_plot_samples,)) | ||
| flow_samples_np = np.array(flow_samples) | ||
| labels = parameter_names | ||
|
|
||
| # Inverse transform samples if data was standardized | ||
| if config.standardize and data_statistics is not None: | ||
| if config.standardization_method == "zscore": | ||
| flow_samples_np = inverse_standardize_data_zscore( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
| else: # minmax | ||
| flow_samples_np = inverse_standardize_data_minmax( | ||
| flow_samples_np, data_statistics | ||
| ) | ||
|
|
||
| corner_path = os.path.join(figures_dir, "corner.png") | ||
| # Use original_data for corner plot comparison | ||
| # Update labels based on parameter names | ||
| plot_corner( | ||
| original_data, flow_samples_np, corner_path, labels=parameter_names | ||
| original_data, flow_samples_np, corner_path, labels=labels | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fail fast on inconsistent conditional configs.
This addition exposes
cond_parameter_names, but the model still acceptscond_dim=0, negativecond_dim,cond_parameter_nameswithoutcond_dim, and lists whose length does not matchcond_dim. The runtime guard injesterTOV/inference/flows/train_flow.pyonly checks presence, so some bad configs are silently ignored while others fail much later inside flow creation/training.🤖 Prompt for AI Agents