-
Notifications
You must be signed in to change notification settings - Fork 7
Patch for unstable part in branches in postprocessing #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -329,6 +329,38 @@ def load_injection_eos( | |
| return None | ||
|
|
||
|
|
||
| def _split_into_monotone_branches( | ||
| masses: np.ndarray, lambdas: np.ndarray | ||
| ) -> list[tuple[int, int]]: | ||
| """Split a mass-Lambda curve into monotone-decreasing segments. | ||
|
|
||
| A branch break is detected wherever Lambda increases as mass increases, | ||
| which signals that the stored curve contains two interleaved stable branches | ||
| (the "third family" / twin-star scenario arising from a fold-back in M(pc)). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| masses : np.ndarray | ||
| Mass array (monotone increasing, uniform grid). | ||
| lambdas : np.ndarray | ||
| Tidal deformability array on the same grid. | ||
|
|
||
| Returns | ||
| ------- | ||
| list of (start, end) index pairs | ||
| Each pair defines a half-open slice ``masses[start:end]`` that is | ||
| monotone decreasing in Lambda. Always contains at least one segment. | ||
| """ | ||
| segments: list[tuple[int, int]] = [] | ||
| start = 0 | ||
| for j in range(1, len(lambdas)): | ||
| if lambdas[j] > lambdas[j - 1]: | ||
| segments.append((start, j)) | ||
| start = j | ||
| segments.append((start, len(masses))) | ||
| return segments | ||
|
Comment on lines
+354
to
+361
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate the paired array lengths before returning slice ranges. The function detects breaks over Proposed fix segments: list[tuple[int, int]] = []
+ if len(masses) != len(lambdas):
+ raise ValueError(
+ "masses and lambdas must have the same length to split monotone branches"
+ )
+
start = 0🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def report_credible_interval( | ||
| values: np.ndarray, hdi_prob: float = HDI_PROB, verbose: bool = False | ||
| ) -> tuple: | ||
|
|
@@ -567,19 +599,33 @@ def make_mass_radius_plot( | |
| f"Plotting {len(valid_indices)} M-R curves (excluded {bad_counter} invalid samples)..." | ||
| ) | ||
|
|
||
| # Second pass: plot only valid samples | ||
| # Second pass: plot only valid samples, splitting multi-branch curves | ||
| n_unstable_mr = 0 | ||
| for i in valid_indices: | ||
| # Get color based on probability | ||
| normalized_value = norm(prob[i]) | ||
| color = cmap(normalized_value) | ||
|
|
||
| plt.plot( | ||
| r[i], | ||
| m[i], | ||
| color=color, | ||
| alpha=1.0, | ||
| rasterized=True, | ||
| zorder=1e10 + normalized_value, | ||
| branches = _split_into_monotone_branches(m[i], l[i]) | ||
| if len(branches) > 1: | ||
| n_unstable_mr += 1 | ||
| for start, end in branches: | ||
| plt.plot( | ||
| r[i][start:end], | ||
| m[i][start:end], | ||
| color=color, | ||
| alpha=1.0, | ||
| rasterized=True, | ||
| zorder=1e10 + normalized_value, | ||
| ) | ||
|
|
||
| if n_unstable_mr > 0: | ||
| pct = 100.0 * n_unstable_mr / len(valid_indices) | ||
| logger.warning( | ||
| f"{n_unstable_mr}/{len(valid_indices)} ({pct:.1f}%) samples had an " | ||
| "unstable part (multi-branch) in their NS solution. These samples are not " | ||
| "accounted for properly during inference; the plot shows each stable branch " | ||
| "separately. If this percentage is low, the impact on the posterior is " | ||
| "negligible." | ||
| ) | ||
|
|
||
| # Plot injection EOS if provided (on top of everything else) | ||
|
|
@@ -740,19 +786,33 @@ def make_mass_lambda_plot( | |
| f"Plotting {len(valid_indices)} M-Lambda curves (excluded {bad_counter} invalid samples)..." | ||
| ) | ||
|
|
||
| # Second pass: plot only valid samples | ||
| # Second pass: plot only valid samples, splitting multi-branch curves | ||
| n_unstable_ml = 0 | ||
| for i in valid_indices: | ||
| # Get color based on probability | ||
| normalized_value = norm(prob[i]) | ||
| color = cmap(normalized_value) | ||
|
|
||
| plt.plot( | ||
| m[i], | ||
| l[i], | ||
| color=color, | ||
| alpha=1.0, | ||
| rasterized=True, | ||
| zorder=1e10 + normalized_value, | ||
| branches = _split_into_monotone_branches(m[i], l[i]) | ||
| if len(branches) > 1: | ||
| n_unstable_ml += 1 | ||
| for start, end in branches: | ||
| plt.plot( | ||
| m[i][start:end], | ||
| l[i][start:end], | ||
| color=color, | ||
| alpha=1.0, | ||
| rasterized=True, | ||
| zorder=1e10 + normalized_value, | ||
| ) | ||
|
|
||
| if n_unstable_ml > 0: | ||
| pct = 100.0 * n_unstable_ml / len(valid_indices) | ||
| logger.warning( | ||
| f"{n_unstable_ml}/{len(valid_indices)} ({pct:.1f}%) samples had an " | ||
| "unstable part (multi-branch) in their NS solution. These samples are not " | ||
| "accounted for properly during inference; the plot shows each stable branch " | ||
| "separately. If this percentage is low, the impact on the posterior is " | ||
| "negligible." | ||
| ) | ||
|
|
||
| # Plot injection EOS if provided (on top of everything else) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| make_pressure_density_plot, | ||
| make_cs2_plot, | ||
| make_parameter_histograms, | ||
| _split_into_monotone_branches, | ||
| ) | ||
| from jesterTOV.inference.result import InferenceResult | ||
|
|
||
|
|
@@ -520,3 +521,107 @@ def test_full_workflow_flowmc(self, temp_dir): | |
| assert (temp_dir / "cs2_density_plot.pdf").exists() | ||
|
|
||
| plt.close("all") | ||
|
|
||
|
|
||
| class TestSplitIntoMonotoneBranches: | ||
| """Tests for the multi-branch detection helper.""" | ||
|
|
||
| def test_monotone_decreasing_returns_one_segment(self): | ||
| """Smooth Lambda(M) with no increases → single segment covering all data.""" | ||
| masses = np.array([0.5, 0.8, 1.1, 1.4, 1.7, 2.0]) | ||
| lambdas = np.array([2000.0, 800.0, 300.0, 120.0, 50.0, 20.0]) | ||
| branches = _split_into_monotone_branches(masses, lambdas) | ||
| assert len(branches) == 1 | ||
| assert branches[0] == (0, 6) | ||
|
|
||
| def test_single_lambda_increase_splits_into_two_branches(self): | ||
| """One Lambda increase mid-curve → two branches.""" | ||
| masses = np.array([0.5, 0.8, 1.1, 1.2, 1.5, 1.8]) | ||
| lambdas = np.array([2000.0, 800.0, 300.0, 1500.0, 400.0, 100.0]) | ||
| branches = _split_into_monotone_branches(masses, lambdas) | ||
| assert len(branches) == 2 | ||
| assert branches[0] == (0, 3) | ||
| assert branches[1] == (3, 6) | ||
|
|
||
| def test_multiple_lambda_increases_many_branches(self): | ||
| """Multiple Lambda increases → multiple branches.""" | ||
| masses = np.linspace(0.5, 2.0, 10) | ||
| lambdas = np.array( | ||
| [1000.0, 500.0, 800.0, 200.0, 600.0, 100.0, 300.0, 50.0, 80.0, 20.0] | ||
| ) | ||
| branches = _split_into_monotone_branches(masses, lambdas) | ||
| assert len(branches) > 2 | ||
| assert branches[0][0] == 0 | ||
| assert branches[-1][1] == len(masses) | ||
|
|
||
| def test_constant_lambda_returns_one_segment(self): | ||
| """Flat Lambda → no increases → single segment.""" | ||
| masses = np.linspace(0.5, 2.0, 5) | ||
| lambdas = np.ones(5) * 500.0 | ||
| branches = _split_into_monotone_branches(masses, lambdas) | ||
| assert len(branches) == 1 | ||
|
|
||
| def test_all_segments_cover_full_range(self): | ||
| """Branch segments are contiguous and cover the full array without gaps.""" | ||
| masses = np.linspace(0.5, 2.0, 20) | ||
| rng = np.random.default_rng(42) | ||
| lambdas = np.cumsum(rng.uniform(-200, 100, 20)) + 2000.0 | ||
| branches = _split_into_monotone_branches(masses, lambdas) | ||
| assert branches[0][0] == 0 | ||
| assert branches[-1][1] == len(masses) | ||
| for k in range(len(branches) - 1): | ||
| assert branches[k][1] == branches[k + 1][0] | ||
|
|
||
| def test_plot_runs_without_error_when_unstable(self, tmp_path): | ||
| """make_mass_radius_plot runs cleanly when multi-branch samples are present.""" | ||
| n_eos = 50 | ||
| masses_jagged = np.linspace(0.75, 2.5, n_eos) | ||
| lambdas_jagged = np.concatenate( | ||
| [ | ||
| np.linspace(2000, 400, 25), | ||
| np.linspace(1500, 50, 25), | ||
| ] | ||
| ) | ||
| radii_jagged = np.linspace(12.0, 10.0, n_eos) | ||
|
|
||
| masses_smooth = np.linspace(0.75, 2.5, n_eos) | ||
| lambdas_smooth = np.linspace(2000, 10, n_eos) | ||
| radii_smooth = np.linspace(12.0, 10.0, n_eos) | ||
|
|
||
| data = { | ||
| "masses": np.array([masses_jagged, masses_smooth]), | ||
| "radii": np.array([radii_jagged, radii_smooth]), | ||
| "lambdas": np.array([lambdas_jagged, lambdas_smooth]), | ||
| "densities": np.ones((2, n_eos)), | ||
| "pressures": np.ones((2, n_eos)), | ||
| "cs2": np.ones((2, n_eos)) * 0.3, | ||
| "log_prob": np.array([-10.0, -10.0]), | ||
| "prior_params": {}, | ||
| } | ||
|
|
||
| make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path)) | ||
| plt.close("all") | ||
| assert (tmp_path / "mass_radius_plot.pdf").exists() | ||
|
|
||
| def test_plot_runs_without_error_when_smooth(self, tmp_path): | ||
| """make_mass_radius_plot runs cleanly when all samples are smooth.""" | ||
| n_eos = 50 | ||
| n_samples = 3 | ||
| masses = np.tile(np.linspace(0.75, 2.5, n_eos), (n_samples, 1)) | ||
| lambdas = np.tile(np.linspace(2000, 10, n_eos), (n_samples, 1)) | ||
| radii = np.tile(np.linspace(12.0, 10.0, n_eos), (n_samples, 1)) | ||
|
|
||
| data = { | ||
| "masses": masses, | ||
| "radii": radii, | ||
| "lambdas": lambdas, | ||
| "densities": np.ones((n_samples, n_eos)), | ||
| "pressures": np.ones((n_samples, n_eos)), | ||
| "cs2": np.ones((n_samples, n_eos)) * 0.3, | ||
| "log_prob": np.full(n_samples, -10.0), | ||
| "prior_params": {}, | ||
| } | ||
|
|
||
| make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path)) | ||
| plt.close("all") | ||
| assert (tmp_path / "mass_radius_plot.pdf").exists() | ||
|
Comment on lines
+529
to
+627
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add return type annotations to the new tests. The new test methods should declare Proposed fix- def test_monotone_decreasing_returns_one_segment(self):
+ def test_monotone_decreasing_returns_one_segment(self) -> None:
@@
- def test_single_lambda_increase_splits_into_two_branches(self):
+ def test_single_lambda_increase_splits_into_two_branches(self) -> None:
@@
- def test_multiple_lambda_increases_many_branches(self):
+ def test_multiple_lambda_increases_many_branches(self) -> None:
@@
- def test_constant_lambda_returns_one_segment(self):
+ def test_constant_lambda_returns_one_segment(self) -> None:
@@
- def test_all_segments_cover_full_range(self):
+ def test_all_segments_cover_full_range(self) -> None:
@@
- def test_plot_runs_without_error_when_unstable(self, tmp_path):
+ def test_plot_runs_without_error_when_unstable(self, tmp_path) -> None:
@@
- def test_plot_runs_without_error_when_smooth(self, tmp_path):
+ def test_plot_runs_without_error_when_smooth(self, tmp_path) -> None:As per coding guidelines, "All new code MUST include comprehensive type hints with standard library types (Python 3.10+ syntax), jaxtyping for JAX arrays, Pydantic for configs, and type aliases for complex types". 🤖 Prompt for AI Agents |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format the math notation in the new docstring.
The helper docstring introduces math notation but does not use Sphinx math roles; if you spell this as LaTeX, make the docstring raw as well.
Proposed fix
As per coding guidelines, "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with
:math:role for inline math and.. math::directive for display equations" and "Use raw strings (r\"\"\") for docstrings containing LaTeX to avoid Python escape sequence warnings".🤖 Prompt for AI Agents