diff --git a/jesterTOV/inference/postprocessing/postprocessing.py b/jesterTOV/inference/postprocessing/postprocessing.py index 119ae58d..721a1f53 100644 --- a/jesterTOV/inference/postprocessing/postprocessing.py +++ b/jesterTOV/inference/postprocessing/postprocessing.py @@ -329,6 +329,38 @@ def load_injection_eos( return None +def _split_into_monotone_branches( + masses: np.ndarray, lambdas: np.ndarray +) -> list[tuple[int, int]]: + """Split a mass-Lambda curve into monotone-decreasing segments. + + A branch break is detected wherever Lambda increases as mass increases, + which signals that the stored curve contains two interleaved stable branches + (the "third family" / twin-star scenario arising from a fold-back in M(pc)). + + Parameters + ---------- + masses : np.ndarray + Mass array (monotone increasing, uniform grid). + lambdas : np.ndarray + Tidal deformability array on the same grid. + + Returns + ------- + list of (start, end) index pairs + Each pair defines a half-open slice ``masses[start:end]`` that is + monotone decreasing in Lambda. Always contains at least one segment. + """ + segments: list[tuple[int, int]] = [] + start = 0 + for j in range(1, len(lambdas)): + if lambdas[j] > lambdas[j - 1]: + segments.append((start, j)) + start = j + segments.append((start, len(masses))) + return segments + + def report_credible_interval( values: np.ndarray, hdi_prob: float = HDI_PROB, verbose: bool = False ) -> tuple: @@ -567,19 +599,33 @@ def make_mass_radius_plot( f"Plotting {len(valid_indices)} M-R curves (excluded {bad_counter} invalid samples)..." ) - # Second pass: plot only valid samples + # Second pass: plot only valid samples, splitting multi-branch curves + n_unstable_mr = 0 for i in valid_indices: - # Get color based on probability normalized_value = norm(prob[i]) color = cmap(normalized_value) - plt.plot( - r[i], - m[i], - color=color, - alpha=1.0, - rasterized=True, - zorder=1e10 + normalized_value, + branches = _split_into_monotone_branches(m[i], l[i]) + if len(branches) > 1: + n_unstable_mr += 1 + for start, end in branches: + plt.plot( + r[i][start:end], + m[i][start:end], + color=color, + alpha=1.0, + rasterized=True, + zorder=1e10 + normalized_value, + ) + + if n_unstable_mr > 0: + pct = 100.0 * n_unstable_mr / len(valid_indices) + logger.warning( + f"{n_unstable_mr}/{len(valid_indices)} ({pct:.1f}%) samples had an " + "unstable part (multi-branch) in their NS solution. These samples are not " + "accounted for properly during inference; the plot shows each stable branch " + "separately. If this percentage is low, the impact on the posterior is " + "negligible." ) # Plot injection EOS if provided (on top of everything else) @@ -740,19 +786,33 @@ def make_mass_lambda_plot( f"Plotting {len(valid_indices)} M-Lambda curves (excluded {bad_counter} invalid samples)..." ) - # Second pass: plot only valid samples + # Second pass: plot only valid samples, splitting multi-branch curves + n_unstable_ml = 0 for i in valid_indices: - # Get color based on probability normalized_value = norm(prob[i]) color = cmap(normalized_value) - plt.plot( - m[i], - l[i], - color=color, - alpha=1.0, - rasterized=True, - zorder=1e10 + normalized_value, + branches = _split_into_monotone_branches(m[i], l[i]) + if len(branches) > 1: + n_unstable_ml += 1 + for start, end in branches: + plt.plot( + m[i][start:end], + l[i][start:end], + color=color, + alpha=1.0, + rasterized=True, + zorder=1e10 + normalized_value, + ) + + if n_unstable_ml > 0: + pct = 100.0 * n_unstable_ml / len(valid_indices) + logger.warning( + f"{n_unstable_ml}/{len(valid_indices)} ({pct:.1f}%) samples had an " + "unstable part (multi-branch) in their NS solution. These samples are not " + "accounted for properly during inference; the plot shows each stable branch " + "separately. If this percentage is low, the impact on the posterior is " + "negligible." ) # Plot injection EOS if provided (on top of everything else) diff --git a/tests/test_inference/test_postprocessing.py b/tests/test_inference/test_postprocessing.py index b7581a74..39986fc8 100644 --- a/tests/test_inference/test_postprocessing.py +++ b/tests/test_inference/test_postprocessing.py @@ -17,6 +17,7 @@ make_pressure_density_plot, make_cs2_plot, make_parameter_histograms, + _split_into_monotone_branches, ) from jesterTOV.inference.result import InferenceResult @@ -520,3 +521,107 @@ def test_full_workflow_flowmc(self, temp_dir): assert (temp_dir / "cs2_density_plot.pdf").exists() plt.close("all") + + +class TestSplitIntoMonotoneBranches: + """Tests for the multi-branch detection helper.""" + + def test_monotone_decreasing_returns_one_segment(self): + """Smooth Lambda(M) with no increases → single segment covering all data.""" + masses = np.array([0.5, 0.8, 1.1, 1.4, 1.7, 2.0]) + lambdas = np.array([2000.0, 800.0, 300.0, 120.0, 50.0, 20.0]) + branches = _split_into_monotone_branches(masses, lambdas) + assert len(branches) == 1 + assert branches[0] == (0, 6) + + def test_single_lambda_increase_splits_into_two_branches(self): + """One Lambda increase mid-curve → two branches.""" + masses = np.array([0.5, 0.8, 1.1, 1.2, 1.5, 1.8]) + lambdas = np.array([2000.0, 800.0, 300.0, 1500.0, 400.0, 100.0]) + branches = _split_into_monotone_branches(masses, lambdas) + assert len(branches) == 2 + assert branches[0] == (0, 3) + assert branches[1] == (3, 6) + + def test_multiple_lambda_increases_many_branches(self): + """Multiple Lambda increases → multiple branches.""" + masses = np.linspace(0.5, 2.0, 10) + lambdas = np.array( + [1000.0, 500.0, 800.0, 200.0, 600.0, 100.0, 300.0, 50.0, 80.0, 20.0] + ) + branches = _split_into_monotone_branches(masses, lambdas) + assert len(branches) > 2 + assert branches[0][0] == 0 + assert branches[-1][1] == len(masses) + + def test_constant_lambda_returns_one_segment(self): + """Flat Lambda → no increases → single segment.""" + masses = np.linspace(0.5, 2.0, 5) + lambdas = np.ones(5) * 500.0 + branches = _split_into_monotone_branches(masses, lambdas) + assert len(branches) == 1 + + def test_all_segments_cover_full_range(self): + """Branch segments are contiguous and cover the full array without gaps.""" + masses = np.linspace(0.5, 2.0, 20) + rng = np.random.default_rng(42) + lambdas = np.cumsum(rng.uniform(-200, 100, 20)) + 2000.0 + branches = _split_into_monotone_branches(masses, lambdas) + assert branches[0][0] == 0 + assert branches[-1][1] == len(masses) + for k in range(len(branches) - 1): + assert branches[k][1] == branches[k + 1][0] + + def test_plot_runs_without_error_when_unstable(self, tmp_path): + """make_mass_radius_plot runs cleanly when multi-branch samples are present.""" + n_eos = 50 + masses_jagged = np.linspace(0.75, 2.5, n_eos) + lambdas_jagged = np.concatenate( + [ + np.linspace(2000, 400, 25), + np.linspace(1500, 50, 25), + ] + ) + radii_jagged = np.linspace(12.0, 10.0, n_eos) + + masses_smooth = np.linspace(0.75, 2.5, n_eos) + lambdas_smooth = np.linspace(2000, 10, n_eos) + radii_smooth = np.linspace(12.0, 10.0, n_eos) + + data = { + "masses": np.array([masses_jagged, masses_smooth]), + "radii": np.array([radii_jagged, radii_smooth]), + "lambdas": np.array([lambdas_jagged, lambdas_smooth]), + "densities": np.ones((2, n_eos)), + "pressures": np.ones((2, n_eos)), + "cs2": np.ones((2, n_eos)) * 0.3, + "log_prob": np.array([-10.0, -10.0]), + "prior_params": {}, + } + + make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path)) + plt.close("all") + assert (tmp_path / "mass_radius_plot.pdf").exists() + + def test_plot_runs_without_error_when_smooth(self, tmp_path): + """make_mass_radius_plot runs cleanly when all samples are smooth.""" + n_eos = 50 + n_samples = 3 + masses = np.tile(np.linspace(0.75, 2.5, n_eos), (n_samples, 1)) + lambdas = np.tile(np.linspace(2000, 10, n_eos), (n_samples, 1)) + radii = np.tile(np.linspace(12.0, 10.0, n_eos), (n_samples, 1)) + + data = { + "masses": masses, + "radii": radii, + "lambdas": lambdas, + "densities": np.ones((n_samples, n_eos)), + "pressures": np.ones((n_samples, n_eos)), + "cs2": np.ones((n_samples, n_eos)) * 0.3, + "log_prob": np.full(n_samples, -10.0), + "prior_params": {}, + } + + make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path)) + plt.close("all") + assert (tmp_path / "mass_radius_plot.pdf").exists()