-
Notifications
You must be signed in to change notification settings - Fork 7
First attempt to add conditional NFs to jester infrastructure #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f618d38
fd31ffd
84a4bfb
41640ae
65ece2e
0dd578a
09c5741
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """Normalizing flow models for gravitational wave inference.""" | ||
|
|
||
| from .flow import Flow, load_model | ||
| from .flow import Flow, ConditionalFlow, load_model | ||
| from .bilby_extract import extract_gw_posterior_from_bilby | ||
|
|
||
| __all__ = ["Flow", "load_model", "extract_gw_posterior_from_bilby"] | ||
| __all__ = ["Flow", "ConditionalFlow", "load_model", "extract_gw_posterior_from_bilby"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -291,6 +291,134 @@ def log_prob(self, x: Array) -> Array: | |
|
|
||
| return log_p | ||
|
|
||
| class ConditionalFlow(Flow): | ||
|
|
||
|
|
||
| def __init__( | ||
| self, | ||
| flow: AbstractDistribution, | ||
| metadata: Dict[str, Any], | ||
| flow_kwargs: Dict[str, Any], | ||
| ): | ||
| """ | ||
| Initialize Flow wrapper. | ||
|
|
||
| Args: | ||
| flow: Trained flowjax flow model | ||
| metadata: Training metadata | ||
| flow_kwargs: Flow architecture kwargs | ||
| """ | ||
|
|
||
| super().__init__(flow, metadata, flow_kwargs) | ||
|
|
||
| if self.standardization_method=="zscore": | ||
| self.cond_data_mean = jnp.array(metadata["cond_data_mean"]) | ||
| self.cond_data_std = jnp.array(metadata["cond_data_std"]) | ||
|
|
||
| elif self.standardization_method=="minmax": | ||
| self.cond_data_min = jnp.array(metadata["cond_data_min"]) | ||
| self.cond_data_max = jnp.array(metadata["cond_data_max"]) | ||
| self.cond_data_range = self.cond_data_max - self.cond_data_min | ||
|
|
||
| def standardize_cond_data(self, data: Array) -> Array: | ||
|
|
||
| if self.standardization_method == "zscore": | ||
| return (data - self.cond_data_mean) / self.cond_data_std | ||
| else: | ||
| return (data - self.cond_data_min) / self.cond_data_range | ||
|
Comment on lines
+314
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the conditional standardization math. The z-score branch divides by 🛠️ Suggested fix if self.standardization_method=="zscore":
self.cond_data_mean = jnp.array(metadata["cond_data_mean"])
self.cond_data_std = jnp.array(metadata["cond_data_std"])
+ self.cond_data_std = jnp.where(
+ self.cond_data_std == 0, 1.0, self.cond_data_std
+ )
elif self.standardization_method=="minmax":
self.cond_data_min = jnp.array(metadata["cond_data_min"])
self.cond_data_max = jnp.array(metadata["cond_data_max"])
self.cond_data_range = self.cond_data_max - self.cond_data_min
+ self.cond_data_range = jnp.where(
+ self.cond_data_range == 0, 1.0, self.cond_data_range
+ )
def standardize_cond_data(self, data: Array) -> Array:
if self.standardization_method == "zscore":
- return (data - self.cond_data_mean) / self.cond_data_mean
- else:
+ return (data - self.cond_data_mean) / self.cond_data_std
+ if self.standardization_method == "minmax":
return (data - self.cond_data_min) / self.cond_data_range
+ return data🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| @classmethod | ||
| def from_directory(cls, output_dir: str) -> "ConditionalFlow": | ||
| """ | ||
| Load a trained conditional flow from a directory. | ||
|
|
||
| Args: | ||
| output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json | ||
|
|
||
| Returns: | ||
| ConditionalFlow instance with loaded model and metadata | ||
|
|
||
| """ | ||
| # Load the flow model and metadata | ||
| flow_model, metadata = load_model(output_dir) | ||
|
|
||
| # Load kwargs | ||
| kwargs_path = os.path.join(output_dir, "flow_kwargs.json") | ||
| with open(kwargs_path, "r") as f: | ||
| flow_kwargs = json.load(f) | ||
|
|
||
| return cls(flow_model, metadata, flow_kwargs) | ||
|
|
||
| def log_prob(self, x: Array, y: Array) -> Array: | ||
| """ | ||
| Evaluate log probability of samples under the flow. | ||
|
|
||
| If standardization was used, input data is automatically standardized | ||
| before evaluation and Jacobian correction is applied. If not, operations | ||
| are identity (no-op). | ||
|
|
||
| The Jacobian correction accounts for the change of variables: | ||
| - Z-score: log p(x|y) = log p(x_std|y_std) - sum(log(std)) | ||
| - Min-max: log p(x|y) = log p(x_std|y_std) - sum(log(max - min)) | ||
| - None: log p(x|y) = log p(x_std|y_std) (no correction) | ||
|
|
||
|
Comment on lines
+354
to
+365
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Format the new probability equations as Sphinx math. The As per coding guidelines "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with 🤖 Prompt for AI Agents |
||
| Args: | ||
| x: Data in original scale, shape (n_samples, n_features). | ||
| JAX array. | ||
| y: Conditional parameter, shape (n_samples, cond_dim). | ||
| JAX array. | ||
|
|
||
| Returns: | ||
| Log probabilities as JAX array, shape (n_samples,) | ||
|
|
||
| Example: | ||
| >>> data = jnp.array([[1.4, 1.3, 100, 200]]) | ||
| >>> y = jnp.array([[2.0, 3.0]]) | ||
| >>> log_prob = flow.log_prob(data, y=y) | ||
| """ | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| # Standardize input (method-dependent or identity) | ||
| x_std = self.standardize_input(x) | ||
| y_std = self.standardize_cond_data(y) | ||
|
|
||
| # Evaluate log probability in standardized space | ||
| log_p = self.flow.log_prob(x_std, y_std) | ||
|
|
||
| # Account for Jacobian of inverse transformation | ||
| if self.standardization_method == "zscore": | ||
| # Z-score: log |det J| = sum(log(std)) | ||
| log_det_jacobian = -jnp.sum(jnp.log(self.data_std)) | ||
| else: | ||
| # Min-max or none: log |det J| = sum(log(range)) | ||
| # If standardization disabled (range=1), log_det_jacobian = 0 | ||
| log_det_jacobian = -jnp.sum(jnp.log(self.data_range)) | ||
|
|
||
| log_p = log_p + log_det_jacobian | ||
|
|
||
| return log_p | ||
|
|
||
| def sample(self, key: Array, shape: Tuple[int, ...], y: Array) -> Array: | ||
| """ | ||
| Sample from the conditional flow and return in original scale. | ||
|
|
||
| If standardization was used during training, samples are automatically | ||
| converted back to the original scale using the inverse transformation | ||
| (z-score or min-max). If not, the transformation is identity (no-op). | ||
|
|
||
| Args: | ||
| key: JAX random key (jax.Array) | ||
| shape: Shape of samples to generate (e.g., (1000,) for 1000 samples) | ||
| y: Conditional parameters (Array) | ||
| Returns: | ||
| Samples in original scale as JAX array of shape (*shape, y.shape[0], n_features) | ||
| """ | ||
|
|
||
| y_std = self.standardize_cond_data(y_std) | ||
| samples = self.flow.sample(key, shape, condition=y) | ||
|
|
||
| samples = self.destandardize_output(samples) | ||
|
Comment on lines
+417
to
+419
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Training standardizes conditional inputs before fitting when 🛠️ Suggested fix- samples = self.flow.sample(key, shape, condition=y)
+ y_std = self.standardize_cond_data(y)
+ samples = self.flow.sample(key, shape, condition=y_std)🤖 Prompt for AI Agents |
||
|
|
||
| return samples | ||
|
|
||
| def create_transformer( | ||
| transformer_type: str = "affine", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fail fast on inconsistent conditional configs.
This addition exposes
cond_parameter_names, but the model still acceptscond_dim=0, negativecond_dim,cond_parameter_nameswithoutcond_dim, and lists whose length does not matchcond_dim. The runtime guard injesterTOV/inference/flows/train_flow.pyonly checks presence, so some bad configs are silently ignored while others fail much later inside flow creation/training.🤖 Prompt for AI Agents