Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions src/scilpy/segment/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@

def streamlines_in_mask(sft, target_mask, all_in=False):
"""
Finds the streamlines that are either touching a mask (if all_in=False) or
entirely contained in the mask (if all_in=True).

Parameters
----------
sft : StatefulTractogram
StatefulTractogram containing the streamlines to segment.
target_mask : numpy.ndarray
Binary mask in which the streamlines should pass.
all_in: bool
If true, finds streamlines satisfying the 'all' criteria. Else, finds
streamlines satisfying the 'any' criteria.

Returns
-------
ids : list
Ids of the streamlines passing through the mask.
Ids of the streamlines passing the test.
"""
sft.to_vox()
sft.to_corner()
Expand All @@ -52,9 +59,11 @@ def streamlines_in_mask(sft, target_mask, all_in=False):
return np.where(streamlines_case == [0, 1][True])[0].tolist()


def filter_grid_roi_both(sft, mask_1, mask_2):
""" Filters streamlines with one end in a mask and the other in
another mask.
def filter_grid_roi_both_ends(sft, mask_1, mask_2):
"""
Filters streamlines with one end in a mask and the other in another mask.
See also filter_grid_roi, but here we may give two different masks for the
endpoints.

Parameters
----------
Expand All @@ -64,6 +73,7 @@ def filter_grid_roi_both(sft, mask_1, mask_2):
Binary mask in which the streamlines should start or end.
mask_2: numpy.ndarray
Binary mask in which the streamlines should start or end.

Returns
-------
new_sft: StatefulTractogram
Expand Down Expand Up @@ -108,6 +118,9 @@ def filter_grid_roi_both(sft, mask_1, mask_2):
def filter_grid_roi(sft, mask, filter_type, is_exclude, filter_distance=0,
return_sft=False, return_rejected_sft=False):
"""
Filters streamlines based on a given criteria (any, all, either_end,
both_ends).

Parameters
----------
sft : StatefulTractogram
Expand Down Expand Up @@ -199,8 +212,8 @@ def filter_grid_roi(sft, mask, filter_type, is_exclude, filter_distance=0,
return line_based_indices


def pre_filtering_for_geometrical_shape(sft, size, center, filter_type,
is_in_vox):
def _pre_filtering_for_geometrical_shape(sft, size, center, filter_type,
is_in_vox):
"""
Parameters
----------
Expand Down Expand Up @@ -254,14 +267,17 @@ def pre_filtering_for_geometrical_shape(sft, size, center, filter_type,
def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center,
filter_type, is_exclude, is_in_vox=False):
"""
Finds streamlines that respect some criteria in a ROI, where the ROI is
a bounding box of ellipsoid type.

Parameters
----------
sft : StatefulTractogram
Tractogram containing the streamlines to segment.
ellipsoid_radius : numpy.ndarray (3)
Size in mm, x/y/z of the ellipsoid.
ellipsoid_center: numpy.ndarray (3)
Center x/y/z of the ellipsoid.
Center x/y/z of the ellipsoid, in RASMM space, center origin.
filter_type: str
One of the 4 following choices, 'any', 'all', 'either_end', 'both_ends'.
is_exclude: bool
Expand All @@ -280,12 +296,11 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center,
return np.array([]), sft

pre_filtered_indices, pre_filtered_sft = \
pre_filtering_for_geometrical_shape(sft, ellipsoid_radius,
ellipsoid_center, filter_type,
is_in_vox)
_pre_filtering_for_geometrical_shape(sft, ellipsoid_radius,
ellipsoid_center, filter_type,
is_in_vox)
pre_filtered_sft.to_rasmm()
pre_filtered_sft.to_center()
pre_filtered_streamlines = pre_filtered_sft.streamlines
transfo, _, res, _ = sft.space_attributes

if is_in_vox:
Expand All @@ -303,7 +318,7 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center,
ellipsoid_radius = np.asarray(ellipsoid_radius, dtype=float)
ellipsoid_center = np.asarray(ellipsoid_center, dtype=float)

for i, line in enumerate(pre_filtered_streamlines):
for i, line in enumerate(pre_filtered_sft.streamlines):
if filter_type in ['any', 'all']:
# Resample to 1/10 of the voxel size
nb_points = max(int(length(line) / np.average(res) * 10), 2)
Expand Down Expand Up @@ -360,23 +375,24 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center,
return line_based_indices, new_sft


def filter_cuboid(sft, cuboid_radius, cuboid_center,
filter_type, is_exclude):
def filter_cuboid(sft, cuboid_radius, cuboid_center, filter_type, is_exclude):
"""
Finds streamlines that respect some criteria in a ROI, where the ROI is
a bounding box of cuboid type.

