diff --git a/pyproject.toml b/pyproject.toml index ee642a451..c4fef1d20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ scil_tractogram_assign_uniform_color = "scilpy.cli.scil_tractogram_assign_unifo scil_tractogram_commit = "scilpy.cli.scil_tractogram_commit:main" scil_tractogram_compress = "scilpy.cli.scil_tractogram_compress:main" scil_tractogram_compute_density_map = "scilpy.cli.scil_tractogram_compute_density_map:main" +scil_tractogram_compute_ae = "scilpy.cli.scil_tractogram_compute_ae:main" scil_tractogram_compute_TODI = "scilpy.cli.scil_tractogram_compute_TODI:main" scil_tractogram_convert_hdf5_to_trk = "scilpy.cli.scil_tractogram_convert_hdf5_to_trk:main" scil_tractogram_convert = "scilpy.cli.scil_tractogram_convert:main" @@ -233,6 +234,7 @@ scil_tractogram_cut_streamlines = "scilpy.cli.scil_tractogram_cut_streamlines:m scil_tractogram_detect_loops = "scilpy.cli.scil_tractogram_detect_loops:main" scil_tractogram_dpp_math = "scilpy.cli.scil_tractogram_dpp_math:main" scil_tractogram_dps_math = "scilpy.cli.scil_tractogram_dps_math:main" +scil_tractogram_extract_streamlines = "scilpy.cli.scil_tractogram_extract_streamlines:main" scil_tractogram_extract_ushape = "scilpy.cli.scil_tractogram_extract_ushape:main" scil_tractogram_filter_by_anatomy = "scilpy.cli.scil_tractogram_filter_by_anatomy:main" scil_tractogram_filter_by_length = "scilpy.cli.scil_tractogram_filter_by_length:main" diff --git a/src/scilpy/cli/scil_tractogram_compute_ae.py b/src/scilpy/cli/scil_tractogram_compute_ae.py new file mode 100644 index 000000000..4d4cda22d --- /dev/null +++ b/src/scilpy/cli/scil_tractogram_compute_ae.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Compute the angular error (AE) for each segment of the streamlines. + +For each segment of each streamline, the direction is compared with the +underlying peak (for single peak files like DTI) or with the closest peak +(ex, with fODF peaks). Currently, interpolation is not supported: peaks of +the closest voxel are used (nearest neighbor). AE is computed as the cosine +difference. + +The ae is added as data_per_point (dpp) for each segment, using the last point +of the segment. The last point of each streamline has an AE of zero. +Optionnally, you may also save it as a color. +""" + +import argparse +import logging + +import nibabel as nib +import numpy as np + +from scilpy.io.streamlines import (load_tractogram_with_reference, + save_tractogram) +from scilpy.io.utils import (add_processes_arg, add_verbose_arg, + add_overwrite_arg, assert_headers_compatible, + assert_inputs_exist, assert_outputs_exist, + add_bbox_arg) +from scilpy.tractanalysis.scoring import compute_ae +from scilpy.tractograms.dps_and_dpp_management import (add_data_as_color_dpp, + project_dpp_to_map) +from scilpy.version import version_string +from scilpy.viz.color import get_lookup_table + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_tractogram', + help='Path of the input tractogram file (trk or tck).') + p.add_argument('in_peaks', + help='Path of the input peaks file.') + p.add_argument('out_tractogram', + help='Path of the output tractogram file (trk or tck).') + p.add_argument('--dpp_key', default="AE", + help="Name of the dpp key containg the AE in the output. " + "Default: AE") + + g = p.add_argument_group("Optional outputs") + g.add_argument('--save_as_color', action='store_true', + help="Save the AE as a color. Colors will range between " + "black (0) and yellow (--cmax_max) \n" + "See also scil_tractogram_assign_custom_color, option " + "--use_dpp.") + g.add_argument('--save_mean_map', metavar='filename', + help="If set, save the mean value of each streamline per " + "voxel. Name of the map file (nifti).\n" + "See also scil_tractogram_project_streamlines_to_map.") + g.add_argument('--save_worst', metavar='filename', + help="If set, save the worst streamlines in a separate " + "tractogram.") + + g = p.add_argument_group("Processing options") + g.add_argument('--cmap_max', nargs='?', const=180, + help="If set, the maximum color on the colormap (yellow) " + "will be associated \nto this value. If not set, the " + "maxium value found in the data will be used instead. " + "Default if set: 180 degrees.") + + add_processes_arg(p) + add_verbose_arg(p) + add_overwrite_arg(p) + add_bbox_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + # -- Verifications + args.reference = args.in_peaks + assert_inputs_exist(parser, [args.in_tractogram, args.in_peaks], + args.reference) + assert_headers_compatible(parser, [args.in_tractogram, args.in_peaks], [], + args.reference) + assert_outputs_exist(parser, args, args.out_tractogram, + [args.save_mean_map, args.save_worst]) + + # -- Loading + peaks = nib.load(args.in_peaks).get_fdata() + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + logging.info("Loaded data") + + # Removing invalid + len_before = len(sft) + sft.remove_invalid_streamlines() + len_after = len(sft) + if len_before != len_after: + logging.warning("Removed {} invalid streamlines before processing" + .format(len_before - len_after)) + + # Verify if the key already exists + if args.dpp_key in sft.data_per_point.keys() and not args.overwrite: + parser.error("--dpp_key already exists. Use --overwrite to proceed.") + if (args.save_as_color and 'color' in sft.data_per_point.keys() and + not args.overwrite): + parser.error("The 'color' dpp already exists. Use --overwrite to " + "proceed.") + + # -- Processing + ae = compute_ae(sft, peaks, nb_processes=args.nbr_processes) + + # Printing stats + stacked_ae = np.hstack(ae) + mean_ae = np.mean(stacked_ae) + std_ae = np.std(stacked_ae) + min_ae = np.min(stacked_ae) + max_ae = np.max(stacked_ae) + logging.info("AE computed. Some statistics:\n" + "- Mean AE: {} +- {} \n" + "- Range:[{}, {}]".format(mean_ae, std_ae, min_ae, max_ae)) + + # Add as dpp + ae_dpp = [ae_s[:, None] for ae_s in ae] + sft.data_per_point[args.dpp_key] = ae_dpp + + # Add as color (optional) + if args.save_as_color: + max_cmap = args.cmap_max if args.cmap_max is not None \ + else np.max(stacked_ae) + logging.info("Saving colors. The maxium color is assiociated to " + "value {}".format(max_cmap)) + + cmap = get_lookup_table('jet') + sft, _, _ = add_data_as_color_dpp(sft, cmap, stacked_ae, + min_cmap=0, max_cmap=max_cmap) + + # -- Saving + logging.info("Saving file {}.".format(args.out_tractogram)) + save_tractogram(sft, args.out_tractogram, no_empty=False) + + # Save map (optional) + if args.save_mean_map is not None: + logging.info("Preparing map.") + the_map = project_dpp_to_map(sft, args.dpp_key) + + logging.info("Saving file {}".format(args.save_mean_map)) + nib.save(nib.Nifti1Image(the_map, sft.affine), args.save_mean_map) + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/scil_tractogram_extract_streamlines.py b/src/scilpy/cli/scil_tractogram_extract_streamlines.py new file mode 100644 index 000000000..f51b0ece4 --- /dev/null +++ b/src/scilpy/cli/scil_tractogram_extract_streamlines.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Extract some streamlines from chosen criterion based on streamlines' dpp +(data_per_point) or dps (data_per_streamline). + +See also: + - To modify your dpp / dps values: see scil_tractogram_dpp_math and + scil_tractogram_dps_math. + - To extract streamlines based on regions of interest (ROI), see + scil_tractogram_segment_with_ROI. + - To extract U-shaped streamlines, see scil_tractogram_extract_ushape +""" + +import argparse +import logging + +import nibabel as nib +import numpy as np + +from scilpy.io.streamlines import (load_tractogram_with_reference, + save_tractogram) +from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg, add_reference_arg, + add_verbose_arg, assert_inputs_exist, + assert_outputs_exist, ranged_type) +from scilpy.version import version_string + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_tractogram', + help='Path of the input tractogram file (trk or tck).') + p.add_argument('out_tractogram', + help='Path of the output tractogram file (trk or tck).') + p.add_argument('--no_empty', action='store_true', + help="Do not save the output tractogram if no streamline " + "fit the criterion.") + + g = p.add_argument_group("Criterion's data") + gg = g.add_mutually_exclusive_group(required=True) + gg.add_argument("--from_dps", metavar='dpp_key', + help="Use DPS as criteria.") + gg.add_argument("--from_dpp", metavar='dpp_key', + help="Use DPP as criteria. Uses the average value over " + "each streamline.") + + g = p.add_argument_group("Direction of the criterion:") + gg = g.add_mutually_exclusive_group(required=True) + gg.add_argument('--top', action='store_true', + help="If set, selects a portion of streamlines that has " + "the highest value in its dps \nor mean dpp.") + gg.add_argument('--bottom', action='store_true', + help="If set, selects a portion of streamlines that has " + "the lowest value in its dps \nor mean dpp.") + gg.add_argument('--center', action='store_true', + help="Selects the average streamlines.") + + + g = p.add_argument_group("Criterion") + gg = g.add_mutually_exclusive_group(required=True) + gg.add_argument( + '--nb', type=int, + help="Selects a chosen number of streamlines.") + gg.add_argument( + '--percent', type=ranged_type(float, 0, 100), const=5, nargs='?', + help="Saves the streamlines in the top / lowest percentile.\n" + "Default if set: The top / bottom 5%%") + gg.add_argument( + '--mean_std', type=int, const=3, nargs='?', dest='std', + help="Saves the streamlines with value above mean + N*std (option " + "--top), below \nmean - N*std (option --below) or in the " + "range [mean - N*std, mean + N*std] \n(option --center)." + "Default if set: uses mean +- 3std.") + + add_verbose_arg(p) + add_overwrite_arg(p) + add_bbox_arg(p) + add_reference_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + # -- Verifications + assert_inputs_exist(parser, args.in_tractogram, args.reference) + assert_outputs_exist(parser, args, args.out_tractogram) + + # -- Loading + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + logging.info("Loaded data") + + # Verify if the key already exists + if (args.from_dpp is not None and + args.from_dpp not in sft.data_per_point.keys()): + parser.error("dpp key not found") + if (args.from_dps is not None and + args.from_dps not in sft.data_per_streamline.keys()): + parser.error("dps key not found") + + if args.from_dps is not None: + data = sft.data_per_streamline[args.from_dps] + data = [np.squeeze(data_s) for data_s in data] + if len(data[0]).shape > 1: + parser.error( + "Script not ready to deal with dps of more than one value per " + "streamline. Use scil_tractogram_dps_math to modify your data.") + else: + data = sft.data_per_point[args.from_dpp] + if len(np.squeeze(data[0][0, :]).shape) > 1: + parser.error( + "Script not ready to deal with dpp of more than one value per " + "point. Use scil_tractogram_dpp_math to modify your data.") + data = [np.mean(np.squeeze(data_s)) for data_s in data] + + nb_init = len(sft) + + if args.percent or args.nb: + ordered_ind = np.argsort(data) + + if args.percent: + nb_streamlines = int(args.percent / 100.0 * len(sft)) + percent = args.percent + else: + nb_streamlines = args.nb + percent = np.round(nb_streamlines / len(sft) * 100, decimals=3) + + if args.top: + ind = ordered_ind[-nb_streamlines:] + logging.info("Saving {}/{} streamlines; the top {}% of " + "streamlines." + .format(len(ind), nb_init, percent)) + elif args.bottom: + ind = ordered_ind[0:nb_streamlines] + logging.info("Saving {}/{} streamlines; the bottom {}% of " + "streamlines." + .format(len(ind), nb_init, percent)) + else: # args.center + half_remains = int((len(sft) - nb_streamlines) / 2) + ind = ordered_ind[half_remains:-half_remains] + logging.info("Saving {}/{} streamlines; the middle {}% of " + "streamlines." + .format(len(ind), nb_init, percent)) + + else: # Using mean +- STD + mean = np.mean(data) + std = np.std(data) + + if args.top: + limit = mean + args.std * std + ind = data >= limit + logging.info("Number of streamlines above mean + {}std limit: {}" + .format(args.std, sum(ind))) + elif args.bottom: + limit = mean - args.std * std + ind = data <= limit + logging.info("Number of streamlines below mean - {}std limit: {}" + .format(args.std, sum(ind))) + else: # args.center + limit1 = mean - args.std * std + limit2 = mean + args.std * std + ind = np.logical_and(data > limit1, data < limit2) + logging.info("Number of streamlines in the range mean +- {}std: {}" + .format(args.std, sum(ind))) + + sft = sft[ind] + save_tractogram(sft, args.out_tractogram, no_empty=args.no_empty) + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/tests/test_tractogram_compute_ae.py b/src/scilpy/cli/tests/test_tractogram_compute_ae.py new file mode 100644 index 000000000..400668a62 --- /dev/null +++ b/src/scilpy/cli/tests/test_tractogram_compute_ae.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['commit_amico.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_tractogram_compute_ae', + '--help']) + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'commit_amico', 'tracking.trk') + in_peaks = os.path.join(SCILPY_HOME, 'commit_amico', 'peaks.nii.gz') + + ret = script_runner.run(['scil_tractogram_compute_ae', in_bundle, in_peaks, + 'out_bundle.trk', '--dpp_key', 'AE', + '--save_mean_map', 'out_map.nii.gz', + '--save_as_color', '--processes', '4']) + assert ret.success \ No newline at end of file diff --git a/src/scilpy/cli/tests/test_tractogram_extract_streamlines.py b/src/scilpy/cli/tests/test_tractogram_extract_streamlines.py new file mode 100644 index 000000000..37b66c151 --- /dev/null +++ b/src/scilpy/cli/tests/test_tractogram_extract_streamlines.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_tractogram_extract_streamlines', + '--help']) + assert ret.success + +def _create_tractogram_with_dpp(script_runner): + """ + Copied this code from test_tractogram_projects_streamlines_to_map + ToDo: Add a tractogram with dpp in our test data. + """ + in_bundle = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM_uni.trk') + in_mni = os.path.join(SCILPY_HOME, 'tractometry', 'mni_masked.nii.gz') + in_bundle_with_dpp = 'IFGWM_uni_with_dpp.trk' + + # Create our test data with dpp: add metrics as dpp. + # Or get a tractogram that already as some dpp in the test data. + script_runner.run(['scil_tractogram_project_map_to_streamlines', + in_bundle, in_bundle_with_dpp, '-f', + '--in_maps', in_mni, '--out_dpp_name', 'some_metric']) + + +def test_from_dpp(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # Creating the data + # Could eventually split this test into many tests, but we want to create + # the data only once + _create_tractogram_with_dpp(script_runner) + in_bundle_with_dpp = 'IFGWM_uni_with_dpp.trk' + + # From NB, top + ret = script_runner.run(['scil_tractogram_extract_streamlines', + in_bundle_with_dpp, 'out_200_top.trk', + '--from_dpp', 'some_metric', + '--top', '--nb', 200]) + assert ret.success + + # From NB, center + ret = script_runner.run(['scil_tractogram_extract_streamlines', + in_bundle_with_dpp, 'out_200_middle.trk', + '--from_dpp', 'some_metric', + '--center', '--nb', 200]) + assert ret.success + + # From Percent, bottom + ret = script_runner.run(['scil_tractogram_extract_streamlines', + in_bundle_with_dpp, 'out_5percent_bottom.trk', + '--from_dpp', 'some_metric', + '--bottom', '--percent', 5]) + + assert ret.success + + # From mean + std, center + ret = script_runner.run(['scil_tractogram_extract_streamlines', + in_bundle_with_dpp, 'out_middle_std.trk', + '--from_dpp', 'some_metric', + '--center', '--mean_std', 3]) + + assert ret.success + + +def from_dps(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # Creating the data. We create a file with dpp and then average it over + # streamlines as dps. + # Could eventually split this test into many tests, but we want to create + # the data only once + _create_tractogram_with_dpp(script_runner) + in_bundle_with_dpp = 'IFGWM_uni_with_dpp.trk' + in_bundle_with_dps = 'IFGWM_uni_with_dps.trk' + script_runner.run(['scil_tractogram_dpp_math', 'MEAN', in_bundle_with_dpp, + in_bundle_with_dps, '--mode', 'dps', + '--in_dpp_name', 'some_metric', + '--out_keys', 'mean_dpp']) + + # No need to retest all options. + # From NB, top + ret = script_runner.run(['scil_tractogram_extract_streamlines', + in_bundle_with_dps, 'out_200_top.trk', + '--from_dps', 'mean_dpp', + '--top', '--nb', 200]) + assert ret.success diff --git a/src/scilpy/io/streamlines.py b/src/scilpy/io/streamlines.py index 0fc196f6a..09aed6a77 100644 --- a/src/scilpy/io/streamlines.py +++ b/src/scilpy/io/streamlines.py @@ -120,6 +120,21 @@ def load_tractogram_with_reference(parser, args, filepath, arg_name=None): def save_tractogram(sft, filename, no_empty, bbox_valid_check=True): + """ + Save tractogram. If no_empty and the tractogram has 0 streamlines, won't + save. + + Parameters + ---------- + sft: StatefulTractogram + The Tractogram + filename: str + Where to save. Filename with valid extension. + no_empty: bool + Wether saving empty files is allowed or not. + bbox_valid_check: bool + Verify if streamlines are in the bounding box. Default: True. + """ if len(sft.streamlines) == 0 and no_empty: logging.info("The file {} won't be written (0 streamlines)" .format(filename)) diff --git a/src/scilpy/tractanalysis/scoring.py b/src/scilpy/tractanalysis/scoring.py index 0f42af9ba..b15aa4eae 100644 --- a/src/scilpy/tractanalysis/scoring.py +++ b/src/scilpy/tractanalysis/scoring.py @@ -13,7 +13,7 @@ - Optional: - WPC: wrong path connections, streamlines connecting correct ROIs but not respecting the other criteria for that bundle. Such streamlines always - exist but they are only saved separately if specified in the options. + exist but they are only saved separately if specified{} in the options. Else, they are merged back with the IS. By definition. WPC are only computed if "limits masks" are provided. - IC: invalid connections, streamlines joining an incorrect combination of @@ -38,14 +38,145 @@ import logging +from multiprocessing import Pool import numpy as np +from tqdm import tqdm from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map - from scilpy.tractograms.streamline_and_mask_operations import \ get_endpoints_density_map +def _compute_ae(args): + """Used for multiprocessing in compute_ae below.""" + process_id, start, end, dirs, coords, peaks = args + + ae_chunk = np.zeros(end - start) + nb_nan_chunk = 0 + + # Hiding the segment number in the tqdm bar, not meaningful + format = '{desc}:{percentage:3.0f}%|{bar}| ' + format += '[{elapsed}<{remaining}, {rate_fmt}{postfix}' + + for chunk_i, i in enumerate(tqdm(range(start, end), ncols=120, + bar_format=format, leave=False, + desc="Process {}".format(process_id + 1), + position=process_id+1)): + current_peaks = peaks[coords[i][0], coords[i][1], coords[i][2], :, :] + + # Using only non-zero peaks. Dealing with buggy voxels: setting AE to 0 + current_peaks = current_peaks[np.any(current_peaks!=0, axis=-1)] + if current_peaks.size == 0: + nb_nan_chunk += 1 + ae_chunk[chunk_i] = 0 + continue + + # Using the abs value because vectors are undirected. + cos_theta = np.abs(np.dot(current_peaks, dirs[i])) + cos_theta = np.clip(cos_theta, -1.0, 1.0) # numerical safety + theta = np.rad2deg(np.arccos(cos_theta)) + ae_chunk[chunk_i] = np.min(theta) + + return start, ae_chunk, nb_nan_chunk + + +def compute_ae(sft, peaks, nb_processes=1): + """ + Computing the angular error for each segment. The direction is compared + with the underlying peak (for single peak files like DTI) or with the + closest peak (ex, with fODF peaks). Currently, interpolation is not + supported: peaks of the closest voxel are used (nearest neighbor). AE is + computed as the cosine difference. + + Parameters + ---------- + sft: StatefulTractogram + The tractogram + peaks: np.array of shape [x, y, z, nb_peaks, 3]. + The peaks. + nb_processes: int + To use multiprocessing + + Returns + ------- + ae: list[np.array] + The angular error for each streamline, in degrees. The last point of + each streamline has an AE of zero. + """ + # If there is only one peak, make sure we still have a 4th dimension = 1. + peaks = peaks.reshape(peaks.shape[:3] + (-1, 3)) + + if peaks.shape[3] == 1: + multi_peaks = False + logging.info("Peaks seem to be single-peaks (DTI, probably). Simple " + "alignment measure.") + else: + multi_peaks = True + logging.info("Peaks seem to be multi-peaks (maybe coming from ODF, " + "fODF, etc). Will verify alignment with closest peak.") + + # Sending sft to vox space, corner origin. Then nearest neighbor + # interpolation is just the floor. + previous_space = sft.space + previous_origin = sft.origin + sft.to_vox() + sft.to_corner() + + # Fixing peaks shape and normalizing + _norm = np.linalg.norm(peaks, axis=-1, keepdims=True) + _norm[_norm == 0] = 1 # Making sure we don't divide by 0 + peaks = peaks / _norm + + # Getting segments and normalizing. Concatenating. + dirs = [np.diff(s, axis=0) for s in sft.streamlines] + dirs = np.vstack(dirs) + dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True) + + # Concatenating streamlines for faster processing + nearest neighbor + coords = np.floor(np.vstack([s[1:] for s in sft.streamlines])).astype(int) + + # Preparing multiprocessing + nb_items = len(coords) + chunk_size = (nb_items + nb_processes - 1) // nb_processes + split_indices = [(i, min(i + chunk_size, nb_items)) + for i in range(0, nb_items, chunk_size)] + + # Finding the angular difference with the closest peak + ae = np.zeros(len(coords)) + nb_nan = 0 + with Pool(processes=nb_processes) as pool: + for start, ae_chunk, nb_nan_chunk in pool.imap_unordered( + _compute_ae, [(i, start, end, dirs, coords, peaks) + for i, (start, end) in enumerate(split_indices)]): + ae[start:start + len(ae_chunk)] = ae_chunk + nb_nan += nb_nan_chunk + + print(' ') # Required because finishing sub-processes' tqdm is flaky + if nb_nan > 0: + msg = "AE in these voxels was set to 0. Total number of segments " + \ + "traversing these voxels: {} /{} .".format(nb_nan, nb_items) + if multi_peaks: + logging.warning("Some voxels had 0 valid peaks out of the {} " + "possible peaks (they were all [0,0,0]). " + .format(peaks.shape[3]) + msg) + else: + logging.warning("Invalid peaks ([0,0,0]) were found in some " + "voxels. " + msg) + + # Split back streamlines + lengths = [len(s) - 1 for s in sft.streamlines] + ae = np.split(ae, np.cumsum(lengths)[:-1]) + + # Add value 0 as the last value of each streamline + ae = [np.append(line_ae, 0) for line_ae in ae] + + # Sending back to previous space + sft.to_space(previous_space) + sft.to_origin(previous_origin) + + return ae + + def compute_f1_score(overlap, overreach): """ Compute the F1 score between overlap and overreach (they must be