Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ scil_tractogram_assign_uniform_color = "scilpy.cli.scil_tractogram_assign_unifo
scil_tractogram_commit = "scilpy.cli.scil_tractogram_commit:main"
scil_tractogram_compress = "scilpy.cli.scil_tractogram_compress:main"
scil_tractogram_compute_density_map = "scilpy.cli.scil_tractogram_compute_density_map:main"
scil_tractogram_compute_ae = "scilpy.cli.scil_tractogram_compute_ae:main"
scil_tractogram_compute_TODI = "scilpy.cli.scil_tractogram_compute_TODI:main"
scil_tractogram_convert_hdf5_to_trk = "scilpy.cli.scil_tractogram_convert_hdf5_to_trk:main"
scil_tractogram_convert = "scilpy.cli.scil_tractogram_convert:main"
Expand All @@ -233,6 +234,7 @@ scil_tractogram_cut_streamlines = "scilpy.cli.scil_tractogram_cut_streamlines:m
scil_tractogram_detect_loops = "scilpy.cli.scil_tractogram_detect_loops:main"
scil_tractogram_dpp_math = "scilpy.cli.scil_tractogram_dpp_math:main"
scil_tractogram_dps_math = "scilpy.cli.scil_tractogram_dps_math:main"
scil_tractogram_extract_streamlines = "scilpy.cli.scil_tractogram_extract_streamlines:main"
scil_tractogram_extract_ushape = "scilpy.cli.scil_tractogram_extract_ushape:main"
scil_tractogram_filter_by_anatomy = "scilpy.cli.scil_tractogram_filter_by_anatomy:main"
scil_tractogram_filter_by_length = "scilpy.cli.scil_tractogram_filter_by_length:main"
Expand Down
157 changes: 157 additions & 0 deletions src/scilpy/cli/scil_tractogram_compute_ae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compute the angular error (AE) for each segment of the streamlines.

For each segment of each streamline, the direction is compared with the
underlying peak (for single peak files like DTI) or with the closest peak
(ex, with fODF peaks). Currently, interpolation is not supported: peaks of
the closest voxel are used (nearest neighbor). AE is computed as the cosine
difference.

