Skip to content

Conversation

eclipse0922
Copy link

@eclipse0922 eclipse0922 commented Sep 21, 2025

Fixes #3328 .

Description

A few sentences describing the changes proposed in this pull request.

This pull request introduces GenerateHeatmap and GenerateHeatmapd transforms for creating Gaussian heatmaps from landmark coordinates.
The input points are currently expected in ZYX order, but this can be changed to support XYZ if preferred.
The transforms support both batched (B, N, D) and non-batched (N, D) inputs.

Example notebooks are included for demonstration and will be removed before the PR is merged.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Adds a `GenerateHeatmap` transform to create gaussian response maps from landmark coordinates.
This transform is implemented for both array and dictionary-based workflows.
It enables the generation of heatmaps from landmark data, facilitating tasks
like landmark localization and visualization.
The transform supports 2D and 3D coordinates and offers options for controlling
the gaussian standard deviation, spatial shape, truncation, normalization, and data type.
Introduces a new interactive notebook demonstrating landmark to heatmap conversion using MONAI transforms.

This includes:
- A notebook with array and dictionary transform modes.
- A test suite for the `GenerateHeatmap` transform.

This enhancement enables users to visualize and interact with heatmap generation, facilitating a better understanding and application of the MONAI transforms.
Extends the `GenerateHeatmap` transform to support batched inputs,
allowing for more efficient processing of multiple landmark sets.

This change modifies the transform to handle inputs with a batch dimension (B, N, spatial_dims) in addition to single-point inputs (N, spatial_dims).
It also includes a demonstration of 3D heatmap generation using PyVista for visualization.
@eclipse0922 eclipse0922 marked this pull request as draft September 21, 2025 10:42
Streamlines the GenerateHeatmap and GenerateHeatmapd transforms for better usability and code clarity.

Specifically:
- Improves the input landmark array validation to provide a more descriptive error message.
- Removes example notebooks.

DCO Remediation Commit for sewon.jeon <[email protected]>

I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 8ef905b
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 226bf90
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 3097baf
I, sewon.jeon <[email protected]>, hereby add my Signed-off-by to this commit: 0072cb0

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 0072cb0 to 25ceb7f Compare September 21, 2025 11:41
Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 4443705 to 9e33e7c Compare September 21, 2025 12:29
Copy link
Contributor

coderabbitai bot commented Sep 21, 2025

Walkthrough

  • Added GenerateHeatmap transform in monai/transforms/post/array.py to create Gaussian heatmaps from 2D/3D landmark points, with sigma handling, spatial shape resolution, truncation/windowing, optional normalization, dtype/backends (NumPy/Torch) support, and helper utilities. Export list updated and get_equivalent_dtype imported.
  • Added GenerateHeatmapd in monai/transforms/post/dictionary.py to map-wise wrap GenerateHeatmap with per-key/ref-image handling, dtype/device alignment, metadata propagation, and aliases GenerateHeatmapD/GenerateHeatmapDict. Exports and imports updated.
  • Added comprehensive unit tests: tests/transforms/test_generate_heatmap.py and tests/transforms/test_generate_heatmapd.py covering shapes, dtypes/devices, normalization, truncation, batching, error cases, and multiple keys.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Generate heatmap transforms" is concise and accurately summarizes the primary change in this PR—adding GenerateHeatmap and GenerateHeatmapd transforms (and tests) to produce Gaussian landmark heatmaps.
Linked Issues Check ✅ Passed The implementation meets the core coding objectives from issue [#3328]: it adds GenerateHeatmap and GenerateHeatmapd to convert landmark coordinates into Gaussian heatmaps, supports configurable heatmap size via spatial_shape, accepts scalar or per-dimension sigmas, writes outputs to specified dictionary keys, handles batched and single-point inputs, preserves dtype/device/metadata across torch and numpy backends, and includes tests; however the PR does not expose explicit label-range/step parameters described in the issue (it relies on spatial_shape and reference-image mapping rather than a separate range/step API).
Out of Scope Changes Check ✅ Passed The changes are focused on the new heatmap transforms, their dictionary wrapper, type/dtype handling utilities, and tests; added imports and all exports are relevant to the feature and no unrelated code edits are apparent, though example notebooks included in the branch are out-of-scope for a code PR.
Description Check ✅ Passed The PR description follows the repository template by including "Fixes #3328", a Description that explains the added GenerateHeatmap/GenerateHeatmapd transforms and test coverage, and a Types of changes checklist; it also notes example notebooks included and that docstrings/documentation were not updated.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922 eclipse0922 marked this pull request as ready for review September 21, 2025 14:27
@eclipse0922 eclipse0922 force-pushed the generate_heatmap_transforms branch from 60b58f5 to 54a81a5 Compare September 21, 2025 14:27
@eclipse0922
Copy link
Author

Signed-off-by: sewon.jeon <[email protected]>
@eclipse0922
Copy link
Author

@coderabbitai review

@eclipse0922
Copy link
Author

@coderabbitai help

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (8)
monai/transforms/post/dictionary.py (2)

516-652: LGTM with minor improvements needed.

The implementation is solid. A few observations:

  1. The _update_spatial_metadata method has complex logic for determining spatial shape from different tensor dimensions. Consider simplifying or adding more explicit comments about the logic.

  2. Error messages could be extracted to constants for better maintainability (lines 573, 583, 598, 619-621, 633).

  3. The type hints could be more specific - Any is used extensively where more concrete types might be known.

Consider extracting error messages:

+_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
+_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
+_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
+_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
+_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
+_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."

 def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
     if heatmap_keys is None:
         return tuple(f"{key}_heatmap" for key in self.keys)
     keys_tuple = ensure_tuple(heatmap_keys)
     if len(keys_tuple) == 1 and len(self.keys) > 1:
         keys_tuple = keys_tuple * len(self.keys)
     if len(keys_tuple) != len(self.keys):
-        raise ValueError("heatmap_keys length must match keys length.")
+        raise ValueError(_ERR_HEATMAP_KEYS_LEN)
     return keys_tuple

636-650: Simplify spatial metadata update logic.

The _update_spatial_metadata method has nested conditionals that make it hard to follow. The logic for distinguishing batched 2D from non-batched 3D is particularly complex.

Consider a clearer approach:

 def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
     """Update spatial metadata of heatmap based on its dimensions."""
-    # Update spatial_shape metadata based on heatmap dimensions
-    if heatmap.ndim == 5:  # 3D batched: (B, C, H, W, D)
-        heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
-    elif heatmap.ndim == 4:  # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
-        # Need to check if this is batched 2D or non-batched 3D
-        if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
-            # Non-batched 3D
-            heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
-        else:
-            # Batched 2D
-            heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
-    else:  # 2D non-batched: (C, H, W)
-        heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
+    # Determine if batched based on reference's batch dimension
+    ref_is_batched = len(reference.shape) > len(reference.meta.get("spatial_shape", [])) + 1
+    
+    if heatmap.ndim == 5:  # 3D batched
+        spatial_shape = heatmap.shape[2:]
+    elif heatmap.ndim == 4:
+        # Disambiguate: 2D batched vs 3D non-batched
+        spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:]
+    else:  # ndim == 3, 2D non-batched
+        spatial_shape = heatmap.shape[1:]
+    
+    heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
tests/transforms/test_generate_heatmapd.py (3)

