Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions jesterTOV/inference/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,38 @@ def load_injection_eos(
return None


def _split_into_monotone_branches(
masses: np.ndarray, lambdas: np.ndarray
) -> list[tuple[int, int]]:
"""Split a mass-Lambda curve into monotone-decreasing segments.

A branch break is detected wherever Lambda increases as mass increases,
which signals that the stored curve contains two interleaved stable branches
(the "third family" / twin-star scenario arising from a fold-back in M(pc)).

Parameters
----------
masses : np.ndarray
Mass array (monotone increasing, uniform grid).
lambdas : np.ndarray
Tidal deformability array on the same grid.

Returns
-------
list of (start, end) index pairs
Each pair defines a half-open slice ``masses[start:end]`` that is
monotone decreasing in Lambda. Always contains at least one segment.
"""
Comment on lines +335 to +353
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Format the math notation in the new docstring.

The helper docstring introduces math notation but does not use Sphinx math roles; if you spell this as LaTeX, make the docstring raw as well.

Proposed fix
-def _split_into_monotone_branches(
-    masses: np.ndarray, lambdas: np.ndarray
-) -> list[tuple[int, int]]:
-    """Split a mass-Lambda curve into monotone-decreasing segments.
+def _split_into_monotone_branches(
+    masses: np.ndarray, lambdas: np.ndarray
+) -> list[tuple[int, int]]:
+    r"""Split a :math:`M-\Lambda` curve into monotone-decreasing segments.
 
-    A branch break is detected wherever Lambda increases as mass increases,
+    A branch break is detected wherever :math:`\Lambda` increases as mass increases,
     which signals that the stored curve contains two interleaved stable branches
-    (the "third family" / twin-star scenario arising from a fold-back in M(pc)).
+    (the "third family" / twin-star scenario arising from a fold-back in
+    :math:`M(p_c)`).

As per coding guidelines, "All mathematical expressions in docstrings must use Sphinx/reStructuredText formatting with :math: role for inline math and .. math:: directive for display equations" and "Use raw strings (r\"\"\") for docstrings containing LaTeX to avoid Python escape sequence warnings".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/postprocessing/postprocessing.py` around lines 335 - 353,
Convert the docstring for the function that splits a mass-Lambda curve (the
docstring that documents parameters masses and lambdas in postprocessing.py)
into a raw string (prefix with r""") and replace any LaTeX/math fragments with
Sphinx roles: use :math:`...` for inline expressions (e.g. variable names like
m, Lambda) and use a ``.. math::`` block for any displayed equations; ensure
backslashes are preserved by the raw string and update examples/phrasing to use
:math: roles consistently throughout the description.

segments: list[tuple[int, int]] = []
start = 0
for j in range(1, len(lambdas)):
if lambdas[j] > lambdas[j - 1]:
segments.append((start, j))
start = j
segments.append((start, len(masses)))
return segments
Comment on lines +354 to +361
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate the paired array lengths before returning slice ranges.

The function detects breaks over lambdas but ends the last slice at len(masses). If malformed input has mismatched lengths, callers can plot ranges that were not checked for monotonicity.

Proposed fix
     segments: list[tuple[int, int]] = []
+    if len(masses) != len(lambdas):
+        raise ValueError(
+            "masses and lambdas must have the same length to split monotone branches"
+        )
+
     start = 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@jesterTOV/inference/postprocessing/postprocessing.py` around lines 354 - 361,
Before returning slice ranges, validate that the paired arrays have equal
lengths: check that len(lambdas) == len(masses) at start of the routine (the
code that builds segments using lambdas, masses, segments, start) and if they
differ either raise a ValueError with a clear message or truncate to the smaller
length and use that for the final segment boundary; ensure the final
segments.append uses the validated length (e.g., len(lambdas_validated) or
min_len) rather than unvalidated len(masses).



def report_credible_interval(
values: np.ndarray, hdi_prob: float = HDI_PROB, verbose: bool = False
) -> tuple:
Expand Down Expand Up @@ -567,19 +599,33 @@ def make_mass_radius_plot(
f"Plotting {len(valid_indices)} M-R curves (excluded {bad_counter} invalid samples)..."
)

# Second pass: plot only valid samples
# Second pass: plot only valid samples, splitting multi-branch curves
n_unstable_mr = 0
for i in valid_indices:
# Get color based on probability
normalized_value = norm(prob[i])
color = cmap(normalized_value)

plt.plot(
r[i],
m[i],
color=color,
alpha=1.0,
rasterized=True,
zorder=1e10 + normalized_value,
branches = _split_into_monotone_branches(m[i], l[i])
if len(branches) > 1:
n_unstable_mr += 1
for start, end in branches:
plt.plot(
r[i][start:end],
m[i][start:end],
color=color,
alpha=1.0,
rasterized=True,
zorder=1e10 + normalized_value,
)

if n_unstable_mr > 0:
pct = 100.0 * n_unstable_mr / len(valid_indices)
logger.warning(
f"{n_unstable_mr}/{len(valid_indices)} ({pct:.1f}%) samples had an "
"unstable part (multi-branch) in their NS solution. These samples are not "
"accounted for properly during inference; the plot shows each stable branch "
"separately. If this percentage is low, the impact on the posterior is "
"negligible."
)

# Plot injection EOS if provided (on top of everything else)
Expand Down Expand Up @@ -740,19 +786,33 @@ def make_mass_lambda_plot(
f"Plotting {len(valid_indices)} M-Lambda curves (excluded {bad_counter} invalid samples)..."
)

# Second pass: plot only valid samples
# Second pass: plot only valid samples, splitting multi-branch curves
n_unstable_ml = 0
for i in valid_indices:
# Get color based on probability
normalized_value = norm(prob[i])
color = cmap(normalized_value)

plt.plot(
m[i],
l[i],
color=color,
alpha=1.0,
rasterized=True,
zorder=1e10 + normalized_value,
branches = _split_into_monotone_branches(m[i], l[i])
if len(branches) > 1:
n_unstable_ml += 1
for start, end in branches:
plt.plot(
m[i][start:end],
l[i][start:end],
color=color,
alpha=1.0,
rasterized=True,
zorder=1e10 + normalized_value,
)

if n_unstable_ml > 0:
pct = 100.0 * n_unstable_ml / len(valid_indices)
logger.warning(
f"{n_unstable_ml}/{len(valid_indices)} ({pct:.1f}%) samples had an "
"unstable part (multi-branch) in their NS solution. These samples are not "
"accounted for properly during inference; the plot shows each stable branch "
"separately. If this percentage is low, the impact on the posterior is "
"negligible."
)

# Plot injection EOS if provided (on top of everything else)
Expand Down
105 changes: 105 additions & 0 deletions tests/test_inference/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
make_pressure_density_plot,
make_cs2_plot,
make_parameter_histograms,
_split_into_monotone_branches,
)
from jesterTOV.inference.result import InferenceResult

Expand Down Expand Up @@ -520,3 +521,107 @@ def test_full_workflow_flowmc(self, temp_dir):
assert (temp_dir / "cs2_density_plot.pdf").exists()

plt.close("all")


class TestSplitIntoMonotoneBranches:
"""Tests for the multi-branch detection helper."""

def test_monotone_decreasing_returns_one_segment(self):
"""Smooth Lambda(M) with no increases → single segment covering all data."""
masses = np.array([0.5, 0.8, 1.1, 1.4, 1.7, 2.0])
lambdas = np.array([2000.0, 800.0, 300.0, 120.0, 50.0, 20.0])
branches = _split_into_monotone_branches(masses, lambdas)
assert len(branches) == 1
assert branches[0] == (0, 6)

def test_single_lambda_increase_splits_into_two_branches(self):
"""One Lambda increase mid-curve → two branches."""
masses = np.array([0.5, 0.8, 1.1, 1.2, 1.5, 1.8])
lambdas = np.array([2000.0, 800.0, 300.0, 1500.0, 400.0, 100.0])
branches = _split_into_monotone_branches(masses, lambdas)
assert len(branches) == 2
assert branches[0] == (0, 3)
assert branches[1] == (3, 6)

def test_multiple_lambda_increases_many_branches(self):
"""Multiple Lambda increases → multiple branches."""
masses = np.linspace(0.5, 2.0, 10)
lambdas = np.array(
[1000.0, 500.0, 800.0, 200.0, 600.0, 100.0, 300.0, 50.0, 80.0, 20.0]
)
branches = _split_into_monotone_branches(masses, lambdas)
assert len(branches) > 2
assert branches[0][0] == 0
assert branches[-1][1] == len(masses)

def test_constant_lambda_returns_one_segment(self):
"""Flat Lambda → no increases → single segment."""
masses = np.linspace(0.5, 2.0, 5)
lambdas = np.ones(5) * 500.0
branches = _split_into_monotone_branches(masses, lambdas)
assert len(branches) == 1

def test_all_segments_cover_full_range(self):
"""Branch segments are contiguous and cover the full array without gaps."""
masses = np.linspace(0.5, 2.0, 20)
rng = np.random.default_rng(42)
lambdas = np.cumsum(rng.uniform(-200, 100, 20)) + 2000.0
branches = _split_into_monotone_branches(masses, lambdas)
assert branches[0][0] == 0
assert branches[-1][1] == len(masses)
for k in range(len(branches) - 1):
assert branches[k][1] == branches[k + 1][0]

def test_plot_runs_without_error_when_unstable(self, tmp_path):
"""make_mass_radius_plot runs cleanly when multi-branch samples are present."""
n_eos = 50
masses_jagged = np.linspace(0.75, 2.5, n_eos)
lambdas_jagged = np.concatenate(
[
np.linspace(2000, 400, 25),
np.linspace(1500, 50, 25),
]
)
radii_jagged = np.linspace(12.0, 10.0, n_eos)

masses_smooth = np.linspace(0.75, 2.5, n_eos)
lambdas_smooth = np.linspace(2000, 10, n_eos)
radii_smooth = np.linspace(12.0, 10.0, n_eos)

data = {
"masses": np.array([masses_jagged, masses_smooth]),
"radii": np.array([radii_jagged, radii_smooth]),
"lambdas": np.array([lambdas_jagged, lambdas_smooth]),
"densities": np.ones((2, n_eos)),
"pressures": np.ones((2, n_eos)),
"cs2": np.ones((2, n_eos)) * 0.3,
"log_prob": np.array([-10.0, -10.0]),
"prior_params": {},
}

make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path))
plt.close("all")
assert (tmp_path / "mass_radius_plot.pdf").exists()

