diff --git a/src/scilpy/cli/scil_tractogram_segment_with_ROI_and_score.py b/src/scilpy/cli/scil_tractogram_segment_with_ROI_and_score.py index 33712d85d..45e58b000 100755 --- a/src/scilpy/cli/scil_tractogram_segment_with_ROI_and_score.py +++ b/src/scilpy/cli/scil_tractogram_segment_with_ROI_and_score.py @@ -239,26 +239,24 @@ def load_and_verify_everything(parser, args): logging.info("Loading and/or computing ground-truth masks, limits " "masks and any_masks.") gt_masks = compute_masks_from_bundles(gt_masks_files, parser, args) - inv_all_masks = compute_masks_from_bundles(all_masks_files, parser, args, - inverse_mask=True) + all_masks = compute_masks_from_bundles(all_masks_files, parser, args) any_masks = compute_masks_from_bundles(any_masks_files, parser, args) logging.info("Extracting ground-truth head and tail masks.") gt_tails, gt_heads = compute_endpoint_masks(roi_options, args.out_dir) - # Update the list of every ROI, remove duplicates + # Check that all ROIs are compatible (remove duplicates) + logging.info("Verifying tractogram compatibility with endpoint ROIs.") list_rois = gt_tails + gt_heads list_rois = list(dict.fromkeys(list_rois)) # Removes duplicates - - logging.info("Verifying tractogram compatibility with endpoint ROIs.") for file in list_rois: compatible = is_header_compatible(sft, file) if not compatible: parser.error("Input tractogram incompatible with {}".format(file)) - return (gt_tails, gt_heads, sft, bundle_names, list_rois, + return (gt_tails, gt_heads, sft, bundle_names, lengths, angles, orientation_lengths, abs_orientation_lengths, - inv_all_masks, gt_masks, any_masks, dimensions, json_outputs) + all_masks, gt_masks, any_masks, dimensions, json_outputs) def read_config_file(args): @@ -428,8 +426,8 @@ def main(): logging.getLogger().setLevel(logging.getLevelName(args.verbose)) # Load - (gt_tails, gt_heads, sft, bundle_names, list_rois, bundle_lengths, angles, - orientation_lengths, abs_orientation_lengths, inv_all_masks, gt_masks, + (gt_tails, gt_heads, sft, bundle_names, bundle_lengths, angles, + orientation_lengths, abs_orientation_lengths, all_masks, gt_masks, any_masks, dimensions, json_outputs) = load_and_verify_everything(parser, args) @@ -437,8 +435,10 @@ def main(): (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names, bundle_stats) = segment_tractogram_from_roi( sft, gt_tails, gt_heads, bundle_names, bundle_lengths, angles, - orientation_lengths, abs_orientation_lengths, inv_all_masks, any_masks, - list_rois, args) + orientation_lengths, abs_orientation_lengths, all_masks, any_masks, + args.out_dir, args.compute_ic, args.save_wpc_separately, args.unique, + args.remove_wpc_belonging_to_another_bundle, + args.no_empty, args.bbox_check, args.dilate_endpoints) # Save results with open(json_outputs[0], "w") as f: diff --git a/src/scilpy/segment/tests/test_tractogram_from_roi.py b/src/scilpy/segment/tests/test_tractogram_from_roi.py index 2d451bedd..a2bae8b63 100644 --- a/src/scilpy/segment/tests/test_tractogram_from_roi.py +++ b/src/scilpy/segment/tests/test_tractogram_from_roi.py @@ -1,14 +1,153 @@ +import logging import os import tempfile + +from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin import nibabel as nib import numpy as np - from numpy.testing import (assert_array_equal, assert_equal) from scilpy.segment.tractogram_from_roi import (_extract_vb_one_bundle, - _extract_ib_one_bundle) -from dipy.io.stateful_tractogram import Space, StatefulTractogram + _extract_ib_one_bundle, + segment_tractogram_from_roi) + + +def test_compute_masks_from_bundles(): + pass + +def test_compute_endpoint_masks(): + pass + +def test_segment_tractogram_from_roi(): + + # Creating an example. + # This is already pretty complex, we will not test all complex + # configurations + # Inventing a volume of size 5, 5, 5 and bundles. See ROIs below. + # ROIs are y=0, y=4, z=0, z=4 slices. + + # Bundle 1: left to right (in y). Start at y=0, finish at y=4. + # - two streamlines, reversed + # - any mask: need to pass in voxel [1, 1, 1] + streamline_r_to_l = [[1.1, 0.2, 1.3], + [1.1, 2.2, 1.4], + [1.1, 3.3, 1.3], + [1.1, 4.3, 1.2]] + streamline_l_to_r = streamline_r_to_l[::-1] + bundle1_any_mask = np.zeros((5, 5, 5)) + bundle1_any_mask[1, 1, 1] = 1 # No real point there but they pass through + # wpc not passing by [1, 1, 1] + wpc_streamline1 = [[2.1, 0.2, 1.3], + [2.1, 2.2, 1.4], + [2.1, 3.3, 1.3], + [2.1, 4.3, 1.2]] + + # Bundle 2: vertical. Starts at z=0. Finishes at z=4. + streamline_vertical = [[3.4, 5.9, 0.9], + [3.5, 1.8, 1.8], + [3.4, 2.7, 2.6], + [3.4, 4.7, 4.5]] + bundle2_all_mask = np.zeros((5, 5, 5)) + bundle2_all_mask[3, :, :] = 1 + # wpc not entirely in [3, :, :] + wpc_streamline2 = [[3.4, 5.9, 0.9], + [3.5, 1.8, 1.8], + [5.4, 2.7, 2.6], + [3.4, 4.7, 4.5]] + + # Adding an IC streamline. From y=0, but reaches z=0 instead of reaching + # y=4 + ic_streamline = [[1.1, 0.2, 1.3], + [1.1, 2.2, 1.4], + [1.1, 3.3, 1.3], + [1.1, 2.3, 0.2]] + + # Adding a NC streamline. Finishes in the middle of the volume, so not in + # any ROI. + nc_streamline = [[1.1, 0.2, 1.3], + [1.1, 2.2, 1.4], + [1.1, 3.3, 1.3], + [3.1, 2.3, 3.2]] + + # Preparing data. + gt_heads = [] # will be prepared below + gt_tails = [] # Will be prepared below + bundle_names = ['l-r', 'vertical'] + lengths = [None, None] + angles = [359, 359] + orientation_length = [None, None] + abs_orientation_length = [None, None] + all_masks = [None, bundle2_all_mask] + any_masks = [bundle1_any_mask, None] + + # Many things must be saved on disk. + with tempfile.TemporaryDirectory() as tmpdirname: + print(f'Created temporary directory: {tmpdirname}') + + fake_ref = nib.Nifti1Image(np.zeros((5, 5, 5)), affine=np.eye(4)) + sft = StatefulTractogram([streamline_r_to_l, streamline_l_to_r, + streamline_vertical, wpc_streamline1, + wpc_streamline2, ic_streamline, + nc_streamline], + space=Space.VOX, origin=Origin('corner'), + reference=fake_ref) + + def save_img(array): + img = nib.Nifti1Image(array.astype('uint8'), affine=np.eye(4)) + nib.save(img, filename) + + # Bundle 1 : left to right (streamlines 1 and 2) + # Creating a fake ROI to the left and another to the right. + gt_head = np.zeros((5, 5, 5)) + gt_head[:, 0, :] = 1 + filename = os.path.join(tmpdirname, 'bundle1_head.nii.gz') + save_img(gt_head) + gt_heads.append(filename) + + gt_tail = np.zeros((5, 5, 5)) + gt_tail[:, 4, :] = 1 + filename = os.path.join(tmpdirname, 'bundle1_tail.nii.gz') + save_img(gt_tail) + gt_tails.append(filename) + + # Bundle 2: vertical (streamline 2) + # Creating a fake ROI at the bottom and another at the top + gt_head = np.zeros((5, 5, 5)) + gt_head[:, :, 0] = 1 + filename = os.path.join(tmpdirname, 'bundle2_head.nii.gz') + save_img(gt_head) + gt_heads.append(filename) + + gt_tail = np.zeros((5, 5, 5)) + gt_tail[:, :, 4] = 1 + filename = os.path.join(tmpdirname, 'bundle2_tail.nii.gz') + save_img(gt_tail) + gt_tails.append(filename) + + logging.getLogger().setLevel('INFO') + (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names, + bundle_stats) = segment_tractogram_from_roi( + sft, gt_tails, gt_heads, bundle_names, lengths, angles, + orientation_length, abs_orientation_length, all_masks, any_masks, + out_dir=tmpdirname, compute_ic=True, save_wpc_separately=True) + + print("bundle stats:", bundle_stats) + + # 2 valid in bundle1, 1 valid in bundle 2 + assert len(vb_sft_list) == 2 + assert len(vb_sft_list[0]) == 2 + assert len(vb_sft_list[1]) == 1 + + # two wpc streamlines + assert len(wpc_sft_list) == 2 + assert len(wpc_sft_list[0]) == 1 + assert len(wpc_sft_list[1]) == 1 + + # One IB with one streamline, one NC + assert len(ib_sft_list) == 1 + assert len(ib_sft_list[0]) == 1 + assert len(nc_sft) == 1 def test_extract_vb_one_bundle(): diff --git a/src/scilpy/segment/tractogram_from_roi.py b/src/scilpy/segment/tractogram_from_roi.py index 1ef2f560b..cf669651a 100644 --- a/src/scilpy/segment/tractogram_from_roi.py +++ b/src/scilpy/segment/tractogram_from_roi.py @@ -32,7 +32,7 @@ def _extract_prefix(filename): return prefix -def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): +def compute_masks_from_bundles(gt_files, parser, args): """ Compute ground-truth masks. If the file is already a mask, load it. If it is a bundle, compute the mask. If the filename is None, appends None @@ -49,8 +49,6 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): args: Namespace List of arguments passed to the script. Used for its 'ref' and 'bbox_check' arguments. - inverse_mask: bool - If true, returns the list of inversed masks instead. Returns ------- @@ -69,7 +67,6 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): if ext in ['.gz', '.nii.gz']: gt_img = nib.load(gt_bundle) gt_mask = get_data_as_mask(gt_img) - dimensions = gt_mask.shape else: # Cheating ref because it may send a lot of warning if loading # many trk with ref (reference was maybe added only for some @@ -87,10 +84,6 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): dimensions).astype(np.int16) gt_mask[gt_mask > 0] = 1 - if inverse_mask: - gt_inv_mask = np.zeros(dimensions, dtype=np.int16) - gt_inv_mask[gt_mask == 0] = 1 - gt_mask = gt_inv_mask else: gt_mask = None @@ -177,8 +170,9 @@ def compute_endpoint_masks(roi_options, out_dir): def _extract_vb_and_wpc_all_bundles( gt_tails, gt_heads, sft, bundle_names, lengths, angles, - orientation_lengths, abs_orientation_lengths, inv_all_masks, - any_masks, args): + orientation_lengths, abs_orientation_lengths, all_masks, + any_masks, out_dir, unique, dilate_endpoints, save_wpc_separately, + remove_wpc_belonging_to_another_bundle): """ Loop on every ground truth bundles and extract VS and WPC. @@ -196,7 +190,7 @@ def _extract_vb_and_wpc_all_bundles( vb_sft_list: list List of StatefulTractograms of VS wpc_sft_list: list - List of StatefulTractograms of WPC if args.save_wpc_separately), else + List of StatefulTractograms of WPC if save_wpc_separately), else None. all_vs_wpc_ids: list List of list of all VS + WPC streamlines detected. @@ -214,7 +208,7 @@ def _extract_vb_and_wpc_all_bundles( wpc_ids_list = [] bundles_stats = [] - remaining_ids = np.arange(len(sft)) # For args.unique management. + remaining_ids = np.arange(len(sft)) # For unique management. # 1. Extract VB and WPC. for i in range(nb_bundles): @@ -225,9 +219,9 @@ def _extract_vb_and_wpc_all_bundles( _extract_vb_one_bundle( sft[remaining_ids], head_filename, tail_filename, lengths[i], angles[i], orientation_lengths[i], abs_orientation_lengths[i], - inv_all_masks[i], any_masks[i], args.dilate_endpoints) + all_masks[i], any_masks[i], dilate_endpoints) - if args.unique: + if unique: # Assign actual ids, not from subset vs_ids = remaining_ids[vs_ids] wpc_ids = remaining_ids[wpc_ids] @@ -247,8 +241,8 @@ def _extract_vb_and_wpc_all_bundles( all_gt_ids = list(itertools.chain(*vs_ids_list)) # 2. Remove duplicate WPC and then save. - if args.save_wpc_separately: - if args.remove_wpc_belonging_to_another_bundle or args.unique: + if save_wpc_separately: + if remove_wpc_belonging_to_another_bundle or unique: for i in range(nb_bundles): new_wpc_ids = np.setdiff1d(wpc_ids_list[i], all_gt_ids) nb_rejected = len(wpc_ids_list[i]) - len(new_wpc_ids) @@ -272,9 +266,9 @@ def _extract_vb_and_wpc_all_bundles( wpc_ids_list = [[] for _ in range(nb_bundles)] wpc_sft_list = None - # 3. If not args.unique, tell users if there were duplicates. Save + # 3. If not unique, tell users if there were duplicates. Save # duplicates separately in segmented_conflicts/duplicates_*_*.trk. - if not args.unique: + if not unique: for i in range(nb_bundles): for j in range(i + 1, nb_bundles): duplicate_ids = np.intersect1d(vs_ids_list[i], vs_ids_list[j]) @@ -288,7 +282,7 @@ def _extract_vb_and_wpc_all_bundles( # Duplicates directory only created if at least one # duplicate is found. - path_duplicates = os.path.join(args.out_dir, + path_duplicates = os.path.join(out_dir, 'segmented_conflicts') if not os.path.isdir(path_duplicates): os.makedirs(path_duplicates) @@ -311,7 +305,7 @@ def _extract_vb_and_wpc_all_bundles( def _extract_vb_one_bundle( sft, head_filename, tail_filename, limits_length, angle, - orientation_length, abs_orientation_length, inv_all_mask, + orientation_length, abs_orientation_length, all_mask, any_mask, dilate_endpoints): """ Extract valid bundle (and valid streamline ids) from a tractogram, based @@ -335,8 +329,8 @@ def _extract_vb_one_bundle( Bundle's length parameters in each direction: [[min_x, max_x], [min_y, max_y], [min_z, max_z]] abs_orientation_length: idem, computed in absolute values. - inv_all_mask: np.ndarray or None - Inverse ALL mask for this bundle: no point must be outside the mask. + all_mask: np.ndarray or None + The "ALL" mask for this bundle: no point must be outside the mask. any_mask: np.ndarray or None ANY mask for this bundle. Streamlines must pass through this mask (touch it) to be included @@ -371,8 +365,13 @@ def _extract_vb_one_bundle( bundle_stats = {"Initial count head to tail": len(vs_ids)} # Remove out of inclusion mask (limits_mask) - if len(vs_ids) > 0 and inv_all_mask is not None: + if len(vs_ids) > 0 and all_mask is not None: tmp_sft = sft[vs_ids] + + # ALL points inside = NO points outside = NOT ANY point outside + # Inversing the mask. + all_mask = all_mask.astype(bool) + inv_all_mask = ~all_mask out_of_mask_ids_from_vs = filter_grid_roi( tmp_sft, inv_all_mask, 'any', is_exclude=False) out_of_mask_ids = vs_ids[out_of_mask_ids_from_vs] @@ -522,7 +521,7 @@ def _extract_ib_one_bundle(sft, mask_1_filename, mask_2_filename, return fc_sft, fc_ids -def _extract_ib_all_bundles(comb_filename, sft, args): +def _extract_ib_all_bundles(comb_filename, sft, unique, dilate_endpoints): """ Loop on every bundle and compute false connections, defined as connections between ROIs pairs that do not form gt bundles. @@ -542,9 +541,9 @@ def _extract_ib_all_bundles(comb_filename, sft, args): prefix_2 = _extract_prefix(roi2_filename) ib_sft, ic_ids = _extract_ib_one_bundle( - sft[all_ids], roi1_filename, roi2_filename, args.dilate_endpoints) + sft[all_ids], roi1_filename, roi2_filename, dilate_endpoints) - if args.unique: + if unique: ic_ids = all_ids[ic_ids] all_ids = np.setdiff1d(all_ids, ic_ids, assume_unique=True) @@ -557,7 +556,7 @@ def _extract_ib_all_bundles(comb_filename, sft, args): ib_bundle_names.append(prefix_1 + '_' + prefix_2) # Duplicates? - if not args.unique: + if not unique: nb_pairs = len(ic_ids_list) for i in range(nb_pairs): for j in range(i + 1, nb_pairs): @@ -575,30 +574,79 @@ def _extract_ib_all_bundles(comb_filename, sft, args): def segment_tractogram_from_roi( sft, gt_tails, gt_heads, bundle_names, bundle_lengths, angles, - orientation_lengths, abs_orientation_lengths, inv_all_masks, any_masks, - list_rois, args): + orientation_lengths, abs_orientation_lengths, all_masks, any_masks, + out_dir, compute_ic=False, save_wpc_separately=False, unique=True, + remove_wpc_belonging_to_another_bundle=True, + no_empty=True, bbox_check=True, dilate_endpoints=0): """ - Segments valid bundles (VB). Based on args: - - args.compute_ic: computes invalid bundles (IB) - - args.save_wpc_separately: compute WPC + Segments valid bundles (VB). + + Parameters + ---------- + sft: StatefulTractogram + The tractogram to segment. + gt_tails: list[str] + List of filenames, each VB endpoint mask (first end) + gt_heads: list[str] + List of filenames, each VB endpoint mask (second end), in the same + order as gt_tails. Ex, VB #2 uses gt_tails[2] and gt_head[2] as + endpoints. + bundle_names: list[str] + Bundle names. + bundle_lengths: list[[float, float] or None] + Maximum length for each bundle. Either a limit range, [float, float] or + None for no limit. + angles: list[float] + Maximum angle (in loops) for each bundle (in degree). + orientation_lengths: list[[limitsx, limitsy, limitsz] or None] + For each bundle, the length parameters in each direction: + [[min_x, max_x], [min_y, max_y], [min_z, max_z]]. None for no limit. + abs_orientation_lengths: list[[limitsx, limitsy, limitsz] or None] + Idem, computed in absolute values. + all_masks: list[np.ndarray or None] + For each bundle, the "ALL" mask for this bundle: no point must + be outside the mask. + any_masks: list[np.ndarray or None] + For each bundle, the "ANY" mask for this bundle: at least one point + must pass through this mask. + out_dir: str + Output directory. We will save all VB, IC and WPC there. + compute_ic: bool + Also compute invalid connections (IC). + save_wpc_separately: bool + Separate wrong path connections (WPC) from other invalid connections + (IC). WPC = correct endpoint ROIs but wrong path based on other + criteria. + unique: bool + If True, streamlines are assigned to the first bundle they fit in and + not to all. + remove_wpc_belonging_to_another_bundle: bool + If true, WPC actually belonging to any VB (in the case of overlapping + ROIs) will be removed from the WPC classification. + no_empty: bool + If true, do not save empty bundles. + bbox_check: bool + If true, check bounding box validation. + dilate_endpoints: int + Dilate endpoint masks n-times. Default: 0. Returns ------- - vb_sft_list: list + vb_sft_list: list[StatefulTractogram] The list of valid bundles discovered. These files are also saved in segmented_VB/\\*_VS.trk. - wpc_sft_list: list + wpc_sft_list: list[StatefulTractogram or None] or None The list of wrong path connections: streamlines connecting the right - endpoint regions but not included in the ALL mask. - ** This is only computed if args.save_wpc_separately. Else, this is - None. - ib_sft_list: list + endpoint regions but not included in the ALL mask. This list has the + same length as vb_sft_list. + ** This is only computed if save_wpc_separately. Else, this is None. + ib_sft_list: list[StatefulTractogram] or None The list of invalid bundles: streamlines connecting regions that should not be connected. - ** This is only computed if args.compute_ic. Else, this is None. - nc_sft_list: list + ** This is only computed if compute_ic. Else, this is None. + nc_sft_list: list[StatefulTractogram] or None The list of rejected streamlines that were not included in any IB. - ib_names: list + ib_names: list[StatefulTractogram] or None The list of names for invalid bundles (IB). They are created from the combinations of ROIs used for IB computations. bundle_stats: dict @@ -612,14 +660,17 @@ def segment_tractogram_from_roi( _extract_vb_and_wpc_all_bundles( gt_tails, gt_heads, sft, bundle_names, bundle_lengths, angles, orientation_lengths, abs_orientation_lengths, - inv_all_masks, any_masks, args) + all_masks, any_masks, out_dir, unique, dilate_endpoints, + save_wpc_separately, remove_wpc_belonging_to_another_bundle) remaining_ids = np.arange(0, len(sft)) - if args.unique: + if unique: remaining_ids = np.setdiff1d(remaining_ids, detected_vs_wpc_ids) # IC - if args.compute_ic and len(remaining_ids) > 0: + list_rois = gt_tails + gt_heads + list_rois = list(dict.fromkeys(list_rois)) # Removes duplicates + if compute_ic and len(remaining_ids) > 0: logging.info("Extracting invalid bundles") # Keep all possible combinations @@ -633,8 +684,8 @@ def segment_tractogram_from_roi( vb_roi_pair = tuple(sorted(vb_roi_pair)) comb_filename.remove(vb_roi_pair) ib_sft_list, ic_ids_list, ib_names = _extract_ib_all_bundles( - comb_filename, sft[remaining_ids], args) - if args.unique and len(ic_ids_list) > 0: + comb_filename, sft[remaining_ids], unique, dilate_endpoints) + if unique and len(ic_ids_list) > 0: for i in range(len(ic_ids_list)): # Assign actual ids ic_ids_list[i] = remaining_ids[ic_ids_list[i]] @@ -650,11 +701,11 @@ def segment_tractogram_from_roi( # NC # = ids that are not VS, not wpc (if asked) and not IC (if asked). all_nc_ids = remaining_ids - if not args.unique: + if not unique: all_nc_ids = np.setdiff1d(all_nc_ids, detected_vs_wpc_ids) all_nc_ids = np.setdiff1d(all_nc_ids, all_ic_ids) - if args.compute_ic: + if compute_ic: logging.info("The remaining {} / {} streamlines will be scored as NC." .format(len(all_nc_ids), len(sft))) filename = "NC.trk" @@ -664,9 +715,9 @@ def segment_tractogram_from_roi( filename = "IS.trk" nc_sft = sft[all_nc_ids] - if len(nc_sft) > 0 or not args.no_empty: + if len(nc_sft) > 0 or not no_empty: save_tractogram(nc_sft, os.path.join( - args.out_dir, filename), bbox_valid_check=args.bbox_check) + out_dir, filename), bbox_valid_check=bbox_check) return (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names, bundle_stats) diff --git a/src/scilpy/tractograms/streamline_operations.py b/src/scilpy/tractograms/streamline_operations.py index dfacc173b..2d24756b2 100644 --- a/src/scilpy/tractograms/streamline_operations.py +++ b/src/scilpy/tractograms/streamline_operations.py @@ -893,7 +893,7 @@ def remove_loops(streamlines, max_angle, num_processes=1): The list of streamlines from which to remove loops and sharp turns. max_angle: float Maximal winding angle a streamline can have before being classified as - a loop. + a loop (in degree). num_processes : int Split the calculation to a pool of children processes. @@ -967,7 +967,7 @@ def remove_loops_and_sharp_turns(streamlines, max_angle, qb_threshold=None, The list of streamlines from which to remove loops and sharp turns. max_angle: float Maximal winding angle a streamline can have before being classified as - a loop. + a loop (in degree) qb_threshold: float, optional If not None, do the additional QuickBundles pass. This will help remove sharp turns. Should only be used on bundled streamlines, not on