Parameters
----------
sft : StatefulTractogram
StatefulTractogram containing the streamlines to segment.
cuboid_radius : numpy.ndarray (3)
Size in mm, x/y/z of the cuboid.
cuboid_center: numpy.ndarray (3)
Center x/y/z of the cuboid.
Center x/y/z of the cuboid, in RASMM space, center origin.
filter_type: str
One of the 4 following choices: 'any', 'all', 'either_end', 'both_ends'.
is_exclude: bool
Value to indicate if the ROI is an AND (false) or a NOT (true).
is_in_vox: bool
Value to indicate if the ROI is in voxel space.

Returns
-------
ids : list
Expand All @@ -388,9 +404,9 @@ def filter_cuboid(sft, cuboid_radius, cuboid_center,
return np.array([]), sft

pre_filtered_indices, pre_filtered_sft = \
pre_filtering_for_geometrical_shape(sft, cuboid_radius,
cuboid_center, filter_type,
is_in_vox=False)
_pre_filtering_for_geometrical_shape(sft, cuboid_radius,
cuboid_center, filter_type,
is_in_vox=False)
pre_filtered_sft.to_rasmm()
pre_filtered_sft.to_center()
pre_filtered_streamlines = pre_filtered_sft.streamlines
Expand Down
7 changes: 7 additions & 0 deletions src/scilpy/segment/tests/test_bundleseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-

def test_recognize():
# todo:
# bundleseg = BundleSeg(...)
# bundleseg.recognize.
pass
138 changes: 138 additions & 0 deletions src/scilpy/segment/tests/test_streamlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-
from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin
import nibabel as nib
import numpy as np
from scilpy.segment.streamlines import streamlines_in_mask, \
filter_grid_roi_both_ends, filter_grid_roi, filter_ellipsoid, filter_cuboid

# Preparing SFT for all tests here.
# 3 ways to crete ROI, all ~voxels 1 and 2 in x, y, z
# - mask
# - bdo_ellipsoid
# - bdo_cuboid

# Binary mask
mask = np.zeros((10, 10, 10))
mask[1:3, 1:3, 1:3] = 1

# bdo.
# Note. Everything else here is in vox, corner, except bdo_center
bdo_center_rasmm_centerorigin = np.asarray([1.5, 1.5, 1.5])
bdo_radius_mm = 1.5

# Preparing streamlines fitting criteria.
# line 0 'all': entirely in mask.
line0 = [[1.5, 1.5, 1.5],
[1.6, 1.6, 1.6],
[1.7, 1.7, 1.7],
[2.5, 2.5, 2.5]]

# line 1 'both_ends': both ends in mask but not in the middle.
line1 = [[1.5, 1.5, 1.5],
[9, 9, 9],
[1.7, 1.7, 1.7],
[2.5, 2.5, 2.5]]

# line 2 'any': Only one end in mask
line2 = [[1.5, 1.5, 1.5],
[1.6, 1.6, 1.6],
[1.7, 1.7, 1.7],
[9, 9, 9]]

# line 3 'any': Both ends out of mask but touches in the middle
line3 = [[9, 9, 9],
[1.6, 1.6, 1.6],
[1.7, 1.7, 1.7],
[9, 9, 9]]

# line 4 (exclude): never in mask
line4 = [[9, 9, 9.0],
[8, 8, 8.0],
[7, 7, 7.0],
[9, 9, 9.0]]

# line 5 'any': passes through, but no real point inside.
line5 = [[0.1, 0.1, 0.1],
[8, 8, 8.0],
[7, 7, 7.0],
[9, 9, 9.0]]

fake_reference = nib.Nifti1Image(
np.zeros((10, 10, 10, 1)), affine=np.eye(4))
sft = StatefulTractogram([line0, line1, line2, line3, line4, line5],
fake_reference, space=Space.VOXMM,
origin=Origin('corner'))


def test_streamlines_in_mask():
# Test option 'all'
ids = streamlines_in_mask(sft, mask, all_in=True)
assert np.array_equal(ids, [0])

# Test option 'any'
ids = streamlines_in_mask(sft, mask, all_in=False)
assert np.array_equal(ids, [0, 1, 2, 3, 5])


def test_filter_grid_roi_both_ends():
# Pretending to have two masks
# Test option 'both ends'
new_sft, ids = filter_grid_roi_both_ends(sft, mask_1=mask, mask_2=mask)
assert np.array_equal(ids, [0, 1])
assert len(new_sft) == 2


def test_filter_grid_roi():
# Note. Distance not tested yet. (toDo)
roi_options = (mask,)
_test_all_criteria(filter_grid_roi, roi_options, fct_returns_sft=False)


def test_filter_ellipsoid():
roi_options = (bdo_radius_mm, bdo_center_rasmm_centerorigin)
_test_all_criteria(filter_ellipsoid, roi_options, fct_returns_sft=True)


def test_filter_cuboid():
roi_options = (bdo_radius_mm, bdo_center_rasmm_centerorigin)
_test_all_criteria(filter_cuboid, roi_options, fct_returns_sft=True)


def _test_all_criteria(fct, roi_args, fct_returns_sft):
"""
The three filtering methods (filter_grid_roi, filter_ellipsoid,
filter_cuboid) test the same criteria, but with a different way to treat
the ROI. Here are the tests, the roi_args should be different based on fct.
"""
# Parameter is "is_exclude" for all three methods. So:
include=False
exclude=True

def get_ids(output):
if fct_returns_sft:
return output[0]
return output

# Test 'any'
ids = get_ids(fct(sft, *roi_args, 'any', include))
assert np.array_equal(ids, [0, 1, 2, 3, 5])
ids = get_ids(fct(sft, *roi_args, 'any', exclude))
assert np.array_equal(ids, [4])

# Test 'all'
ids = get_ids(fct(sft, *roi_args, 'all', include))
assert np.array_equal(ids, [0])
ids = get_ids(fct(sft, *roi_args, 'all', exclude))
assert np.array_equal(ids, [1, 2, 3, 4, 5])

# Test 'either_end'
ids = get_ids(fct(sft, *roi_args, 'either_end', include))
assert np.array_equal(ids, [0, 1, 2])
ids = get_ids(fct(sft, *roi_args, 'either_end', exclude))
assert np.array_equal(ids, [3, 4, 5])

# Test 'both_ends'
ids = get_ids(fct(sft, *roi_args, 'both_ends', include))
assert np.array_equal(ids, [0, 1])
ids = get_ids(fct(sft, *roi_args, 'both_ends', exclude))
assert np.array_equal(ids, [2, 3, 4, 5])
1 change: 1 addition & 0 deletions src/scilpy/segment/tests/test_tractogram_from_roi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import nibabel as nib
Expand Down
6 changes: 3 additions & 3 deletions src/scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
split_mask_blobs_kmeans
from scilpy.io.image import get_data_as_mask
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both
from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both_ends
from scilpy.tractograms.streamline_operations import \
remove_loops_and_sharp_turns
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
Expand Down Expand Up @@ -363,7 +363,7 @@ def _extract_vb_one_bundle(
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

_, vs_ids = filter_grid_roi_both(sft, mask_1, mask_2)
_, vs_ids = filter_grid_roi_both_ends(sft, mask_1, mask_2)
else:
vs_ids = np.array([])

Expand Down Expand Up @@ -514,7 +514,7 @@ def _extract_ib_one_bundle(sft, mask_1_filename, mask_2_filename,
mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

_, fc_ids = filter_grid_roi_both(sft, mask_1, mask_2)
_, fc_ids = filter_grid_roi_both_ends(sft, mask_1, mask_2)
else:
fc_ids = []

Expand Down
8 changes: 8 additions & 0 deletions src/scilpy/tractograms/tests/test_intersection_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-


def test_find_intersection():
# Todo:
# finder = IntersectionFinder(...)
# finder.find_intersection()
pass