def test_plot_runs_without_error_when_smooth(self, tmp_path):
"""make_mass_radius_plot runs cleanly when all samples are smooth."""
n_eos = 50
n_samples = 3
masses = np.tile(np.linspace(0.75, 2.5, n_eos), (n_samples, 1))
lambdas = np.tile(np.linspace(2000, 10, n_eos), (n_samples, 1))
radii = np.tile(np.linspace(12.0, 10.0, n_eos), (n_samples, 1))

data = {
"masses": masses,
"radii": radii,
"lambdas": lambdas,
"densities": np.ones((n_samples, n_eos)),
"pressures": np.ones((n_samples, n_eos)),
"cs2": np.ones((n_samples, n_eos)) * 0.3,
"log_prob": np.full(n_samples, -10.0),
"prior_params": {},
}

make_mass_radius_plot(data, prior_data=None, outdir=str(tmp_path))
plt.close("all")
assert (tmp_path / "mass_radius_plot.pdf").exists()
Comment on lines +529 to +627
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add return type annotations to the new tests.

The new test methods should declare -> None to satisfy the project’s type-hinting rule.

Proposed fix
-    def test_monotone_decreasing_returns_one_segment(self):
+    def test_monotone_decreasing_returns_one_segment(self) -> None:
@@
-    def test_single_lambda_increase_splits_into_two_branches(self):
+    def test_single_lambda_increase_splits_into_two_branches(self) -> None:
@@
-    def test_multiple_lambda_increases_many_branches(self):
+    def test_multiple_lambda_increases_many_branches(self) -> None:
@@
-    def test_constant_lambda_returns_one_segment(self):
+    def test_constant_lambda_returns_one_segment(self) -> None:
@@
-    def test_all_segments_cover_full_range(self):
+    def test_all_segments_cover_full_range(self) -> None:
@@
-    def test_plot_runs_without_error_when_unstable(self, tmp_path):
+    def test_plot_runs_without_error_when_unstable(self, tmp_path) -> None:
@@
-    def test_plot_runs_without_error_when_smooth(self, tmp_path):
+    def test_plot_runs_without_error_when_smooth(self, tmp_path) -> None:

As per coding guidelines, "All new code MUST include comprehensive type hints with standard library types (Python 3.10+ syntax), jaxtyping for JAX arrays, Pydantic for configs, and type aliases for complex types".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_inference/test_postprocessing.py` around lines 529 - 627, Add
explicit return type annotations "-> None" to each new test function definition
so they comply with the project's type-hinting rule: update the defs for
test_monotone_decreasing_returns_one_segment,
test_single_lambda_increase_splits_into_two_branches,
test_multiple_lambda_increases_many_branches,
test_constant_lambda_returns_one_segment, test_all_segments_cover_full_range,
test_plot_runs_without_error_when_unstable, and
test_plot_runs_without_error_when_smooth to include "-> None" after the
parameter list.

Loading