The ae is added as data_per_point (dpp) for each segment, using the last point
of the segment. The last point of each streamline has an AE of zero.
Optionnally, you may also save it as a color.
"""

import argparse
import logging

import nibabel as nib
import numpy as np

from scilpy.io.streamlines import (load_tractogram_with_reference,
save_tractogram)
from scilpy.io.utils import (add_processes_arg, add_verbose_arg,
add_overwrite_arg, assert_headers_compatible,
assert_inputs_exist, assert_outputs_exist,
add_bbox_arg)
from scilpy.tractanalysis.scoring import compute_ae
from scilpy.tractograms.dps_and_dpp_management import (add_data_as_color_dpp,
project_dpp_to_map)
from scilpy.version import version_string
from scilpy.viz.color import get_lookup_table


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

p.add_argument('in_tractogram',
help='Path of the input tractogram file (trk or tck).')
p.add_argument('in_peaks',
help='Path of the input peaks file.')
p.add_argument('out_tractogram',
help='Path of the output tractogram file (trk or tck).')
p.add_argument('--dpp_key', default="AE",
help="Name of the dpp key containg the AE in the output. "
"Default: AE")

g = p.add_argument_group("Optional outputs")
g.add_argument('--save_as_color', action='store_true',
help="Save the AE as a color. Colors will range between "
"black (0) and yellow (--cmax_max) \n"
"See also scil_tractogram_assign_custom_color, option "
"--use_dpp.")
g.add_argument('--save_mean_map', metavar='filename',
help="If set, save the mean value of each streamline per "
"voxel. Name of the map file (nifti).\n"
"See also scil_tractogram_project_streamlines_to_map.")
g.add_argument('--save_worst', metavar='filename',
help="If set, save the worst streamlines in a separate "
"tractogram.")

g = p.add_argument_group("Processing options")
g.add_argument('--cmap_max', nargs='?', const=180,
help="If set, the maximum color on the colormap (yellow) "
"will be associated \nto this value. If not set, the "
"maxium value found in the data will be used instead. "
"Default if set: 180 degrees.")

add_processes_arg(p)
add_verbose_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

# -- Verifications
args.reference = args.in_peaks
assert_inputs_exist(parser, [args.in_tractogram, args.in_peaks],
args.reference)
assert_headers_compatible(parser, [args.in_tractogram, args.in_peaks], [],
args.reference)
assert_outputs_exist(parser, args, args.out_tractogram,
[args.save_mean_map, args.save_worst])

# -- Loading
peaks = nib.load(args.in_peaks).get_fdata()
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
logging.info("Loaded data")

# Removing invalid
len_before = len(sft)
sft.remove_invalid_streamlines()
len_after = len(sft)
if len_before != len_after:
logging.warning("Removed {} invalid streamlines before processing"
.format(len_before - len_after))

# Verify if the key already exists
if args.dpp_key in sft.data_per_point.keys() and not args.overwrite:
parser.error("--dpp_key already exists. Use --overwrite to proceed.")
if (args.save_as_color and 'color' in sft.data_per_point.keys() and
not args.overwrite):
parser.error("The 'color' dpp already exists. Use --overwrite to "
"proceed.")

# -- Processing
ae = compute_ae(sft, peaks, nb_processes=args.nbr_processes)

# Printing stats
stacked_ae = np.hstack(ae)
mean_ae = np.mean(stacked_ae)
std_ae = np.std(stacked_ae)
min_ae = np.min(stacked_ae)
max_ae = np.max(stacked_ae)
logging.info("AE computed. Some statistics:\n"
"- Mean AE: {} +- {} \n"
"- Range:[{}, {}]".format(mean_ae, std_ae, min_ae, max_ae))

# Add as dpp
ae_dpp = [ae_s[:, None] for ae_s in ae]
sft.data_per_point[args.dpp_key] = ae_dpp

# Add as color (optional)
if args.save_as_color:
max_cmap = args.cmap_max if args.cmap_max is not None \
else np.max(stacked_ae)
logging.info("Saving colors. The maxium color is assiociated to "
"value {}".format(max_cmap))

cmap = get_lookup_table('jet')
sft, _, _ = add_data_as_color_dpp(sft, cmap, stacked_ae,
min_cmap=0, max_cmap=max_cmap)

# -- Saving
logging.info("Saving file {}.".format(args.out_tractogram))
save_tractogram(sft, args.out_tractogram, no_empty=False)

# Save map (optional)
if args.save_mean_map is not None:
logging.info("Preparing map.")
the_map = project_dpp_to_map(sft, args.dpp_key)

logging.info("Saving file {}".format(args.save_mean_map))
nib.save(nib.Nifti1Image(the_map, sft.affine), args.save_mean_map)


if __name__ == "__main__":
main()
177 changes: 177 additions & 0 deletions src/scilpy/cli/scil_tractogram_extract_streamlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Extract some streamlines from chosen criterion based on streamlines' dpp
(data_per_point) or dps (data_per_streamline).

See also:
- To modify your dpp / dps values: see scil_tractogram_dpp_math and
scil_tractogram_dps_math.
- To extract streamlines based on regions of interest (ROI), see
scil_tractogram_segment_with_ROI.
- To extract U-shaped streamlines, see scil_tractogram_extract_ushape
"""

import argparse
import logging

import nibabel as nib
import numpy as np

from scilpy.io.streamlines import (load_tractogram_with_reference,
save_tractogram)
from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg, add_reference_arg,
add_verbose_arg, assert_inputs_exist,
assert_outputs_exist, ranged_type)
from scilpy.version import version_string


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

p.add_argument('in_tractogram',
help='Path of the input tractogram file (trk or tck).')
p.add_argument('out_tractogram',
help='Path of the output tractogram file (trk or tck).')
p.add_argument('--no_empty', action='store_true',
help="Do not save the output tractogram if no streamline "
"fit the criterion.")

g = p.add_argument_group("Criterion's data")
gg = g.add_mutually_exclusive_group(required=True)
gg.add_argument("--from_dps", metavar='dpp_key',
help="Use DPS as criteria.")
gg.add_argument("--from_dpp", metavar='dpp_key',
help="Use DPP as criteria. Uses the average value over "
"each streamline.")

g = p.add_argument_group("Direction of the criterion:")
gg = g.add_mutually_exclusive_group(required=True)
gg.add_argument('--top', action='store_true',
help="If set, selects a portion of streamlines that has "
"the highest value in its dps \nor mean dpp.")
gg.add_argument('--bottom', action='store_true',
help="If set, selects a portion of streamlines that has "
"the lowest value in its dps \nor mean dpp.")
gg.add_argument('--center', action='store_true',
help="Selects the average streamlines.")


