Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/scilpy/cli/scil_tractogram_segment_with_ROI_and_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,26 +239,24 @@ def load_and_verify_everything(parser, args):
logging.info("Loading and/or computing ground-truth masks, limits "
"masks and any_masks.")
gt_masks = compute_masks_from_bundles(gt_masks_files, parser, args)
inv_all_masks = compute_masks_from_bundles(all_masks_files, parser, args,
inverse_mask=True)
all_masks = compute_masks_from_bundles(all_masks_files, parser, args)
any_masks = compute_masks_from_bundles(any_masks_files, parser, args)

logging.info("Extracting ground-truth head and tail masks.")
gt_tails, gt_heads = compute_endpoint_masks(roi_options, args.out_dir)

# Update the list of every ROI, remove duplicates
# Check that all ROIs are compatible (remove duplicates)
logging.info("Verifying tractogram compatibility with endpoint ROIs.")
list_rois = gt_tails + gt_heads
list_rois = list(dict.fromkeys(list_rois)) # Removes duplicates

logging.info("Verifying tractogram compatibility with endpoint ROIs.")
for file in list_rois:
compatible = is_header_compatible(sft, file)
if not compatible:
parser.error("Input tractogram incompatible with {}".format(file))

return (gt_tails, gt_heads, sft, bundle_names, list_rois,
return (gt_tails, gt_heads, sft, bundle_names,
lengths, angles, orientation_lengths, abs_orientation_lengths,
inv_all_masks, gt_masks, any_masks, dimensions, json_outputs)
all_masks, gt_masks, any_masks, dimensions, json_outputs)


def read_config_file(args):
Expand Down Expand Up @@ -428,17 +426,19 @@ def main():
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

# Load
(gt_tails, gt_heads, sft, bundle_names, list_rois, bundle_lengths, angles,
orientation_lengths, abs_orientation_lengths, inv_all_masks, gt_masks,
(gt_tails, gt_heads, sft, bundle_names, bundle_lengths, angles,
orientation_lengths, abs_orientation_lengths, all_masks, gt_masks,
any_masks, dimensions,
json_outputs) = load_and_verify_everything(parser, args)

# Segment VB, WPC, IB
(vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft,
ib_names, bundle_stats) = segment_tractogram_from_roi(
sft, gt_tails, gt_heads, bundle_names, bundle_lengths, angles,
orientation_lengths, abs_orientation_lengths, inv_all_masks, any_masks,
list_rois, args)
orientation_lengths, abs_orientation_lengths, all_masks, any_masks,
args.out_dir, args.compute_ic, args.save_wpc_separately, args.unique,
args.remove_wpc_belonging_to_another_bundle,
args.no_empty, args.bbox_check, args.dilate_endpoints)

# Save results
with open(json_outputs[0], "w") as f:
Expand Down
145 changes: 142 additions & 3 deletions src/scilpy/segment/tests/test_tractogram_from_roi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,153 @@
import logging
import os
import tempfile

from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
import nibabel as nib
import numpy as np

from numpy.testing import (assert_array_equal,
assert_equal)

from scilpy.segment.tractogram_from_roi import (_extract_vb_one_bundle,
_extract_ib_one_bundle)
from dipy.io.stateful_tractogram import Space, StatefulTractogram
_extract_ib_one_bundle,
segment_tractogram_from_roi)


def test_compute_masks_from_bundles():
pass

def test_compute_endpoint_masks():
pass

def test_segment_tractogram_from_roi():

# Creating an example.
# This is already pretty complex, we will not test all complex
# configurations
# Inventing a volume of size 5, 5, 5 and bundles. See ROIs below.
# ROIs are y=0, y=4, z=0, z=4 slices.

# Bundle 1: left to right (in y). Start at y=0, finish at y=4.
# - two streamlines, reversed
# - any mask: need to pass in voxel [1, 1, 1]
streamline_r_to_l = [[1.1, 0.2, 1.3],
[1.1, 2.2, 1.4],
[1.1, 3.3, 1.3],
[1.1, 4.3, 1.2]]
streamline_l_to_r = streamline_r_to_l[::-1]
bundle1_any_mask = np.zeros((5, 5, 5))
bundle1_any_mask[1, 1, 1] = 1 # No real point there but they pass through
# wpc not passing by [1, 1, 1]
wpc_streamline1 = [[2.1, 0.2, 1.3],
[2.1, 2.2, 1.4],
[2.1, 3.3, 1.3],
[2.1, 4.3, 1.2]]

