diff --git a/jesterTOV/inference/flows/__init__.py b/jesterTOV/inference/flows/__init__.py index ed35ce5a..458a27b1 100644 --- a/jesterTOV/inference/flows/__init__.py +++ b/jesterTOV/inference/flows/__init__.py @@ -1,6 +1,6 @@ """Normalizing flow models for gravitational wave inference.""" -from .flow import Flow, load_model +from .flow import Flow, ConditionalFlow, load_model from .bilby_extract import extract_gw_posterior_from_bilby -__all__ = ["Flow", "load_model", "extract_gw_posterior_from_bilby"] +__all__ = ["Flow", "ConditionalFlow", "load_model", "extract_gw_posterior_from_bilby"] diff --git a/jesterTOV/inference/flows/config.py b/jesterTOV/inference/flows/config.py index 74fa1dec..4d1419b5 100644 --- a/jesterTOV/inference/flows/config.py +++ b/jesterTOV/inference/flows/config.py @@ -40,7 +40,11 @@ class FlowTrainingConfig(BaseModel): invert : bool Whether to invert the flow (default: True) cond_dim : int | None - Conditional dimension for conditional flows (default: None) + If set, a conditional flow will be trained. + cond_dim is the dimension for the conditional data. + In this case, `conditional_parameter_names` needs to be provided as well (default: None) + cond_parameter_names: list[str] | None + List of parameter names that are the conditional data (default: None) max_samples : int Maximum number of samples to use for training (default: 50,000) seed : int @@ -84,6 +88,7 @@ class FlowTrainingConfig(BaseModel): flow_layers: int = 1 invert: bool = True cond_dim: int | None = None + cond_parameter_names: list[str] | None = None max_samples: int = 50_000 seed: int = 0 plot_corner: bool = True diff --git a/jesterTOV/inference/flows/flow.py b/jesterTOV/inference/flows/flow.py index 2140c456..b4976950 100644 --- a/jesterTOV/inference/flows/flow.py +++ b/jesterTOV/inference/flows/flow.py @@ -291,6 +291,134 @@ def log_prob(self, x: Array) -> Array: return log_p +class ConditionalFlow(Flow): + + + def __init__( + self, + flow: AbstractDistribution, + metadata: Dict[str, Any], + flow_kwargs: Dict[str, Any], + ): + """ + Initialize Flow wrapper. + + Args: + flow: Trained flowjax flow model + metadata: Training metadata + flow_kwargs: Flow architecture kwargs + """ + + super().__init__(flow, metadata, flow_kwargs) + + if self.standardization_method=="zscore": + self.cond_data_mean = jnp.array(metadata["cond_data_mean"]) + self.cond_data_std = jnp.array(metadata["cond_data_std"]) + + elif self.standardization_method=="minmax": + self.cond_data_min = jnp.array(metadata["cond_data_min"]) + self.cond_data_max = jnp.array(metadata["cond_data_max"]) + self.cond_data_range = self.cond_data_max - self.cond_data_min + + def standardize_cond_data(self, data: Array) -> Array: + + if self.standardization_method == "zscore": + return (data - self.cond_data_mean) / self.cond_data_std + else: + return (data - self.cond_data_min) / self.cond_data_range + + + @classmethod + def from_directory(cls, output_dir: str) -> "ConditionalFlow": + """ + Load a trained conditional flow from a directory. + + Args: + output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json + + Returns: + ConditionalFlow instance with loaded model and metadata + + """ + # Load the flow model and metadata + flow_model, metadata = load_model(output_dir) + + # Load kwargs + kwargs_path = os.path.join(output_dir, "flow_kwargs.json") + with open(kwargs_path, "r") as f: + flow_kwargs = json.load(f) + + return cls(flow_model, metadata, flow_kwargs) + + def log_prob(self, x: Array, y: Array) -> Array: + """ + Evaluate log probability of samples under the flow. + + If standardization was used, input data is automatically standardized + before evaluation and Jacobian correction is applied. If not, operations + are identity (no-op). + + The Jacobian correction accounts for the change of variables: + - Z-score: log p(x|y) = log p(x_std|y_std) - sum(log(std)) + - Min-max: log p(x|y) = log p(x_std|y_std) - sum(log(max - min)) + - None: log p(x|y) = log p(x_std|y_std) (no correction) + + Args: + x: Data in original scale, shape (n_samples, n_features). + JAX array. + y: Conditional parameter, shape (n_samples, cond_dim). + JAX array. + + Returns: + Log probabilities as JAX array, shape (n_samples,) + + Example: + >>> data = jnp.array([[1.4, 1.3, 100, 200]]) + >>> y = jnp.array([[2.0, 3.0]]) + >>> log_prob = flow.log_prob(data, y=y) + """ + # Standardize input (method-dependent or identity) + x_std = self.standardize_input(x) + y_std = self.standardize_cond_data(y) + + # Evaluate log probability in standardized space + log_p = self.flow.log_prob(x_std, y_std) + + # Account for Jacobian of inverse transformation + if self.standardization_method == "zscore": + # Z-score: log |det J| = sum(log(std)) + log_det_jacobian = -jnp.sum(jnp.log(self.data_std)) + else: + # Min-max or none: log |det J| = sum(log(range)) + # If standardization disabled (range=1), log_det_jacobian = 0 + log_det_jacobian = -jnp.sum(jnp.log(self.data_range)) + + log_p = log_p + log_det_jacobian + + return log_p + + def sample(self, key: Array, shape: Tuple[int, ...], y: Array) -> Array: + """ + Sample from the conditional flow and return in original scale. + + If standardization was used during training, samples are automatically + converted back to the original scale using the inverse transformation + (z-score or min-max). If not, the transformation is identity (no-op). + + Args: + key: JAX random key (jax.Array) + shape: Shape of samples to generate (e.g., (1000,) for 1000 samples) + y: Conditional parameters (Array) + Returns: + Samples in original scale as JAX array of shape (*shape, y.shape[0], n_features) + """ + + y_std = self.standardize_cond_data(y_std) + samples = self.flow.sample(key, shape, condition=y) + + samples = self.destandardize_output(samples) + + return samples def create_transformer( transformer_type: str = "affine", diff --git a/jesterTOV/inference/flows/train_flow.py b/jesterTOV/inference/flows/train_flow.py index 71e85a27..b8474b05 100644 --- a/jesterTOV/inference/flows/train_flow.py +++ b/jesterTOV/inference/flows/train_flow.py @@ -88,7 +88,7 @@ import os import sys from pathlib import Path -from typing import Any, Dict, Tuple, Mapping +from typing import Any, Dict, Tuple, Mapping, Iterable import equinox as eqx import jax @@ -167,6 +167,45 @@ def load_posterior( return data, metadata +def standardize_data( + data: np.ndarray, + standardization_method: str, + parameter_names: list[str] = [] +) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: + """ + Standardize data based on the selected standardiation method. + + Args: + data: Array of shape (n_samples, n_features) + standardization_method: str + Which method to standardize with. + Can either be 'zscore' for zero mean and unit std, + otherwise min-max scaling will be applied. + parameter_names: + Parameter names for which the rescaled data range should be printed. + Defaults to []. + """ + + if standardization_method == "zscore": + logger.info("Standardizing data using z-score (mean=0, std=1)...") + data, data_statistics = standardize_data_zscore(data) + logger.info("Standardized data statistics:") + for i, name in enumerate(parameter_names): + logger.info( + f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}" + ) + logger.info("Data mean and std saved for inverse transformation") + else: # minmax + logger.info("Standardizing data using min-max [0, 1] scaling...") + data, data_statistics = standardize_data_minmax(data) + logger.info("Standardized data ranges:") + for i, name in enumerate(parameter_names): + logger.info( + f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]" + ) + logger.info("Data bounds saved for inverse transformation") + + return data, data_statistics def standardize_data_zscore( data: np.ndarray, @@ -268,7 +307,7 @@ def inverse_standardize_data_minmax( def train_flow( flow: Any, - data: np.ndarray, + data: np.ndarray | Iterable[np.ndarray], key: Array, learning_rate: float = 1e-3, max_epochs: int = 600, @@ -281,7 +320,9 @@ def train_flow( Args: flow: Untrained flowjax flow - data: Training data of shape (n_samples, n_dims) + data: Training data of shape (n_samples, n_dims) + or if conditional flow is trained iterable of two arrays + where the last array are the conditional parameters. key: JAX random key learning_rate: Learning rate for optimizer max_epochs: Maximum number of epochs @@ -455,6 +496,7 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: logger.info(f"Flow layers: {config.flow_layers}") logger.info(f"Invert: {config.invert}") logger.info(f"Cond dim: {config.cond_dim}") + logger.info(f"Cond. parameters: {config.cond_parameter_names}") logger.info(f"Transformer: {config.transformer}") logger.info(f"Transformer knots: {config.transformer_knots}") logger.info(f"Transformer interval: {config.transformer_interval}") @@ -490,30 +532,47 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: # Standardize data if requested data_statistics = None if config.standardize: - if config.standardization_method == "zscore": - logger.info("Standardizing data using z-score (mean=0, std=1)...") - data, data_statistics = standardize_data_zscore(data) - logger.info("Standardized data statistics:") - for i, name in enumerate(parameter_names): - logger.info( - f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}" - ) - logger.info("Data mean and std saved for inverse transformation") - else: # minmax - logger.info("Standardizing data using min-max [0, 1] scaling...") - data, data_statistics = standardize_data_minmax(data) - logger.info("Standardized data ranges:") - for i, name in enumerate(parameter_names): - logger.info( - f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]" - ) - logger.info("Data bounds saved for inverse transformation") + data, data_statistics = standardize_data( + data, + config.standardization_method, + parameter_names, + ) + dim = data.shape[1] # Infer dimensionality from data + + if config.cond_dim: + logger.info("[1.5/5] Loading conditional samples...") + + cond_parameter_names = config.cond_parameter_names or [] + if len(cond_parameter_names) != config.cond_dim: + raise ValueError( + f"If conditional dimension is set, " + f"you also need to provide {config.cond_dim} conditional " + f"parameter names. You provided {len(cond_parameter_names)}." + ) + + cond_samples, cond_samples_metadata = load_posterior( + config.posterior_file, + parameter_names=cond_parameter_names, + max_samples=config.max_samples + ) + + # Standardize data if requested + original_cond_samples = cond_samples.copy() + cond_data_statistics = None + if config.standardize: + cond_samples, cond_data_statistics = standardize_data( + cond_samples, + config.standardization_method, + ) + + original_data = np.hstack((original_data, original_cond_samples)) + data = (data, cond_samples) # Create flow logger.info("[2/5] Creating flow architecture...") flow_key, train_key, sample_key = jax.random.split(jax.random.key(config.seed), 3) - dim = data.shape[1] # Infer dimensionality from data logger.info(f"Flow dimensionality: {dim}D") + flow = create_flow( key=flow_key, dim=dim, @@ -531,7 +590,6 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: # Train flow logger.info("[3/5] Training flow...") - logger.info(f"Training dataset shape: {data.shape}") trained_flow, losses = train_flow( flow, data, @@ -591,6 +649,15 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: else: # minmax metadata["data_bounds_min"] = data_statistics["min"].tolist() metadata["data_bounds_max"] = data_statistics["max"].tolist() + + # Add conditional data statistics to metadata if standardization was used + if config.cond_dim and cond_data_statistics is not None: + if config.standardization_method == "zscore": + metadata["cond_data_mean"] = cond_data_statistics["mean"].tolist() + metadata["cond_data_std"] = cond_data_statistics["std"].tolist() + else: + metadata["cond_data_min"] = cond_data_statistics["min"].tolist() + metadata["cond_data_max"] = cond_data_statistics["max"].tolist() save_model(trained_flow, config.output_dir, flow_kwargs, metadata) @@ -608,26 +675,49 @@ def train_flow_from_config(config: FlowTrainingConfig) -> None: if config.plot_corner: try: # Sample from trained flow - n_plot_samples = min(10_000, data.shape[0]) - flow_samples = trained_flow.sample(sample_key, (n_plot_samples,)) - flow_samples_np = np.array(flow_samples) - - # Inverse transform samples if data was standardized - if config.standardize and data_statistics is not None: - if config.standardization_method == "zscore": - flow_samples_np = inverse_standardize_data_zscore( - flow_samples_np, data_statistics - ) - else: # minmax - flow_samples_np = inverse_standardize_data_minmax( - flow_samples_np, data_statistics - ) + n_plot_samples = min(10_000, original_data.shape[0]) + + if config.cond_dim: + # get flow samples and untransform them + n_cond = cond_samples.shape[0] + flow_samples = trained_flow.sample(sample_key, (1,), condition=cond_samples) # generate one sample per condition point + flow_samples_np = np.array(flow_samples) + if config.standardize and data_statistics is not None: + if config.standardization_method == "zscore": + flow_samples_np = inverse_standardize_data_zscore( + flow_samples_np, data_statistics + ) + else: # minmax + flow_samples_np = inverse_standardize_data_minmax( + flow_samples_np, data_statistics + ) + # stack together with conditional data for plot + flow_samples_np = np.hstack( + (flow_samples_np.reshape(-1, len(config.parameter_names)), original_data[:n_cond, -config.cond_dim:]) + ) + labels = [*config.parameter_names, *config.cond_parameter_names] + + else: + flow_samples = trained_flow.sample(sample_key, (n_plot_samples,)) + flow_samples_np = np.array(flow_samples) + labels = parameter_names + + # Inverse transform samples if data was standardized + if config.standardize and data_statistics is not None: + if config.standardization_method == "zscore": + flow_samples_np = inverse_standardize_data_zscore( + flow_samples_np, data_statistics + ) + else: # minmax + flow_samples_np = inverse_standardize_data_minmax( + flow_samples_np, data_statistics + ) corner_path = os.path.join(figures_dir, "corner.png") # Use original_data for corner plot comparison # Update labels based on parameter names plot_corner( - original_data, flow_samples_np, corner_path, labels=parameter_names + original_data, flow_samples_np, corner_path, labels=labels ) except Exception as e: logger.warning(