g = p.add_argument_group("Criterion")
gg = g.add_mutually_exclusive_group(required=True)
gg.add_argument(
'--nb', type=int,
help="Selects a chosen number of streamlines.")
gg.add_argument(
'--percent', type=ranged_type(float, 0, 100), const=5, nargs='?',
help="Saves the streamlines in the top / lowest percentile.\n"
"Default if set: The top / bottom 5%%")
gg.add_argument(
'--mean_std', type=int, const=3, nargs='?', dest='std',
help="Saves the streamlines with value above mean + N*std (option "
"--top), below \nmean - N*std (option --below) or in the "
"range [mean - N*std, mean + N*std] \n(option --center)."
"Default if set: uses mean +- 3std.")

add_verbose_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)
add_reference_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

# -- Verifications
assert_inputs_exist(parser, args.in_tractogram, args.reference)
assert_outputs_exist(parser, args, args.out_tractogram)

# -- Loading
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
logging.info("Loaded data")

# Verify if the key already exists
if (args.from_dpp is not None and
args.from_dpp not in sft.data_per_point.keys()):
parser.error("dpp key not found")
if (args.from_dps is not None and
args.from_dps not in sft.data_per_streamline.keys()):
parser.error("dps key not found")

if args.from_dps is not None:
data = sft.data_per_streamline[args.from_dps]
data = [np.squeeze(data_s) for data_s in data]
if len(data[0]).shape > 1:
parser.error(
"Script not ready to deal with dps of more than one value per "
"streamline. Use scil_tractogram_dps_math to modify your data.")
else:
data = sft.data_per_point[args.from_dpp]
if len(np.squeeze(data[0][0, :]).shape) > 1:
parser.error(
"Script not ready to deal with dpp of more than one value per "
"point. Use scil_tractogram_dpp_math to modify your data.")
data = [np.mean(np.squeeze(data_s)) for data_s in data]

nb_init = len(sft)

if args.percent or args.nb:
ordered_ind = np.argsort(data)

if args.percent:
nb_streamlines = int(args.percent / 100.0 * len(sft))
percent = args.percent
else:
nb_streamlines = args.nb
percent = np.round(nb_streamlines / len(sft) * 100, decimals=3)

if args.top:
ind = ordered_ind[-nb_streamlines:]
logging.info("Saving {}/{} streamlines; the top {}% of "
"streamlines."
.format(len(ind), nb_init, percent))
elif args.bottom:
ind = ordered_ind[0:nb_streamlines]
logging.info("Saving {}/{} streamlines; the bottom {}% of "
"streamlines."
.format(len(ind), nb_init, percent))
else: # args.center
half_remains = int((len(sft) - nb_streamlines) / 2)
ind = ordered_ind[half_remains:-half_remains]
logging.info("Saving {}/{} streamlines; the middle {}% of "
"streamlines."
.format(len(ind), nb_init, percent))

else: # Using mean +- STD
mean = np.mean(data)
std = np.std(data)

if args.top:
limit = mean + args.std * std
ind = data >= limit
logging.info("Number of streamlines above mean + {}std limit: {}"
.format(args.std, sum(ind)))
elif args.bottom:
limit = mean - args.std * std
ind = data <= limit
logging.info("Number of streamlines below mean - {}std limit: {}"
.format(args.std, sum(ind)))
else: # args.center
limit1 = mean - args.std * std
limit2 = mean + args.std * std
ind = np.logical_and(data > limit1, data < limit2)
logging.info("Number of streamlines in the range mean +- {}std: {}"
.format(args.std, sum(ind)))

sft = sft[ind]
save_tractogram(sft, args.out_tractogram, no_empty=args.no_empty)


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions src/scilpy/cli/tests/test_tractogram_compute_ae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import tempfile

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['commit_amico.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run(['scil_tractogram_compute_ae',
'--help'])
assert ret.success


def test_execution(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_bundle = os.path.join(SCILPY_HOME, 'commit_amico', 'tracking.trk')
in_peaks = os.path.join(SCILPY_HOME, 'commit_amico', 'peaks.nii.gz')

ret = script_runner.run(['scil_tractogram_compute_ae', in_bundle, in_peaks,
'out_bundle.trk', '--dpp_key', 'AE',
'--save_mean_map', 'out_map.nii.gz',
'--save_as_color', '--processes', '4'])
assert ret.success
Loading