diff --git a/tractseg/libs/tracking.py b/tractseg/libs/tracking.py index 8d7e9bc..f2f3750 100644 --- a/tractseg/libs/tracking.py +++ b/tractseg/libs/tracking.py @@ -84,176 +84,175 @@ def track(bundle, peaks, output_dir, tracking_on_FODs, tracking_software, tracki ################### Tracking ################### - if not bundle_mask_ok or not beginnings_mask_ok or not endings_mask_ok: - fiber_utils.create_empty_tractogram(output_dir + "/" + tracking_folder + "/" + - bundle + "." + output_format, - output_dir + "/bundle_segmentations" + dir_postfix + "/" + - bundle + ".nii.gz", - tracking_format=output_format) - else: + if not bundle_mask_ok or not beginnings_mask_ok or not endings_mask_ok: + fiber_utils.create_empty_tractogram(output_dir + "/" + tracking_folder + "/" + + bundle + "." + output_format, + output_dir + "/bundle_segmentations" + dir_postfix + "/" + + bundle + ".nii.gz", + tracking_format=output_format) + shutil.rmtree(tmp_dir) + return + # Filtering - if filter_by_endpoints: - - # Mrtrix Tracking - if tracking_software == "mrtrix": - - # Prepare files - img_utils.dilate_binary_mask(output_dir + "/bundle_segmentations" + dir_postfix + "/" + bundle + ".nii.gz", - tmp_dir + "/" + bundle + ".nii.gz", dilation=dilation) - img_utils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz", - tmp_dir + "/" + bundle + "_e.nii.gz", dilation=dilation + 1) - img_utils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz", - tmp_dir + "/" + bundle + "_b.nii.gz", dilation=dilation + 1) - - # Mrtrix tracking on original FODs (have to be provided to -i) - if tracking_on_FODs: - if tracking_algorithm == "FACT" or tracking_algorithm == "SD_STREAM": - seeds = 1000000 - else: - seeds = 200000 - # Quite slow - # cutoff 0.1 gives more sensitive results than 0.05 (default) (tested for HCP msmt) - # - better for CA & FX (less oversegmentation) - # - worse for CST (missing lateral projections) - subprocess.call("tckgen -algorithm " + tracking_algorithm + " " + - peaks + " " + + + # Mrtrix Tracking + if tracking_software == "mrtrix": + + # Prepare files + img_utils.dilate_binary_mask(output_dir + "/bundle_segmentations" + dir_postfix + "/" + bundle + ".nii.gz", + tmp_dir + "/" + bundle + ".nii.gz", dilation=dilation) + img_utils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz", + tmp_dir + "/" + bundle + "_e.nii.gz", dilation=dilation + 1) + img_utils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz", + tmp_dir + "/" + bundle + "_b.nii.gz", dilation=dilation + 1) + + # Mrtrix tracking on original FODs (have to be provided to -i) + if tracking_on_FODs: + if tracking_algorithm == "FACT" or tracking_algorithm == "SD_STREAM": + seeds = 1000000 + else: + seeds = 200000 + # Quite slow + # cutoff 0.1 gives more sensitive results than 0.05 (default) (tested for HCP msmt) + # - better for CA & FX (less oversegmentation) + # - worse for CST (missing lateral projections) + subprocess.call("tckgen -algorithm " + tracking_algorithm + " " + + peaks + " " + + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + + " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + + " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + + " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + + " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + + " -minlength 40 -maxlength 250 -seeds " + str(seeds) + + " -select " + str(nr_fibers) + " -cutoff 0.05 -force" + nthreads, + shell=True) + if output_format == "trk" or output_format == "trk_legacy": + _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) + + else: + # FACT tracking on TOMs + if tracking_algorithm == "FACT": + # Takes around 2.5min for 1 subject (2mm resolution) + subprocess.call("tckgen -algorithm FACT " + + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + + " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + + " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + + " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + + " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + + " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + + " -force -quiet" + nthreads, + shell=True) + if output_format == "trk" or output_format == "trk_legacy": + _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) + + # iFOD2 tracking on TOMs + elif tracking_algorithm == "iFOD2": + # Takes around 12min for 1 subject (2mm resolution) + img_utils.peaks2fixel(output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz", tmp_dir + "/fixel") + subprocess.call("fixel2sh " + tmp_dir + "/fixel/amplitudes.nii.gz " + + tmp_dir + "/fixel/sh.nii.gz -quiet", shell=True) + subprocess.call("tckgen -algorithm iFOD2 " + + tmp_dir + "/fixel/sh.nii.gz " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + - " -minlength 40 -maxlength 250 -seeds " + str(seeds) + - " -select " + str(nr_fibers) + " -cutoff 0.05 -force" + nthreads, + " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + + " -force -quiet" + nthreads, shell=True) if output_format == "trk" or output_format == "trk_legacy": _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) else: - # FACT tracking on TOMs - if tracking_algorithm == "FACT": - # Takes around 2.5min for 1 subject (2mm resolution) - subprocess.call("tckgen -algorithm FACT " + - output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + - output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + - " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + - " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + - " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + - " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + - " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + - " -force -quiet" + nthreads, - shell=True) - if output_format == "trk" or output_format == "trk_legacy": - _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) - - # iFOD2 tracking on TOMs - elif tracking_algorithm == "iFOD2": - # Takes around 12min for 1 subject (2mm resolution) - img_utils.peaks2fixel(output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz", tmp_dir + "/fixel") - subprocess.call("fixel2sh " + tmp_dir + "/fixel/amplitudes.nii.gz " + - tmp_dir + "/fixel/sh.nii.gz -quiet", shell=True) - subprocess.call("tckgen -algorithm iFOD2 " + - tmp_dir + "/fixel/sh.nii.gz " + - output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + - " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + - " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + - " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + - " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + - " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + - " -force -quiet" + nthreads, - shell=True) - if output_format == "trk" or output_format == "trk_legacy": - _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) - - else: - raise ValueError("Unknown tracking algorithm: {}".format(tracking_algorithm)) - - - # TractSeg probabilistic tracking - else: + raise ValueError("Unknown tracking algorithm: {}".format(tracking_algorithm)) - # Prepare files - bundle_mask_img = nib.load(output_dir + "/bundle_segmentations" + dir_postfix + "/" - + bundle + ".nii.gz") - beginnings_img = nib.load(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz") - endings_img = nib.load(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz") - tom_peaks_img = nib.load(output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz") - - # Ensure same orientation as MNI space - bundle_mask, flip_axis = img_utils.flip_axis_to_match_MNI_space(bundle_mask_img.get_fdata().astype(np.uint8), - bundle_mask_img.affine) - beginnings, flip_axis = img_utils.flip_axis_to_match_MNI_space(beginnings_img.get_fdata().astype(np.uint8), - beginnings_img.affine) - endings, flip_axis = img_utils.flip_axis_to_match_MNI_space(endings_img.get_fdata().astype(np.uint8), - endings_img.affine) - tom_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(tom_peaks_img.get_fdata(), - tom_peaks_img.affine) - - # tracking_uncertainties = nib.load(output_dir + "/tracking_uncertainties/" + bundle + ".nii.gz").get_fdata() - tracking_uncertainties = None - - #Get best original peaks - if use_best_original_peaks: - orig_peaks_img = nib.load(peaks) - orig_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(orig_peaks_img.get_fdata(), - orig_peaks_img.affine) - best_orig_peaks = fiber_utils.get_best_original_peaks(tom_peaks, orig_peaks) - for axis in flip_axis: - best_orig_peaks = img_utils.flip_axis(best_orig_peaks, axis) - nib.save(nib.Nifti1Image(best_orig_peaks, orig_peaks_img.affine), - output_dir + "/" + tracking_folder + "/" + bundle + ".nii.gz") - tom_peaks = best_orig_peaks - - #Get weighted mean between best original peaks and TOMs - if use_as_prior: - orig_peaks_img = nib.load(peaks) - orig_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(orig_peaks_img.get_fdata(), - orig_peaks_img.affine) - best_orig_peaks = fiber_utils.get_best_original_peaks(tom_peaks, orig_peaks) - weighted_peaks = fiber_utils.get_weighted_mean_of_peaks(best_orig_peaks, tom_peaks, weight=0.5) - for axis in flip_axis: - weighted_peaks = img_utils.flip_axis(weighted_peaks, axis) - nib.save(nib.Nifti1Image(weighted_peaks, orig_peaks_img.affine), - output_dir + "/" + tracking_folder + "/" + bundle + "_weighted.nii.gz") - tom_peaks = weighted_peaks - - # Takes around 6min for 1 subject (2mm resolution) - streamlines = tractseg_prob_tracking.track(tom_peaks, max_nr_fibers=nr_fibers, smooth=5, - compress=0.1, bundle_mask=bundle_mask, start_mask=beginnings, - end_mask=endings, - tracking_uncertainties=tracking_uncertainties, - dilation=dilation, - next_step_displacement_std=next_step_displacement_std, - nr_cpus=nr_cpus, affine=bundle_mask_img.affine, - spacing=bundle_mask_img.header.get_zooms()[0], - verbose=False) - - if output_format == "trk_legacy": - fiber_utils.save_streamlines_as_trk_legacy(output_dir + "/" + tracking_folder + "/" + bundle + ".trk", - streamlines, bundle_mask_img.affine, - bundle_mask_img.get_fdata().shape) - else: # tck or trk (determined by file ending) - fiber_utils.save_streamlines( - output_dir + "/" + tracking_folder + "/" + bundle + "." + output_format, - streamlines, bundle_mask_img.affine, - bundle_mask_img.get_fdata().shape) - - - # No streamline filtering + + # TractSeg probabilistic tracking else: - peak_utils.peak_image_to_binary_mask_path(peaks, tmp_dir + "/peak_mask.nii.gz", - peak_length_threshold=0.01) + # Prepare files + bundle_mask_img = nib.load(output_dir + "/bundle_segmentations" + dir_postfix + "/" + + bundle + ".nii.gz") + beginnings_img = nib.load(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz") + endings_img = nib.load(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz") + tom_peaks_img = nib.load(output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz") + + # Ensure same orientation as MNI space + bundle_mask, flip_axis = img_utils.flip_axis_to_match_MNI_space(bundle_mask_img.get_fdata().astype(np.uint8), + bundle_mask_img.affine) + beginnings, flip_axis = img_utils.flip_axis_to_match_MNI_space(beginnings_img.get_fdata().astype(np.uint8), + beginnings_img.affine) + endings, flip_axis = img_utils.flip_axis_to_match_MNI_space(endings_img.get_fdata().astype(np.uint8), + endings_img.affine) + tom_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(tom_peaks_img.get_fdata(), + tom_peaks_img.affine) + + # tracking_uncertainties = nib.load(output_dir + "/tracking_uncertainties/" + bundle + ".nii.gz").get_fdata() + tracking_uncertainties = None + + #Get best original peaks + if use_best_original_peaks: + orig_peaks_img = nib.load(peaks) + orig_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(orig_peaks_img.get_fdata(), + orig_peaks_img.affine) + best_orig_peaks = fiber_utils.get_best_original_peaks(tom_peaks, orig_peaks) + for axis in flip_axis: + best_orig_peaks = img_utils.flip_axis(best_orig_peaks, axis) + nib.save(nib.Nifti1Image(best_orig_peaks, orig_peaks_img.affine), + output_dir + "/" + tracking_folder + "/" + bundle + ".nii.gz") + tom_peaks = best_orig_peaks + + #Get weighted mean between best original peaks and TOMs + if use_as_prior: + orig_peaks_img = nib.load(peaks) + orig_peaks, flip_axis = img_utils.flip_axis_to_match_MNI_space(orig_peaks_img.get_fdata(), + orig_peaks_img.affine) + best_orig_peaks = fiber_utils.get_best_original_peaks(tom_peaks, orig_peaks) + weighted_peaks = fiber_utils.get_weighted_mean_of_peaks(best_orig_peaks, tom_peaks, weight=0.5) + for axis in flip_axis: + weighted_peaks = img_utils.flip_axis(weighted_peaks, axis) + nib.save(nib.Nifti1Image(weighted_peaks, orig_peaks_img.affine), + output_dir + "/" + tracking_folder + "/" + bundle + "_weighted.nii.gz") + tom_peaks = weighted_peaks + + # Takes around 6min for 1 subject (2mm resolution) + streamlines = tractseg_prob_tracking.track(tom_peaks, max_nr_fibers=nr_fibers, smooth=5, + compress=0.1, bundle_mask=bundle_mask, start_mask=beginnings, + end_mask=endings, + tracking_uncertainties=tracking_uncertainties, + dilation=dilation, + next_step_displacement_std=next_step_displacement_std, + nr_cpus=nr_cpus, affine=bundle_mask_img.affine, + spacing=bundle_mask_img.header.get_zooms()[0], + verbose=False) + + if output_format == "trk_legacy": + fiber_utils.save_streamlines_as_trk_legacy(output_dir + "/" + tracking_folder + "/" + bundle + ".trk", + streamlines, bundle_mask_img.affine, + bundle_mask_img.get_fdata().shape) + else: # tck or trk (determined by file ending) + fiber_utils.save_streamlines( + output_dir + "/" + tracking_folder + "/" + bundle + "." + output_format, + streamlines, bundle_mask_img.affine, + bundle_mask_img.get_fdata().shape) + + # No streamline filtering + else: - # FACT Tracking on TOMs - subprocess.call("tckgen -algorithm FACT " + - output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + - output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + - " -seed_image " + tmp_dir + "/peak_mask.nii.gz" + - " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + - " -force -quiet" + nthreads, shell=True) + peak_utils.peak_image_to_binary_mask_path(peaks, tmp_dir + "/peak_mask.nii.gz", + peak_length_threshold=0.01) - if output_format == "trk" or output_format == "trk_legacy": - _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) + # FACT Tracking on TOMs + subprocess.call("tckgen -algorithm FACT " + + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + + " -seed_image " + tmp_dir + "/peak_mask.nii.gz" + + " -minlength 40 -maxlength 250 -select " + str(nr_fibers) + + " -force -quiet" + nthreads, shell=True) + if output_format == "trk" or output_format == "trk_legacy": + _mrtrix_tck_to_trk(output_dir, tracking_folder, dir_postfix, bundle, output_format, nr_cpus) shutil.rmtree(tmp_dir)