101-101: Remove unused parameters.

Parameters expected_dtype and uses_ref are not used in the test method.

-def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref):
+def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused):

151-151: Remove unused parameter.

Parameter expected_dtype is not used.

-def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype):
+def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype):

205-229: Document current behavior limitation.

The test acknowledges that MetaTensor points may inherit incorrect affine. This should be tracked as a known issue.

Should I create an issue to track the affine inheritance behavior when using MetaTensor points with reference images?

monai/transforms/post/array.py (3)

753-893: Well-implemented transform with room for minor improvements.

The GenerateHeatmap transform is well-structured. A few suggestions:

  1. The backend class attribute should be annotated with ClassVar (line 769).
  2. Consider extracting error messages to constants for maintainability.
  3. The _evaluate_gaussian method could benefit from a brief docstring.

Fix the class attribute annotation:

+from typing import ClassVar
 
 class GenerateHeatmap(Transform):
-    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
+    backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH]

863-865: Add boundary check optimization.

The _is_inside method could short-circuit on first failure.

 @staticmethod
 def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
-    return all(0 <= c < size for c, size in zip(center, bounds))
+    for c, size in zip(center, bounds):
+        if not (0 <= c < size):
+            return False
+    return True

881-892: Add docstring for clarity.

The _evaluate_gaussian method would benefit from documentation.

 def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
+    """Evaluate Gaussian at given coordinate shifts with specified sigmas.
+    
+    Args:
+        coord_shifts: Per-dimension coordinate offsets from center.
+        sigma: Per-dimension standard deviations.
+    
+    Returns:
+        Gaussian values at the specified coordinates.
+    """
     device = coord_shifts[0].device
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 53382d8 and 9f10dcf.

📒 Files selected for processing (4)
  • monai/transforms/post/array.py (3 hunks)
  • monai/transforms/post/dictionary.py (4 hunks)
  • tests/transforms/test_generate_heatmap.py (1 hunks)
  • tests/transforms/test_generate_heatmapd.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/transforms/post/dictionary.py
  • monai/transforms/post/array.py
  • tests/transforms/test_generate_heatmap.py
  • tests/transforms/test_generate_heatmapd.py
🪛 Ruff (0.13.1)
monai/transforms/post/dictionary.py

573-573: Avoid specifying long messages outside the exception class

(TRY003)


583-583: Avoid specifying long messages outside the exception class

(TRY003)


598-598: Avoid specifying long messages outside the exception class

(TRY003)


615-615: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


633-633: Avoid specifying long messages outside the exception class

(TRY003)

monai/transforms/post/array.py

769-769: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


781-781: Avoid specifying long messages outside the exception class

(TRY003)


785-785: Avoid specifying long messages outside the exception class

(TRY003)


788-788: Avoid specifying long messages outside the exception class

(TRY003)


802-804: Avoid specifying long messages outside the exception class

(TRY003)


808-808: Avoid specifying long messages outside the exception class

(TRY003)


847-847: Avoid specifying long messages outside the exception class

(TRY003)


853-853: Avoid specifying long messages outside the exception class

(TRY003)


861-861: Avoid specifying long messages outside the exception class

(TRY003)

tests/transforms/test_generate_heatmapd.py

56-56: Consider (1, *shape) instead of concatenation

Replace with (1, *shape)

(RUF005)


101-101: Unused method argument: expected_dtype

(ARG002)


101-101: Unused method argument: uses_ref

(ARG002)


151-151: Unused method argument: expected_dtype

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
🔇 Additional comments (2)
tests/transforms/test_generate_heatmap.py (2)

24-29: LGTM!

Clean helper function for finding peak coordinates.


233-240: Good error handling test coverage.

Testing multiple invalid input scenarios is thorough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Heatmap generation transforms
1 participant