# Bundle 2: vertical. Starts at z=0. Finishes at z=4.
streamline_vertical = [[3.4, 5.9, 0.9],
[3.5, 1.8, 1.8],
[3.4, 2.7, 2.6],
[3.4, 4.7, 4.5]]
bundle2_all_mask = np.zeros((5, 5, 5))
bundle2_all_mask[3, :, :] = 1
# wpc not entirely in [3, :, :]
wpc_streamline2 = [[3.4, 5.9, 0.9],
[3.5, 1.8, 1.8],
[5.4, 2.7, 2.6],
[3.4, 4.7, 4.5]]

# Adding an IC streamline. From y=0, but reaches z=0 instead of reaching
# y=4
ic_streamline = [[1.1, 0.2, 1.3],
[1.1, 2.2, 1.4],
[1.1, 3.3, 1.3],
[1.1, 2.3, 0.2]]

# Adding a NC streamline. Finishes in the middle of the volume, so not in
# any ROI.
nc_streamline = [[1.1, 0.2, 1.3],
[1.1, 2.2, 1.4],
[1.1, 3.3, 1.3],
[3.1, 2.3, 3.2]]

# Preparing data.
gt_heads = [] # will be prepared below
gt_tails = [] # Will be prepared below
bundle_names = ['l-r', 'vertical']
lengths = [None, None]
angles = [359, 359]
orientation_length = [None, None]
abs_orientation_length = [None, None]
all_masks = [None, bundle2_all_mask]
any_masks = [bundle1_any_mask, None]

# Many things must be saved on disk.
with tempfile.TemporaryDirectory() as tmpdirname:
print(f'Created temporary directory: {tmpdirname}')

fake_ref = nib.Nifti1Image(np.zeros((5, 5, 5)), affine=np.eye(4))
sft = StatefulTractogram([streamline_r_to_l, streamline_l_to_r,
streamline_vertical, wpc_streamline1,
wpc_streamline2, ic_streamline,
nc_streamline],
space=Space.VOX, origin=Origin('corner'),
reference=fake_ref)

def save_img(array):
img = nib.Nifti1Image(array.astype('uint8'), affine=np.eye(4))
nib.save(img, filename)

# Bundle 1 : left to right (streamlines 1 and 2)
# Creating a fake ROI to the left and another to the right.
gt_head = np.zeros((5, 5, 5))
gt_head[:, 0, :] = 1
filename = os.path.join(tmpdirname, 'bundle1_head.nii.gz')
save_img(gt_head)
gt_heads.append(filename)

gt_tail = np.zeros((5, 5, 5))
gt_tail[:, 4, :] = 1
filename = os.path.join(tmpdirname, 'bundle1_tail.nii.gz')
save_img(gt_tail)
gt_tails.append(filename)

# Bundle 2: vertical (streamline 2)
# Creating a fake ROI at the bottom and another at the top
gt_head = np.zeros((5, 5, 5))
gt_head[:, :, 0] = 1
filename = os.path.join(tmpdirname, 'bundle2_head.nii.gz')
save_img(gt_head)
gt_heads.append(filename)

gt_tail = np.zeros((5, 5, 5))
gt_tail[:, :, 4] = 1
filename = os.path.join(tmpdirname, 'bundle2_tail.nii.gz')
save_img(gt_tail)
gt_tails.append(filename)

logging.getLogger().setLevel('INFO')
(vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names,
bundle_stats) = segment_tractogram_from_roi(
sft, gt_tails, gt_heads, bundle_names, lengths, angles,
orientation_length, abs_orientation_length, all_masks, any_masks,
out_dir=tmpdirname, compute_ic=True, save_wpc_separately=True)

print("bundle stats:", bundle_stats)

# 2 valid in bundle1, 1 valid in bundle 2
assert len(vb_sft_list) == 2
assert len(vb_sft_list[0]) == 2
assert len(vb_sft_list[1]) == 1

# two wpc streamlines
assert len(wpc_sft_list) == 2
assert len(wpc_sft_list[0]) == 1
assert len(wpc_sft_list[1]) == 1

# One IB with one streamline, one NC
assert len(ib_sft_list) == 1
assert len(ib_sft_list[0]) == 1
assert len(nc_sft) == 1


def test_extract_vb_one_bundle():
Expand Down
Loading