From 0da3d046baea3f326b606d311dca9d2195d7b41a Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Fri, 31 Oct 2025 09:57:25 -0400 Subject: [PATCH] Add verification of out-of-bound, as suggested by Antoine --- src/scilpy/cli/scil_bundle_label_map.py | 6 +-- .../scil_tractogram_assign_custom_color.py | 6 ++- .../cli/scil_viz_bundle_screenshot_mni.py | 17 +++++---- src/scilpy/connectivity/connectivity.py | 6 ++- src/scilpy/image/volume_space_management.py | 38 ++++++++++++++++++- src/scilpy/tractanalysis/bundle_operations.py | 8 ++-- .../streamline_and_mask_operations.py | 11 ++++-- 7 files changed, 70 insertions(+), 22 deletions(-) diff --git a/src/scilpy/cli/scil_bundle_label_map.py b/src/scilpy/cli/scil_bundle_label_map.py index 496f2c74d..3d9893dc0 100755 --- a/src/scilpy/cli/scil_bundle_label_map.py +++ b/src/scilpy/cli/scil_bundle_label_map.py @@ -72,6 +72,7 @@ import scipy.ndimage as ndi from scilpy.image.volume_math import neighborhood_correlation_ +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, @@ -353,9 +354,8 @@ def main(): continue if len(cut_sft): - tmp_data = ndi.map_coordinates( - map, cut_sft.streamlines._data.T - 0.5, order=0, - mode='nearest') + tmp_data = map_coordinates_in_volume( + map, cut_sft.streamlines._data.T - 0.5, order=0) if basename == 'labels': max_val = args.nb_pts diff --git a/src/scilpy/cli/scil_tractogram_assign_custom_color.py b/src/scilpy/cli/scil_tractogram_assign_custom_color.py index 78ecefc09..562b27a56 100755 --- a/src/scilpy/cli/scil_tractogram_assign_custom_color.py +++ b/src/scilpy/cli/scil_tractogram_assign_custom_color.py @@ -52,6 +52,7 @@ import matplotlib.pyplot as plt from scipy.ndimage import map_coordinates +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, @@ -199,9 +200,11 @@ def main(): '({}) but found {} values.' .format(len(sft), len(data))) elif args.load_dpp or args.from_anatomy: + # Sending data to vox space, center origin for scipy's interpolation sft.to_vox() concat_points = np.vstack(sft.streamlines).T expected_shape = len(concat_points) + # Back to normal space sft.to_rasmm() if args.load_dpp: data = np.squeeze(load_matrix_in_any_format(args.load_dpp)) @@ -210,8 +213,7 @@ def main(): 'but got {}'.format(expected_shape, len(data))) else: # args.from_anatomy: data = nib.load(args.from_anatomy).get_fdata() - data = map_coordinates(data, concat_points, order=0, - mode='nearest') + data = map_coordinates_in_volume(data, concat_points, order=0) elif args.along_profile: data = get_streamlines_as_linspaces(sft) data = np.hstack(data) diff --git a/src/scilpy/cli/scil_viz_bundle_screenshot_mni.py b/src/scilpy/cli/scil_viz_bundle_screenshot_mni.py index a03fde2d3..38aeef44d 100755 --- a/src/scilpy/cli/scil_viz_bundle_screenshot_mni.py +++ b/src/scilpy/cli/scil_viz_bundle_screenshot_mni.py @@ -30,6 +30,7 @@ import numpy as np from scipy.ndimage import map_coordinates +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_headers_compatible, @@ -64,9 +65,9 @@ def _build_arg_parser(): metavar=('R', 'G', 'B'), type=float, help='Color streamlines with uniform coloring.') sub_color.add_argument('--reference_coloring', - metavar='COLORBAR', - help='Color streamlines with reference coloring ' - '(0-255).') + metavar='COLORMAP', + help='Color streamlines with reference coloring. ' + 'Name of a matlab colormap. (0-255).') p.add_argument('--roi', nargs='+', action='append', help='Path to a ROI file (.nii or nii.gz).') p.add_argument('--right', action='store_true', @@ -261,14 +262,16 @@ def main(): args.uniform_coloring[1] / 255.0, args.uniform_coloring[2] / 255.0) elif args.reference_coloring: + # Sending to vox space, center origin for interpolation sft.to_vox() - streamlines_vox = sft.get_streamlines_copy() + coords_vox = np.vstack(sft.streamlines).T + # Back to normal space sft.to_rasmm() + normalized_data = reference_data / np.max(reference_data) cmap = get_lookup_table(args.reference_coloring) - values = map_coordinates(normalized_data, - streamlines_vox.streamlines._data.T, - order=1, mode='nearest') + values = map_coordinates_in_volume(normalized_data, coords_vox, + order=1) colors = cmap(values)[:, 0:3] else: colors = None diff --git a/src/scilpy/connectivity/connectivity.py b/src/scilpy/connectivity/connectivity.py index 844fa10cd..db24d08c6 100644 --- a/src/scilpy/connectivity/connectivity.py +++ b/src/scilpy/connectivity/connectivity.py @@ -13,6 +13,7 @@ from scipy.ndimage import map_coordinates from scilpy.image.labels import get_data_as_labels +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.io.hdf5 import reconstruct_streamlines_from_hdf5 from scilpy.tractanalysis.reproducibility_measures import \ compute_bundle_adjacency_voxel @@ -78,8 +79,9 @@ def compute_triu_connectivity_from_labels(tractogram, data_labels, .format(nb_labels)) matrix = np.zeros((nb_labels, nb_labels), dtype=int) - labels = map_coordinates(data_labels, streamlines._data.T, order=0, - mode='nearest') + # Taking coords in vox space, center origin for interpolation + coords = np.vstack(streamlines).T + labels = map_coordinates_in_volume(data_labels, coords, order=0) start_labels = labels[0::2] end_labels = labels[1::2] diff --git a/src/scilpy/image/volume_space_management.py b/src/scilpy/image/volume_space_management.py index 6f5b54429..25c040ed9 100644 --- a/src/scilpy/image/volume_space_management.py +++ b/src/scilpy/image/volume_space_management.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- -import numpy as np +import logging +import numpy as np from numba_kdtree import KDTree from numba import njit +from scipy.ndimage import map_coordinates + from scilpy.tracking.fibertube_utils import (streamlines_to_segments, point_in_cylinder, sphere_cylinder_intersection) @@ -16,6 +19,39 @@ from dipy.reconst.shm import sf_to_sh +def map_coordinates_in_volume(data, points, order): + """ + Uses map_coordinates, from scipy. But by default, in scipy, half of the + border voxels are considered out-of-bound. Using mode=nearest to make sure + we interpolate correctly in border voxels. Verifying if some coordinates + are *actually* out-of-bound first. + See here for more explanation: https://github.com/scilus/scilpy/pull/1102 + + An alternative is to use dipy's trilinear function, but in some cases + scipy's function is easier to use. + + Parameters + ---------- + data: np.ndarray + The volume + points: np.ndarray + The coordinates in vox space, center origin. Shape: [3, N] + order: int + The order of the interpolation + + Returns + ------- + data: np.ndarray + The interpolated data. + """ + if (np.any(np.logical_or(points[0] < 0, points[0] > data.shape[0])) or + np.any(np.logical_or(points[1] < 0, points[1] > data.shape[1])) or + np.any(np.logical_or(points[2] < 0,points[2] > data.shape[2]))) : + logging.warning("Careful! You are interpolating outside of boundaries " + "of your volume. Using padding to nearest value.") + return map_coordinates(data, points, order=order, mode='nearest') + + class DataVolume(object): """ Class to access/interpolate data from nibabel object diff --git a/src/scilpy/tractanalysis/bundle_operations.py b/src/scilpy/tractanalysis/bundle_operations.py index edfd33fa3..4c03f349c 100644 --- a/src/scilpy/tractanalysis/bundle_operations.py +++ b/src/scilpy/tractanalysis/bundle_operations.py @@ -12,6 +12,7 @@ from scipy.spatial import cKDTree from sklearn.cluster import KMeans +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.maths.utils import fit_circle_planar from scilpy.tractograms.streamline_and_mask_operations import \ get_endpoints_density_map @@ -433,9 +434,10 @@ def compute_bundle_diameter(sft, data_labels, fitting_func): counter = 0 labels_dict = {label: ([], []) for label in unique_labels} - pts_labels = map_coordinates(data_labels, - sft.streamlines._data.T - 0.5, - order=0, mode='nearest') + + # Must bring data to vox space, center origin to use map_coordinates + pts_labels = map_coordinates_in_volume( + data_labels, sft.streamlines._data.T - 0.5, order=0) # For each label, all positions and directions are needed to get # a tube estimation per label. diff --git a/src/scilpy/tractograms/streamline_and_mask_operations.py b/src/scilpy/tractograms/streamline_and_mask_operations.py index faa0fcce5..f9be8f075 100644 --- a/src/scilpy/tractograms/streamline_and_mask_operations.py +++ b/src/scilpy/tractograms/streamline_and_mask_operations.py @@ -10,6 +10,8 @@ from scipy.ndimage import map_coordinates from scilpy.tractograms.uncompress import streamlines_to_voxel_coordinates + +from scilpy.image.volume_space_management import map_coordinates_in_volume from scilpy.tractograms.streamline_operations import \ (_get_point_on_line, _get_streamline_pt_index, _get_next_real_point, _get_previous_real_point, @@ -620,7 +622,8 @@ def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices, roi_data_2: np.ndarray Boolean array representing the region #2 strl_indices: list of tuple (N, 3) - 3D indices of the voxels intersected by the streamline + 3D indices of the voxels intersected by the streamline, in vox space, + corner origin. one_point_in_roi: bool If True, one point in each ROI will be kept. no_point_in_roi: bool @@ -635,10 +638,10 @@ def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices, """ # Find all the points of the streamline that are in the ROIs - roi_data_1_intersect = map_coordinates( - roi_data_1, strl_indices.T, order=0, mode='nearest') + roi_data_1_intersect = map_coordinates_in_volume( + roi_data_1, strl_indices.T, order=0) roi_data_2_intersect = map_coordinates( - roi_data_2, strl_indices.T, order=0, mode='nearest') + roi_data_2, strl_indices.T, order=0) # Get the indices of the voxels intersecting with the ROIs in_strl_indices = np.argwhere(roi_data_1_intersect).squeeze(-1)