diff --git a/.github/workflows/test-ml.yml b/.github/workflows/test-ml.yml index 32fafdf59..73930f259 100644 --- a/.github/workflows/test-ml.yml +++ b/.github/workflows/test-ml.yml @@ -53,7 +53,8 @@ jobs: libfreetype6-dev \ libdrm-dev \ libgl1-mesa-dev \ - libosmesa6-dev + libosmesa6-dev \ + python3-dev \ - name: stdlib checkout if: ${{ !contains(steps.python-selector.outputs.python-version, '3.12') }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95e58fa59..b2ba9af96 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,8 @@ jobs: libfreetype6-dev \ libdrm-dev \ libgl1-mesa-dev \ - libosmesa6-dev + libosmesa6-dev \ + python3-dev \ - name: stdlib checkout if: ${{ !contains(steps.python-selector.outputs.python-version, '3.12') }} diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index a01e1b3f0..3593ce6fb 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -50,8 +50,9 @@ jobs: libfreetype6-dev \ libdrm-dev \ libgl1-mesa-dev \ - libosmesa6-dev \ - wget + libosmesa6-dev \ + python3-dev \ + wget - name: stdlib checkout if: ${{ !contains(steps.python-selector.outputs.python-version, '3.12') }} diff --git a/docs/source/_static/images/scilpy_paper_figure1.png b/docs/source/_static/images/scilpy_paper_figure1.png index efb820c33..d07e176df 100644 Binary files a/docs/source/_static/images/scilpy_paper_figure1.png and b/docs/source/_static/images/scilpy_paper_figure1.png differ diff --git a/docs/source/_static/images/scilpy_paper_figure2.png b/docs/source/_static/images/scilpy_paper_figure2.png index fcd10da8c..7df5d1e38 100644 Binary files a/docs/source/_static/images/scilpy_paper_figure2.png and b/docs/source/_static/images/scilpy_paper_figure2.png differ diff --git a/docs/source/_static/images/scilpy_paper_figure3.png b/docs/source/_static/images/scilpy_paper_figure3.png index f869424bf..6f2be4568 100644 Binary files a/docs/source/_static/images/scilpy_paper_figure3.png and b/docs/source/_static/images/scilpy_paper_figure3.png differ diff --git a/docs/source/_static/images/scilpy_paper_figure4.png b/docs/source/_static/images/scilpy_paper_figure4.png index 25df5729f..b3909f60f 100644 Binary files a/docs/source/_static/images/scilpy_paper_figure4.png and b/docs/source/_static/images/scilpy_paper_figure4.png differ diff --git a/docs/source/_static/images/scilpy_paper_figure5.png b/docs/source/_static/images/scilpy_paper_figure5.png index 48838a0ac..d42416dee 100644 Binary files a/docs/source/_static/images/scilpy_paper_figure5.png and b/docs/source/_static/images/scilpy_paper_figure5.png differ diff --git a/pyproject.toml b/pyproject.toml index c861e62d3..902734cfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "numpy==1.26.*", "openpyxl==3.1.*", "packaging==24.*", +"pathspec==0.12.*", "pybids==0.18.*", "PyMCubes==0.1.*", "pyparsing==3.2.*", @@ -274,6 +275,7 @@ scil_viz_volume_screenshot_mosaic = "scilpy.cli.scil_viz_volume_screenshot_mosa scil_viz_volume_screenshot = "scilpy.cli.scil_viz_volume_screenshot:main" scil_volume_apply_transform = "scilpy.cli.scil_volume_apply_transform:main" scil_volume_b0_synthesis = "scilpy.cli.scil_volume_b0_synthesis:main" +scil_volume_modify_voxel_order = "scilpy.cli.scil_volume_modify_voxel_order:main" scil_volume_count_non_zero_voxels = "scilpy.cli.scil_volume_count_non_zero_voxels:main" scil_volume_crop = "scilpy.cli.scil_volume_crop:main" scil_volume_distance_map = "scilpy.cli.scil_volume_distance_map:main" diff --git a/src/scilpy/cli/scil_bundle_explore_bundleseg.py b/src/scilpy/cli/scil_bundle_explore_bundleseg.py index 0973abbbc..b96461b7c 100755 --- a/src/scilpy/cli/scil_bundle_explore_bundleseg.py +++ b/src/scilpy/cli/scil_bundle_explore_bundleseg.py @@ -478,11 +478,16 @@ def main(): offset = 0 count = 0 for bundle in mapping.keys(): - filename = glob.glob(f'{os.path.join(args.in_folder, bundle)}.t?k')[0] - - if not os.path.exists(filename): - logging.warning(f'File {filename} not found.') + files = glob.glob(f'{os.path.join(args.in_folder, bundle)}.t?k') + if len(files) == 0: + logging.warning("Could not find any file fitting pattern {}" + .format(os.path.join(args.in_folder, bundle))) continue + elif len(files) > 1: + logging.warning("Found two files for bundle {}. Selecting the " + "first one. Verify your files!".format(bundle)) + + filename = files[0] count += 1 tmp_sft = load_tractogram(filename, ref_img) diff --git a/src/scilpy/cli/scil_tractogram_compute_density_map.py b/src/scilpy/cli/scil_tractogram_compute_density_map.py index b180c6500..e31cd1fc4 100755 --- a/src/scilpy/cli/scil_tractogram_compute_density_map.py +++ b/src/scilpy/cli/scil_tractogram_compute_density_map.py @@ -77,6 +77,7 @@ def main(): transformation, dimensions, _, _ = sft.space_attributes # Processing + logging.info("Computing density map...") if args.endpoints_only: streamline_count = get_endpoints_density_map(sft) else: @@ -84,6 +85,7 @@ def main(): dimensions) # Saving + logging.info("Saving density map {}".format(args.out_img)) dtype_to_use = np.int32 if args.binary is not None: if args.binary == 1: diff --git a/src/scilpy/cli/scil_tractogram_segment_with_bundleseg.py b/src/scilpy/cli/scil_tractogram_segment_with_bundleseg.py index 5f1693520..00fd9da2f 100755 --- a/src/scilpy/cli/scil_tractogram_segment_with_bundleseg.py +++ b/src/scilpy/cli/scil_tractogram_segment_with_bundleseg.py @@ -55,7 +55,8 @@ add_reference_arg, add_verbose_arg, assert_inputs_exist, assert_output_dirs_exist_and_empty, - load_matrix_in_any_format, ranged_type) + load_matrix_in_any_format, ranged_type, + assert_inputs_dirs_exist) from scilpy.segment.voting_scheme import VotingScheme from scilpy.version import version_string @@ -136,6 +137,7 @@ def main(): logging.getLogger().setLevel(logging.getLevelName('INFO')) # Verifications + assert_inputs_dirs_exist(parser, args.in_directory) in_models_directories = [ os.path.join(args.in_directory, x) for x in os.listdir(args.in_directory) diff --git a/src/scilpy/cli/scil_tractogram_segment_with_recobundles.py b/src/scilpy/cli/scil_tractogram_segment_with_recobundles.py index e3e7fca2b..8bd1a3371 100755 --- a/src/scilpy/cli/scil_tractogram_segment_with_recobundles.py +++ b/src/scilpy/cli/scil_tractogram_segment_with_recobundles.py @@ -112,8 +112,8 @@ def main(): if args.tractogram_clustering_thr and args.in_pickle: parser.error("Option --tractogram_clustering_thr should not be " "used with --in_pickle.") - else: - # Setting default value. (Will be ignored in args.in_pickle) + elif args.tractogram_clustering_thr is not None: + # Setting default value. (Will be ignored if args.in_pickle) args.tractogram_clustering_thr = 8.0 # Loading diff --git a/src/scilpy/cli/scil_volume_apply_transform.py b/src/scilpy/cli/scil_volume_apply_transform.py index 52536e46d..e07c454d5 100755 --- a/src/scilpy/cli/scil_volume_apply_transform.py +++ b/src/scilpy/cli/scil_volume_apply_transform.py @@ -17,6 +17,7 @@ from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_verbose_arg, load_matrix_in_any_format) +from scilpy.io.stateful_image import StatefulImage from scilpy.utils.filenames import split_name_with_nii from scilpy.version import version_string @@ -72,15 +73,15 @@ def main(): transfo = load_matrix_in_any_format(args.in_transfo) if args.inverse: transfo = np.linalg.inv(transfo) - moving = nib.load(args.in_file) - reference = nib.load(args.in_target_file) + moving = StatefulImage.load(args.in_file) + reference = StatefulImage.load(args.in_target_file) # Processing, saving warped_img = apply_transform( transfo, reference, moving, keep_dtype=args.keep_dtype, interp=args.interpolation) - nib.save(warped_img, args.out_name) + warped_img.save(args.out_name) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_crop.py b/src/scilpy/cli/scil_volume_crop.py index 8fd311e1e..d87ac387f 100755 --- a/src/scilpy/cli/scil_volume_crop.py +++ b/src/scilpy/cli/scil_volume_crop.py @@ -21,7 +21,6 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.io.utils import (add_overwrite_arg, @@ -32,6 +31,7 @@ from scilpy.image.utils import compute_nifti_bounding_box from scilpy.image.volume_operations import crop_volume from scilpy.version import version_string +from scilpy.io.stateful_image import StatefulImage def _build_arg_parser(): @@ -73,23 +73,23 @@ def main(): assert_inputs_exist(parser, args.in_image, args.input_bbox) assert_outputs_exist(parser, args, args.out_image, args.output_bbox) - img = nib.load(args.in_image) + simg = StatefulImage.load(args.in_image) if args.input_bbox: wbbox = WorldBoundingBox.load(args.input_bbox, args.use_deprecated_pickle) if not args.ignore_voxel_size: - voxel_size = img.header.get_zooms()[0:3] + voxel_size = simg.header.get_zooms()[0:3] if not np.allclose(voxel_size, wbbox.voxel_size[0:3], atol=1e-03): raise IOError("Bounding box and data voxel sizes are not " "compatible. Use option --ignore_voxel_size " "to ignore this test.") else: - wbbox = compute_nifti_bounding_box(img) + wbbox = compute_nifti_bounding_box(simg) if args.output_bbox: wbbox.dump(args.output_bbox, args.use_deprecated_pickle) - out_nifti_file = crop_volume(img, wbbox) - nib.save(out_nifti_file, args.out_image) + out_simg = crop_volume(simg, wbbox) + out_simg.save(args.out_image) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_flip.py b/src/scilpy/cli/scil_volume_flip.py index 4af1dea37..108b5b991 100755 --- a/src/scilpy/cli/scil_volume_flip.py +++ b/src/scilpy/cli/scil_volume_flip.py @@ -1,7 +1,17 @@ #! /usr/bin/env python3 - +# -*- coding: utf-8 -*- """ -Flip the volume according to the specified axis. +Flip the volume according to the specified axis. In this script, axes are +referred to as 'x', 'y' and 'z', but they simply correspond to the first, +second and third dimensions of the data array. + +This script only flips the data array in memory and does not modify the +image's strides or orientation information in the header. It simply +flips the numpy data. + + +In contrast, `scil_volume_modify_voxel_order` modifies the image header's +voxel order, but does not modify the data array. """ import argparse diff --git a/src/scilpy/cli/scil_volume_modify_voxel_order.py b/src/scilpy/cli/scil_volume_modify_voxel_order.py new file mode 100644 index 000000000..5575f5dd8 --- /dev/null +++ b/src/scilpy/cli/scil_volume_modify_voxel_order.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Change the voxel order (strides) of a NIfTI image. + +This script allows you to change the voxel order of a NIfTI image by modifying +its header. The voxel order, also known as strides, defines the orientation of +the image data in memory. This can be useful for compatibility with different +software packages that expect a specific voxel order. +In contrast, `scil_volume_flip` only flips the data array in memory, +without changing the header's orientation information. + +The new voxel order can be specified in several ways: +- As a string of 3 characters, e.g., 'RAS', 'LPS', 'ASR'. +- As a comma-separated string of 3 characters, e.g., 'R,A,S'. +- As a string of 3 numbers, e.g., '123', '231', '-12-3'. +- As a comma-separated string of 3 numbers, e.g., '1,2,3', '-1,2,-3'. + +For numeric input, 1, 2, and 3 correspond to the R, A, and S axes of the +image when loaded in RAS orientation. A negative sign flips the axis. +For example., '-1,2,-3' would correspond to a voxel order of 'LAS'. + +For 4D images, the voxel order must be specified numerically. +e.g., '1,2,3,4' or '1,2,3' (if the 4th dimension is time and does not +need to be reordered). The 4th dimension must be 4 or -4. + +To change the header of a tractogram (.trk), we recommend converting it to a +.tck file, then converting it back to .trk with the target NIfTI image as a +reference. +""" + +import argparse +import logging +import nibabel as nib + +from scilpy.io.utils import (add_overwrite_arg, + add_verbose_arg, + assert_inputs_exist, + assert_outputs_exist) +from scilpy.utils.orientation import parse_voxel_order +from scilpy.io.stateful_image import StatefulImage +from scilpy.version import version_string + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_image', + help='Path of the NIfTI file to modify.') + p.add_argument('out_image', + help='Path of the modified NIfTI file to write.') + p.add_argument('--new_voxel_order', required=True, + help='The new voxel order (e.g., "RAS", "1,2,3").') + + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + assert_inputs_exist(parser, args.in_image) + assert_outputs_exist(parser, args, args.out_image) + + img = nib.load(args.in_image) + simg = StatefulImage.load(args.in_image) + + parsed_voxel_order = parse_voxel_order(args.new_voxel_order, + dimensions=len(img.shape)) + + simg.reorient(parsed_voxel_order) + + nib.save(simg, args.out_image) + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/scil_volume_resample.py b/src/scilpy/cli/scil_volume_resample.py index c4b4ac663..2a4d5b87f 100755 --- a/src/scilpy/cli/scil_volume_resample.py +++ b/src/scilpy/cli/scil_volume_resample.py @@ -24,6 +24,7 @@ assert_inputs_exist, assert_outputs_exist) from scilpy.image.volume_operations import resample_volume from scilpy.version import version_string +from scilpy.io.stateful_image import StatefulImage def _build_arg_parser(): @@ -95,7 +96,7 @@ def main(): logging.info('Loading raw data from %s', args.in_image) - img = nib.load(args.in_image) + simg = StatefulImage.load(args.in_image) ref_img = None if args.ref: @@ -103,10 +104,10 @@ def main(): # Must not verify that headers are compatible. But can verify that, at # least, the first columns of their affines are compatible. - img_zoom_invert = [1 / zoom for zoom in img.header.get_zooms()] + img_zoom_invert = [1 / zoom for zoom in simg.header.get_zooms()] ref_zoom_invert = [1 / zoom for zoom in ref_img.header.get_zooms()] - img_affine = np.dot(img.affine[:3, :3], img_zoom_invert) + img_affine = np.dot(simg.affine[:3, :3], img_zoom_invert) ref_affine = np.dot(ref_img.affine[:3, :3], ref_zoom_invert) if not np.allclose(img_affine, ref_affine): @@ -114,20 +115,20 @@ def main(): "input image (but with a different sampling).") # Resampling volume - resampled_img = resample_volume(img, ref_img=ref_img, - volume_shape=args.volume_size, - iso_min=args.iso_min, - voxel_res=args.voxel_size, - interp=args.interp, - enforce_dimensions=args.enforce_dimensions) + resampled_simg = resample_volume(simg, ref_img=ref_img, + volume_shape=args.volume_size, + iso_min=args.iso_min, + voxel_res=args.voxel_size, + interp=args.interp, + enforce_dimensions=args.enforce_dimensions) # Saving results - zooms = list(resampled_img.header.get_zooms()) + zooms = list(resampled_simg.header.get_zooms()) if args.voxel_size: if len(args.voxel_size) == 1: args.voxel_size = args.voxel_size * 3 - if not np.array_equal(zooms[:3], args.voxel_size): + if not np.allclose(zooms[:3], args.voxel_size, atol=1e-3): logging.warning('Voxel size is different from expected.' ' Got: %s, expected: %s', tuple(zooms), tuple(args.voxel_size)) @@ -137,10 +138,10 @@ def main(): zooms[0] = args.voxel_size[0] zooms[1] = args.voxel_size[1] zooms[2] = args.voxel_size[2] - resampled_img.header.set_zooms(tuple(zooms)) + resampled_simg.header.set_zooms(tuple(zooms)) logging.info('Saving resampled data to %s', args.out_image) - nib.save(resampled_img, args.out_image) + resampled_simg.save(args.out_image) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_volume_reshape.py b/src/scilpy/cli/scil_volume_reshape.py index 7a6cf5df4..597e40f3f 100755 --- a/src/scilpy/cli/scil_volume_reshape.py +++ b/src/scilpy/cli/scil_volume_reshape.py @@ -27,6 +27,7 @@ assert_inputs_exist, assert_outputs_exist) from scilpy.image.volume_operations import reshape_volume from scilpy.version import version_string +from scilpy.io.stateful_image import StatefulImage def _build_arg_parser(): @@ -80,7 +81,7 @@ def main(): logging.info('Loading raw data from %s', args.in_image) - img = nib.load(args.in_image) + simg = StatefulImage.load(args.in_image) ref_img = None if args.ref: @@ -93,14 +94,14 @@ def main(): volume_shape = args.volume_size # Resampling volume - reshaped_img = reshape_volume(img, volume_shape, - mode=args.mode, - cval=args.constant_value, - dtype=args.data_type) + reshaped_simg = reshape_volume(simg, volume_shape, + mode=args.mode, + cval=args.constant_value, + dtype=args.data_type) # Saving results logging.info('Saving reshaped data to %s', args.out_image) - nib.save(reshaped_img, args.out_image) + reshaped_simg.save(args.out_image) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_volume_reslice_to_reference.py b/src/scilpy/cli/scil_volume_reslice_to_reference.py index 12823681e..83cf92167 100755 --- a/src/scilpy/cli/scil_volume_reslice_to_reference.py +++ b/src/scilpy/cli/scil_volume_reslice_to_reference.py @@ -19,13 +19,13 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.image.volume_operations import apply_transform from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, add_verbose_arg, assert_outputs_exist) from scilpy.version import version_string +from scilpy.io.stateful_image import StatefulImage def _build_arg_parser(): @@ -63,14 +63,14 @@ def main(): assert_outputs_exist(parser, args, args.out_file) # Load images. - in_file = nib.load(args.in_file) - ref_file = nib.load(args.in_ref_file) + simg = StatefulImage.load(args.in_file) + ref_file = StatefulImage.load(args.in_ref_file) - reshaped_img = apply_transform(np.eye(4), ref_file, in_file, - interp=args.interpolation, - keep_dtype=args.keep_dtype) + reshaped_simg = apply_transform(np.eye(4), ref_file, simg, + interp=args.interpolation, + keep_dtype=args.keep_dtype) - nib.save(reshaped_img, args.out_file) + reshaped_simg.save(args.out_file) if __name__ == "__main__": diff --git a/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py b/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py new file mode 100644 index 000000000..847f070e6 --- /dev/null +++ b/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import nibabel as nib +import numpy as np +import tempfile + + +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_volume_modify_voxel_order', '--help']) + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_file = 'input.nii.gz' + img = nib.Nifti1Image(np.zeros((10, 20, 30)), np.eye(4)) + nib.save(img, in_file) + + # Test with character-based voxel order + out_file_lps = 'output_lps.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_lps, '--new_voxel_order=LPS', '-f']) + assert ret.success + lps_img = nib.load(out_file_lps) + assert nib.aff2axcodes(lps_img.affine) == ('L', 'P', 'S') + + # Test with numeric voxel order + out_file_asr = 'output_asr.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_asr, '--new_voxel_order=3,1,2', '-f']) + assert ret.success + asr_img = nib.load(out_file_asr) + assert nib.aff2axcodes(asr_img.affine) == ('S', 'R', 'A') + + # Test with negative numeric voxel order + out_file_lai = 'output_lai.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_lai, '--new_voxel_order=-1,2,-3', + '-f']) + assert ret.success + lai_img = nib.load(out_file_lai) + assert nib.aff2axcodes(lai_img.affine) == ('L', 'A', 'I') + + # Test with invalid input + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + 'output.nii.gz', '--new_voxel_order=invalid', + '-f']) + assert not ret.success diff --git a/src/scilpy/image/tests/test_volume_operations.py b/src/scilpy/image/tests/test_volume_operations.py index 18d3f22ff..a60ebf1fa 100644 --- a/src/scilpy/image/tests/test_volume_operations.py +++ b/src/scilpy/image/tests/test_volume_operations.py @@ -17,6 +17,7 @@ merge_metrics, normalize_metric, resample_volume, reshape_volume, register_image) +from scilpy.io.stateful_image import StatefulImage from scilpy.io.fetcher import fetch_data, get_testing_files_dict from scilpy.image.utils import compute_nifti_bounding_box @@ -53,10 +54,10 @@ def test_crop_volume(): temp = np.ones((3, 3, 3)) vol = np.pad(temp, pad_width=2, mode='constant', constant_values=0) - img = nib.Nifti1Image(vol, np.eye(4)) - wbbox = compute_nifti_bounding_box(img) + simg = StatefulImage(vol, np.eye(4)) + wbbox = compute_nifti_bounding_box(simg) - vol_cropped = crop_volume(img, wbbox) + vol_cropped = crop_volume(simg, wbbox) assert_equal(temp, vol_cropped.get_fdata()) @@ -71,10 +72,10 @@ def test_apply_transform(): transfo = np.eye(4) transfo[0, 3] = 1 - moving3d_img = nib.Nifti1Image(moving3d, np.eye(4)) + moving3d_simg = StatefulImage(moving3d, np.eye(4)) ref3d_img = nib.Nifti1Image(ref3d, np.eye(4)) - warped_img3d = apply_transform(transfo, ref3d_img, moving3d_img) + warped_img3d = apply_transform(transfo, ref3d_img, moving3d_simg) assert_equal(ref3d, warped_img3d.get_fdata()) @@ -82,9 +83,9 @@ def test_apply_transform(): moving4d = np.pad(np.ones((3, 3, 3, 2)), pad_width=1, mode='constant', constant_values=0) - moving4d_img = nib.Nifti1Image(moving4d, np.eye(4)) + moving4d_simg = StatefulImage(moving4d, np.eye(4)) - warped_img4d = apply_transform(transfo, ref3d_img, moving4d_img) + warped_img4d = apply_transform(transfo, ref3d_img, moving4d_simg) assert_equal(ref3d, warped_img4d.get_fdata()[:, :, :, 2]) @@ -176,7 +177,7 @@ def test_resample_volume(): # affine as np.eye => voxel size 1x1x1 moving3d = np.pad(np.ones((4, 4, 4)), pad_width=1, mode='constant', constant_values=0) - moving3d_img = nib.Nifti1Image(moving3d, np.eye(4)) + moving3d_simg = StatefulImage(moving3d, np.eye(4)) # Ref: 2x2x2, voxel size 3x3x3 ref3d = np.ones((2, 2, 2)) @@ -185,111 +186,111 @@ def test_resample_volume(): # 1) Option volume_shape: expecting an output of 2x2x2, which means # voxel resolution 3x3x3 - resampled_img = resample_volume(moving3d_img, volume_shape=(2, 2, 2), - interp='nn') - assert_equal(resampled_img.get_fdata(), ref3d) - assert resampled_img.affine[0, 0] == 3 + resampled_simg = resample_volume(moving3d_simg, volume_shape=(2, 2, 2), + interp='nn') + assert_equal(resampled_simg.get_fdata(), ref3d) + assert resampled_simg.affine[0, 0] == 3 # 2) Option reference image that is 2x2x2, resolution 3x3x3. ref_img = nib.Nifti1Image(ref3d, ref_affine) - resampled_img = resample_volume(moving3d_img, ref_img=ref_img, - interp='nn') - assert_equal(resampled_img.get_fdata(), ref3d) - assert resampled_img.affine[0, 0] == 3 + resampled_simg = resample_volume(moving3d_simg, ref_img=ref_img, + interp='nn') + assert_equal(resampled_simg.get_fdata(), ref3d) + assert resampled_simg.affine[0, 0] == 3 # 3) Option final resolution 3x3x3, should be of shape 2x2x2 - resampled_img = resample_volume(moving3d_img, voxel_res=(3, 3, 3), - interp='nn') - assert_equal(resampled_img.get_fdata(), ref3d) - assert resampled_img.affine[0, 0] == 3 + resampled_simg = resample_volume(moving3d_simg, voxel_res=(3, 3, 3), + interp='nn') + assert_equal(resampled_simg.get_fdata(), ref3d) + assert resampled_simg.affine[0, 0] == 3 # 4) Same test, with a fake 4th dimension moving3d = np.stack((moving3d, moving3d), axis=-1) - moving3d_img = nib.Nifti1Image(moving3d, np.eye(4)) - resampled_img = resample_volume(moving3d_img, voxel_res=(3, 3, 3), - interp='nn') - result = resampled_img.get_fdata() + moving3d_simg = StatefulImage(moving3d, np.eye(4)) + resampled_simg = resample_volume(moving3d_simg, voxel_res=(3, 3, 3), + interp='nn') + result = resampled_simg.get_fdata() assert_equal(result[:, :, :, 0], ref3d) assert_equal(result[:, :, :, 1], ref3d) - assert resampled_img.affine[0, 0] == 3 + assert resampled_simg.affine[0, 0] == 3 def test_reshape_volume_pad(): # 3D img - img = nib.Nifti1Image( + simg = StatefulImage( np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Reshaping to 4x4x4, padding with 0 - reshaped_img = reshape_volume(img, (4, 4, 4)) + reshaped_simg = reshape_volume(simg, (4, 4, 4)) - assert_equal(reshaped_img.affine[:, -1], [-1, -1, -1, 1]) - assert_equal(reshaped_img.get_fdata()[0, 0, 0], 0) + assert_equal(reshaped_simg.affine[:, -1], [-1, -1, -1, 1]) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0], 0) # 2) Reshaping to 4x4x4, padding with -1 - reshaped_img = reshape_volume(img, (4, 4, 4), mode='constant', - cval=-1) - assert_equal(reshaped_img.get_fdata()[0, 0, 0], -1) + reshaped_simg = reshape_volume(simg, (4, 4, 4), mode='constant', + cval=-1) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0], -1) # 3) Reshaping to 4x4x4, padding with edge - reshaped_img = reshape_volume(img, (4, 4, 4), mode='edge') - assert_equal(reshaped_img.get_fdata()[0, 0, 0], 1) + reshaped_simg = reshape_volume(simg, (4, 4, 4), mode='edge') + assert_equal(reshaped_simg.get_fdata()[0, 0, 0], 1) # 4D img (2 "stacked" 3D volumes) - img = nib.Nifti1Image( + simg = StatefulImage( np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Reshaping to 5x5x5, padding with 0 - reshaped_img = reshape_volume(img, (5, 5, 5)) - assert_equal(reshaped_img.get_fdata()[0, 0, 0, 0], 0) + reshaped_simg = reshape_volume(simg, (5, 5, 5)) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0, 0], 0) def test_reshape_volume_crop(): # 3D img - img = nib.Nifti1Image( + simg = StatefulImage( np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Cropping to 1x1x1 - reshaped_img = reshape_volume(img, (1, 1, 1)) - assert_equal(reshaped_img.get_fdata().shape, (1, 1, 1)) - assert_equal(reshaped_img.affine[:, -1], [1, 1, 1, 1]) - assert_equal(reshaped_img.get_fdata()[0, 0, 0], 14) + reshaped_simg = reshape_volume(simg, (1, 1, 1)) + assert_equal(reshaped_simg.get_fdata().shape, (1, 1, 1)) + assert_equal(reshaped_simg.affine[:, -1], [1, 1, 1, 1]) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0], 14) # 2) Cropping to 2x2x2 - reshaped_img = reshape_volume(img, (2, 2, 2)) - assert_equal(reshaped_img.get_fdata().shape, (2, 2, 2)) - assert_equal(reshaped_img.affine[:, -1], [0, 0, 0, 1]) - assert_equal(reshaped_img.get_fdata()[0, 0, 0], 1) + reshaped_simg = reshape_volume(simg, (2, 2, 2)) + assert_equal(reshaped_simg.get_fdata().shape, (2, 2, 2)) + assert_equal(reshaped_simg.affine[:, -1], [0, 0, 0, 1]) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0], 1) # 4D img - img = nib.Nifti1Image( + simg = StatefulImage( np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Cropping to 2x2x2 - reshaped_img = reshape_volume(img, (2, 2, 2)) - assert_equal(reshaped_img.get_fdata().shape, (2, 2, 2, 2)) - assert_equal(reshaped_img.affine[:, -1], [0, 0, 0, 1]) - assert_equal(reshaped_img.get_fdata()[0, 0, 0, 0], 1) + reshaped_simg = reshape_volume(simg, (2, 2, 2)) + assert_equal(reshaped_simg.get_fdata().shape, (2, 2, 2, 2)) + assert_equal(reshaped_simg.affine[:, -1], [0, 0, 0, 1]) + assert_equal(reshaped_simg.get_fdata()[0, 0, 0, 0], 1) def test_reshape_volume_dtype(): # 3D img - img = nib.Nifti1Image( + simg = StatefulImage( np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(np.uint16), np.eye(4)) # 1) Staying in 3x3x3, same dtype - reshaped_img = reshape_volume(img, (3, 3, 3)) - assert_equal(reshaped_img.get_fdata().shape, (3, 3, 3)) - assert reshaped_img.get_data_dtype() == img.get_data_dtype() + reshaped_simg = reshape_volume(simg, (3, 3, 3)) + assert_equal(reshaped_simg.get_fdata().shape, (3, 3, 3)) + assert reshaped_simg.get_data_dtype() == simg.get_data_dtype() # 1) Staying in 3x3x3, casting to float - reshaped_img = reshape_volume(img, (3, 3, 3), dtype=float) - assert_equal(reshaped_img.get_fdata().shape, (3, 3, 3)) - assert reshaped_img.get_data_dtype() == float + reshaped_simg = reshape_volume(simg, (3, 3, 3), dtype=float) + assert_equal(reshaped_simg.get_fdata().shape, (3, 3, 3)) + assert reshaped_simg.get_data_dtype() == float def test_normalize_metric_basic(): diff --git a/src/scilpy/image/volume_math.py b/src/scilpy/image/volume_math.py index 8c5b0bb31..ca0773526 100644 --- a/src/scilpy/image/volume_math.py +++ b/src/scilpy/image/volume_math.py @@ -55,6 +55,7 @@ def get_array_ops(): ('subtraction', subtraction), ('multiplication', multiplication), ('division', division), + ('maximum', maximum), ('mean', mean), ('std', std), ('correlation', neighborhood_correlation), @@ -442,6 +443,27 @@ def convert(input_list, ref_img): return input_list[0].get_fdata(dtype=np.float64) +def maximum(input_list, ref_img): + """ + maximum: IMGs + Compute the voxel-wise maximum across images. + """ + _validate_length(input_list, 2, at_least=True) + _validate_imgs_type(*input_list, all_imgs=False) + _validate_same_shape(*input_list, ref_img, all_imgs=False) + + output_data = np.zeros(ref_img.header.get_data_shape(), dtype=np.float64) + for img in input_list: + if isinstance(img, nib.Nifti1Image): + data = img.get_fdata(dtype=np.float64) + output_data = np.maximum(output_data, data) + img.uncache() + else: + output_data = np.maximum(output_data, img) + + return output_data + + def addition(input_list, ref_img): """ addition: IMGs diff --git a/src/scilpy/image/volume_operations.py b/src/scilpy/image/volume_operations.py index db4fb849e..1acd986d0 100644 --- a/src/scilpy/image/volume_operations.py +++ b/src/scilpy/image/volume_operations.py @@ -24,6 +24,7 @@ from scilpy.gradients.bvec_bval_tools import identify_shells from scilpy.utils.spatial import voxel_to_world from scilpy.utils.spatial import world_to_voxel +from scilpy.io.stateful_image import StatefulImage def count_non_zero_voxels(image): @@ -79,25 +80,27 @@ def flip_volume(data, axes): return data -def crop_volume(img: nib.Nifti1Image, wbbox): +def crop_volume(simg, wbbox): """ Applies cropping from a world space defined bounding box and fixes the affine to keep data aligned. Parameters ---------- - img: nib.Nifti1Image + simg: StatefulImage Input image to crop. wbbox: WorldBoundingBox Bounding box. Returns ------- - cropped_im: nib.Nifti1Image + cropped_im: StatefulImage The image with cropped data and transformed affine. """ - data = img.get_fdata(dtype=np.float32, caching='unchanged') - affine = img.affine + if not isinstance(simg, StatefulImage): + raise TypeError("Input 'simg' must be a StatefulImage object.") + data = simg.get_fdata(dtype=np.float32, caching='unchanged') + affine = simg.affine voxel_bb_mins = world_to_voxel(wbbox.minimums, affine) voxel_bb_maxs = world_to_voxel(wbbox.maximums, affine) @@ -114,11 +117,12 @@ def crop_volume(img: nib.Nifti1Image, wbbox): new_affine = np.copy(affine) new_affine[0:3, 3] = translation[0:3] - return nib.Nifti1Image(data_crop, new_affine) + cropped_img = nib.Nifti1Image(data_crop, new_affine) + return StatefulImage.create_from(cropped_img, simg) -def apply_transform(transfo, reference, moving, - interp='linear', keep_dtype=False): +def apply_transform(transfo, reference, + moving, interp='linear', keep_dtype=False): """ Apply transformation to an image using Dipy's tool @@ -128,7 +132,7 @@ def apply_transform(transfo, reference, moving, Transformation matrix to be applied reference: nib.Nifti1Image Filename of the reference image (target) - moving: nib.Nifti1Image + moving: StatefulImage Filename of the moving image interp : string, either 'linear' or 'nearest' the type of interpolation to be used, either 'linear' @@ -139,9 +143,11 @@ def apply_transform(transfo, reference, moving, Returns ------- - moved_im: nib.Nifti1Image + moved_im: StatefulImage The warped moving image. """ + if not isinstance(moving, StatefulImage): + raise TypeError("Input 'moving' must be a StatefulImage object.") grid2world, dim, _, _ = get_reference_info(reference) static_data = reference.get_fdata(dtype=np.float32) @@ -181,7 +187,8 @@ def apply_transform(transfo, reference, moving, else: raise ValueError('Does not support this dataset (shape, type, etc)') - return nib.Nifti1Image(resampled.astype(orig_type), grid2world) + moved_nib_img = nib.Nifti1Image(resampled.astype(orig_type), grid2world) + return StatefulImage.create_from(moved_nib_img, reference) def transform_dwi(reg_obj, static, dwi, interpolation='linear'): @@ -504,8 +511,8 @@ def _interp_code_to_order(interp_code): return orders[interp_code] -def resample_volume(img, ref_img=None, volume_shape=None, iso_min=False, - voxel_res=None, +def resample_volume(simg, ref_img=None, + volume_shape=None, iso_min=False, voxel_res=None, interp='lin', enforce_dimensions=False): """ Function to resample a dataset to match the resolution of another reference @@ -516,7 +523,7 @@ def resample_volume(img, ref_img=None, volume_shape=None, iso_min=False, Parameters ---------- - img: nib.Nifti1Image + simg: StatefulImage Image to resample. ref_img: nib.Nifti1Image, optional Reference volume to resample to. This method is used only if ref is not @@ -539,13 +546,15 @@ def resample_volume(img, ref_img=None, volume_shape=None, iso_min=False, Returns ------- - resampled_image: nib.Nifti1Image + resampled_image: StatefulImage Resampled image. """ - data = np.asanyarray(img.dataobj) + if not isinstance(simg, StatefulImage): + raise TypeError("Input 'simg' must be a StatefulImage object.") + data = np.asanyarray(simg.dataobj) original_shape = data.shape - affine = img.affine - original_zooms = img.header.get_zooms()[:3] + affine = simg.affine + original_zooms = simg.header.get_zooms()[:3] error_msg = ('Please only provide one option amongst ref_img, ' 'volume_shape, voxel_res or iso_min.') @@ -609,18 +618,17 @@ def resample_volume(img, ref_img=None, volume_shape=None, iso_min=False, data2[:x_dim, :y_dim, :z_dim] data2 = fix_dim_volume - return nib.Nifti1Image(data2.astype(data.dtype), affine2) + resampled_nib_img = nib.Nifti1Image(data2.astype(data.dtype), affine2) + return StatefulImage.create_from(resampled_nib_img, simg) -def reshape_volume( - img, volume_shape, mode='constant', cval=0, dtype=None -): +def reshape_volume(simg, volume_shape, mode='constant', cval=0, dtype=None): """ Reshape a volume to a specified shape by padding or cropping. The new volume is centered wrt the old volume in world space. Parameters ---------- - img : nib.Nifti1Image + simg : StatefulImage The input image. volume_shape : tuple of 3 ints The desired shape of the volume. @@ -634,14 +642,15 @@ def reshape_volume( Returns ------- - reshaped_img : nib.Nifti1Image + reshaped_simg : StatefulImage The reshaped image. """ - + if not isinstance(simg, StatefulImage): + raise TypeError("Input 'simg' must be a StatefulImage object.") if not dtype: - dtype = img.get_data_dtype() - data = img.get_fdata(dtype=np.float32) - affine = img.affine + dtype = simg.get_data_dtype() + data = simg.get_fdata(dtype=np.float32) + affine = simg.affine # Compute the difference between the desired shape and the current shape diff = (np.array(volume_shape) - np.array(data.shape[:3])) // 2 @@ -689,7 +698,8 @@ def reshape_volume( new_affine = np.copy(affine) new_affine[0:3, 3] = translation[0:3] - return nib.Nifti1Image(cropped_data.astype(dtype), new_affine) + reshaped_nib_img = nib.Nifti1Image(cropped_data.astype(dtype), new_affine) + return StatefulImage.create_from(reshaped_nib_img, simg) def mask_data_with_default_cube(data): diff --git a/src/scilpy/io/image.py b/src/scilpy/io/image.py index 64072bdb7..ad87af5db 100644 --- a/src/scilpy/io/image.py +++ b/src/scilpy/io/image.py @@ -8,7 +8,6 @@ from scilpy.utils import is_float - def load_img(arg): """ Function to create the variable for scil_volume_math main function. diff --git a/src/scilpy/io/stateful_image.py b/src/scilpy/io/stateful_image.py new file mode 100644 index 000000000..f6154aeec --- /dev/null +++ b/src/scilpy/io/stateful_image.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- + +import nibabel as nib +from dipy.io.utils import get_reference_info +from scilpy.utils.orientation import validate_voxel_order + + +class StatefulImage(nib.Nifti1Image): + """ + A class that extends nib.Nifti1Image to manage image orientation state. + + This class ensures that image data loaded into memory is always in a + consistent orientation (RAS by default), while preserving the original + on-disk orientation information. When saving, the image is automatically + reverted to its original orientation, ensuring non-destructive operations. + """ + + def __init__(self, dataobj, affine, header=None, extra=None, + file_map=None, original_affine=None, + original_dimensions=None, original_voxel_sizes=None, + original_axcodes=None): + """ + Initialize a StatefulImage object. + + Extends the Nifti1Image constructor to store original orientation info. + """ + super().__init__(dataobj, affine, header, extra, file_map) + + # Store original image information + self._original_affine = original_affine + self._original_dimensions = original_dimensions + self._original_voxel_sizes = original_voxel_sizes + self._original_axcodes = original_axcodes + + @classmethod + def load(cls, filename, to_orientation="RAS"): + """ + Load a NIfTI image, store its original orientation, and reorient it. + + Parameters + ---------- + filename : str + Path to the NIfTI file. + to_orientation : str or tuple, optional + The target orientation for the in-memory data. Default is "RAS". + + Returns + ------- + StatefulImage + An instance of StatefulImage with data in the target orientation. + """ + img = nib.load(filename) + + original_affine = img.affine.copy() + original_axcodes = nib.orientations.aff2axcodes(img.affine) + original_dims = img.header.get_data_shape() + original_voxel_sizes = img.header.get_zooms() + + if to_orientation: + validate_voxel_order(to_orientation) + start_ornt = nib.orientations.io_orientation(img.affine) + target_ornt = nib.orientations.axcodes2ornt(to_orientation) + transform = nib.orientations.ornt_transform(start_ornt, + target_ornt) + reoriented_img = img.as_reoriented(transform) + else: + reoriented_img = img + + return cls(reoriented_img.dataobj, reoriented_img.affine, + reoriented_img.header, original_affine=original_affine, + original_dimensions=original_dims, + original_voxel_sizes=original_voxel_sizes, + original_axcodes=original_axcodes) + + def save(self, filename): + """ + Save the image to a file, reverting to its original orientation. + + Parameters + ---------- + filename : str + Path to save the NIfTI file. + """ + if self._original_axcodes is None: + raise ValueError( + "Unknown original orientation. Ensure the image was loaded" + "with StatefulImage.load() or that original_axcodes was" + "provided when creating the StatefulImage instance.") + + self.reorient_to_original() + nib.save(self, filename) + + @staticmethod + def create_from(source, reference): + """ + Create a new StatefulImage from a source image, preserving the original + orientation information from a reference StatefulImage. + + Parameters + ---------- + source : nib.Nifti1Image + The image data to use for the new StatefulImage. + reference : StatefulImage + The reference image from which to copy original orientation + information. + + Returns + ------- + StatefulImage + A new StatefulImage with the source image's data and the reference + image's original orientation information. + """ + return StatefulImage(source.dataobj, source.affine, + header=source.header, + original_affine=reference._original_affine, + original_dimensions=reference._original_dimensions, + original_voxel_sizes=reference._original_voxel_sizes, + original_axcodes=reference._original_axcodes) + + def reorient_to_original(self): + """ + Reorient the in-memory image to its original orientation. + This method modifies the image in place. It does not return a new + Nifti1Image instance. + + Raises + ------ + ValueError + If the original axis codes are not set. + """ + if self._original_axcodes is None: + raise ValueError( + "Original axis codes are not set cannot reorient to original" + "orientation.") + self.reorient(self._original_axcodes) + + def reorient(self, target_axcodes): + """ + Reorient the in-memory image to a target orientation. + + Parameters + ---------- + target_axcodes : str or tuple + The target orientation axis codes (e.g., "LPS", ("R", "A", "S")). + """ + validate_voxel_order(target_axcodes) + + current_axcodes = nib.orientations.aff2axcodes(self.affine) + if current_axcodes == tuple(target_axcodes): + return + + # Check unique are only valid axis codes + valid_codes = {'L', 'R', 'A', 'P', 'S', 'I'} + for code in target_axcodes: + if code not in valid_codes: + raise ValueError(f"Invalid axis code '{code}' in target.") + + # Check L/R, A/P, S/I pairs are not both present + pairs = [('L', 'R'), ('A', 'P'), ('S', 'I')] + for pair in pairs: + if pair[0] in target_axcodes and pair[1] in target_axcodes: + raise ValueError(f"Conflicting axis codes '{pair[0]}' and " + f"'{pair[1]}' in target.") + + # Check no repeated axis codes (LL, RR, etc.) + if len(set(target_axcodes)) != 3: + raise ValueError("Target axis codes must be unique.") + + start_ornt = nib.orientations.axcodes2ornt(current_axcodes) + target_ornt = nib.orientations.axcodes2ornt(target_axcodes) + transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + + reoriented_img = self.as_reoriented(transform) + self.__init__(reoriented_img.dataobj, reoriented_img.affine, + reoriented_img.header, + original_affine=self._original_affine, + original_dimensions=self._original_dimensions, + original_voxel_sizes=self._original_voxel_sizes, + original_axcodes=self._original_axcodes) + + def to_ras(self): + """Convenience method to reorient in-memory data to RAS.""" + self.reorient(("R", "A", "S")) + + def to_lps(self): + """Convenience method to reorient in-memory data to LPS.""" + self.reorient(("L", "P", "S")) + + def to_reference(self, obj): + """ + Reorient the in-memory image to match the orientation of a reference + object. + + Parameters + ---------- + obj : object + Reference object from which orientation information can be obtained. + Must not be an instance of ``StatefulImage``. + + Raises + ------ + TypeError + If ``obj`` is an instance of ``StatefulImage``. + """ + + if isinstance(obj, StatefulImage): + raise TypeError('Reference object must not be a StatefulImage.') + + _, _, _, voxel_order = get_reference_info(obj) + self.reorient(voxel_order) + + @property + def axcodes(self): + """Get the axis codes for the current image orientation.""" + return nib.orientations.aff2axcodes(self.affine) + + @property + def original_axcodes(self): + """Get the axis codes for the original image orientation.""" + return self._original_axcodes + + def __str__(self): + """Return a string representation of the image, including orientation.""" + base_str = super().__str__() + current_axcodes = self.axcodes + reoriented = current_axcodes != self._original_axcodes + + orientation_info = ( + f"Reorientation Information:\n" + f" Original axis codes: {self._original_axcodes}\n" + f" Current axis codes: {current_axcodes}\n" + f" Reoriented from original: {reoriented}" + ) + + return f"{base_str}\n{orientation_info}" diff --git a/src/scilpy/io/tests/__init__.py b/src/scilpy/io/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scilpy/io/tests/test_stateful_image.py b/src/scilpy/io/tests/test_stateful_image.py new file mode 100644 index 000000000..78592cace --- /dev/null +++ b/src/scilpy/io/tests/test_stateful_image.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- + +import os +import pytest +import tempfile +from contextlib import contextmanager + +import nibabel as nib +import numpy as np + +from scilpy.io.stateful_image import StatefulImage + + +@contextmanager +def create_dummy_nifti_file(filename="test.nii.gz", in_lps=False): + """ + Create a dummy NIfTI file for testing in a temporary directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + shape = (10, 10, 10) + affine = np.eye(4) if not in_lps else np.diag([-1, -1, 1, 1]) + data = np.random.rand(*shape).astype(np.float32) + img = nib.Nifti1Image(data, affine) + + file_path = os.path.join(tmpdir, filename) + nib.save(img, file_path) + + yield file_path + + +def test_load_and_reorient(): + """ + Test loading a NIfTI file and reorienting it to RAS. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation="RAS") + + assert isinstance(img, StatefulImage) + assert img.axcodes == ("R", "A", "S") + assert img.original_axcodes == ("R", "A", "S") + + +def test_save_to_original_orientation(): + """ + Test that saving the image reverts it to its original orientation. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation="LPS") + + # Save the image + tmp_dir = os.path.dirname(file_path) + output_path = os.path.join(tmp_dir, "output.nii.gz") + img.save(output_path) + + # Load the saved image and check its orientation + saved_img = nib.load(output_path) + assert nib.orientations.aff2axcodes(saved_img.affine) == ("R", "A", "S") + + +def test_reorient_to_original(): + """ + Test reorienting the image back to its original orientation. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation="LPS") + img.reorient_to_original() + assert img.axcodes == ("R", "A", "S") + + +def test_to_ras_lps(): + """ + Test the to_ras() and to_lps() convenience methods. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path) + + img.to_lps() + assert img.axcodes == ("L", "P", "S") + + img.to_ras() + assert img.axcodes == ("R", "A", "S") + + +def test_to_reference(): + """ + Test reorienting to match a reference image. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path) + + # Create a reference image with a different orientation + ref_affine = np.diag([-1, -1, 1, 1]) + ref_img = nib.Nifti1Image(np.zeros((10, 10, 10)), ref_affine) + + img.to_reference(ref_img) + assert img.axcodes == ("L", "P", "S") + + +def test_to_reference_stateful_image(): + """ + Test that to_reference raises a TypeError with a StatefulImage. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path) + ref_img = StatefulImage.load(file_path) + + with pytest.raises(TypeError, + match="Reference object cannot be a StatefulImage."): + img.to_reference(ref_img) + + +def test_axcodes_properties_tuple(): + """ + Test the axcodes and original_axcodes properties. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation=("L", "P", "S")) + assert img.axcodes == ("L", "P", "S") + assert img.original_axcodes == ("R", "A", "S") + + +def test_axcodes_properties_string(): + """ + Test the axcodes and original_axcodes properties. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation="LPS") + assert img.axcodes == ("L", "P", "S") + assert img.original_axcodes == ("R", "A", "S") + + +def test_str_representation(): + """ + Test the string representation of the StatefulImage. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation="LPS") + s = str(img) + assert "Original axis codes: ('R', 'A', 'S')" in s + assert "Current axis codes: ('L', 'P', 'S')" in s + assert "Reoriented from original: True" in s + + +def test_load_no_reorientation(): + """ + Test that loading without reorientation works as expected. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path, to_orientation=None) + assert img.axcodes == ("R", "A", "S") + assert img.original_axcodes == ("R", "A", "S") + + +def test_reorient_no_op_tuple(): + """ + Test that reorienting to the same orientation is a no-op. + """ + with create_dummy_nifti_file(in_lps=True) as file_path: + img = StatefulImage.load(file_path) + img.reorient(("R", "A", "S")) + assert img.axcodes == ("R", "A", "S") + + +def test_reorient_no_op_string(): + """ + Test that reorienting to the same orientation is a no-op. + """ + with create_dummy_nifti_file(in_lps=True) as file_path: + img = StatefulImage.load(file_path) + img.reorient("RAS") + assert img.axcodes == ("R", "A", "S") + + +def test_direct_instantiation(): + """ + Test direct instantiation of StatefulImage. + """ + with create_dummy_nifti_file() as file_path: + nii = nib.load(file_path) + img = StatefulImage(nii.dataobj, nii.affine, nii.header) + + assert img.original_axcodes is None + assert img.axcodes == ("R", "A", "S") + + # Test that save fails without original orientation information + with pytest.raises(ValueError): + img.save("test.nii.gz") + + # Test that reorient_to_original fails + with pytest.raises(ValueError): + img.reorient_to_original() + + +@pytest.mark.parametrize("codes, error_msg", [ + (None, "Axis codes cannot be None."), + ("INVALID", "Target axis codes must be of length 3."), + ("RAR", "Target axis codes must be unique."), + ("LRR", "Target axis codes must be unique."), + ("LRA", "Conflicting axis codes 'L' and 'R' in target."), + ("API", "Conflicting axis codes 'A' and 'P' in target."), +]) +def test_stateful_image_bad_axcodes_reorient(codes, error_msg): + """ + Test that reorienting with invalid axis codes raises a ValueError. + """ + with create_dummy_nifti_file(filename="dummy.nii.gz", in_lps=True) as filepath: + stateful_img = StatefulImage.load(filepath) + with pytest.raises(ValueError, match=error_msg): + stateful_img.reorient(codes) + + +@pytest.mark.parametrize("codes, error_msg", [ + ("INVALID", "Target axis codes must be of length 3."), + ("RAR", "Target axis codes must be unique."), + ("LRR", "Target axis codes must be unique."), + ("LRA", "Conflicting axis codes 'L' and 'R' in target."), + ("API", "Conflicting axis codes 'A' and 'P' in target."), +]) +def test_stateful_image_bad_axcodes_load(codes, error_msg): + """ + Test that loading with invalid axis codes raises a ValueError. + """ + with create_dummy_nifti_file(filename="dummy.nii.gz", in_lps=True) as filepath: + with pytest.raises(ValueError, match=error_msg): + StatefulImage.load(filepath, to_orientation=codes) + + +@pytest.mark.parametrize("codes", [ + ("R", "A", "S"), "RAS", + ("L", "P", "S"), "LPS", + ("A", "R", "S"), "ARS", + ("L", "P", "I"), "LPI", + ("S", "P", "L"), "SPL", +]) +def test_reorient_valid_codes(codes): + """ + Test that reorienting with valid codes does not raises a ValueError. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path) + img.reorient(codes) + + +@pytest.mark.parametrize('codes, invalid_code', [ + (("X", "Y", "Z"), "X"), + (("L", "A", "B"), "B"), +]) +def test_reorient_invalid_codes(codes, invalid_code): + """ + Test that reorienting with invalid codes raises a ValueError. + """ + with create_dummy_nifti_file() as file_path: + img = StatefulImage.load(file_path) + with pytest.raises(ValueError, + match=f"Invalid axis code '{invalid_code}' in target."): + img.reorient(codes) diff --git a/src/scilpy/utils/__init__.py b/src/scilpy/utils/__init__.py index bf683757f..f47858d39 100644 --- a/src/scilpy/utils/__init__.py +++ b/src/scilpy/utils/__init__.py @@ -30,9 +30,9 @@ def recursive_update(d, u, from_existing=False): from_existing=from_existing) else: if not from_existing: - d[k] = float('nan') + d[k] = float("nan") elif k not in d: - d[k] = float('nan') + d[k] = float("nan") return d diff --git a/src/scilpy/utils/orientation.py b/src/scilpy/utils/orientation.py new file mode 100644 index 000000000..1c2b70365 --- /dev/null +++ b/src/scilpy/utils/orientation.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +import re + + +def validate_voxel_order(axcodes, dimensions=3): + """ + Validate a set of axis codes. + Parameters + ---------- + axcodes : str or tuple or list + The axis codes to validate (e.g., "LPS", ("R", "A", "S")). + dimensions : int + The number of dimensions of the image. + Returns + ------- + tuple + A tuple of validated axis codes. + Raises + ------ + ValueError + If the axis codes are invalid. + """ + if axcodes is None: + raise ValueError("Axis codes cannot be None.") + + axcodes = tuple(axcodes) + if len(axcodes) != dimensions: + raise ValueError(f"Target axis codes must be of length {dimensions}.") + + # Check unique are only valid axis codes + valid_codes = {"L", "R", "A", "P", "S", "I"} + if dimensions == 4: + valid_codes.add("T") + for code in axcodes: + if code not in valid_codes: + raise ValueError(f"Invalid axis code '{code}' in target.") + + # Check no repeated axis codes (LL, RR, etc.) + if len(set(axcodes)) != dimensions: + raise ValueError("Target axis codes must be unique.") + + # Check L/R, A/P, S/I pairs are not both present + pairs = [("L", "R"), ("A", "P"), ("S", "I")] + for pair in pairs: + if pair[0] in axcodes and pair[1] in axcodes: + raise ValueError(f"Conflicting axis codes '{pair[0]}' and " + f"'{pair[1]}' in target.") + return axcodes + + +def parse_voxel_order(order_str, dimensions=3): + """ + Parse the voxel order string into a tuple of axis codes. + """ + order_str_cleaned = order_str.replace(',', '').replace(' ', '') + + if dimensions == 4 and order_str_cleaned.isalpha(): + raise ValueError("Alphabetical voxel order is not supported for 4D " + "images. Please use numeric format.") + + if order_str_cleaned.isalpha(): + if len(order_str_cleaned) != 3: + raise ValueError("Voxel order string must have 3 characters.") + return validate_voxel_order(tuple(order_str_cleaned.upper())) + + if order_str_cleaned.replace('-', '').isdigit(): + numeric_parts = re.findall(r'-?\d', order_str_cleaned) + if len(numeric_parts) == 4 and dimensions != 4: + raise ValueError("4D voxel order is only supported for 4D images.") + if len(numeric_parts) not in [3, 4]: + raise ValueError("Voxel order string must have 3 or 4 numbers.") + + if dimensions == 4: + ras_map = {1: 'R', 2: 'A', 3: 'S', 4: 'T'} + flip_map = {'R': 'L', 'A': 'P', 'S': 'I', 'T': 'T'} + if len(numeric_parts) == 4: + if abs(int(numeric_parts[3])) != 4: + raise ValueError("The 4th dimension must be 4 or -4.") + else: + ras_map = {1: 'R', 2: 'A', 3: 'S'} + flip_map = {'R': 'L', 'A': 'P', 'S': 'I'} + + order = [] + for part in numeric_parts: + num = int(part) + axis = ras_map[abs(num)] + if num < 0: + axis = flip_map[axis] + order.append(axis) + + # Check for duplicate axes + if len(set(order)) != len(numeric_parts): + # Handle swapped axes from numeric input (e.g., '231') + axis_vals = [ras_map[abs(int(p))] for p in numeric_parts] + if len(set(axis_vals)) == len(numeric_parts): + return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) + else: + raise ValueError("Invalid numeric voxel order. " + "Axes cannot be repeated.") + + return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) + + raise ValueError(f"Invalid voxel order format: {order_str}") diff --git a/src/scilpy/utils/scilpy_bot.py b/src/scilpy/utils/scilpy_bot.py index a4961cd1b..b3792776e 100644 --- a/src/scilpy/utils/scilpy_bot.py +++ b/src/scilpy/utils/scilpy_bot.py @@ -201,9 +201,9 @@ def _generate_help_file(args): return # Run the script with --h and capture the output result = subprocess.run(['python', script, '--h'], - capture_output=True, text=True) + capture_output=True, text=True, encoding='utf-8') # Save the output to the hidden file - with open(help_file, 'w') as f: + with open(help_file, 'w', encoding='utf-8') as f: f.write(result.stdout) diff --git a/src/scilpy/utils/tests/test_orientation.py b/src/scilpy/utils/tests/test_orientation.py new file mode 100644 index 000000000..815794550 --- /dev/null +++ b/src/scilpy/utils/tests/test_orientation.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +import pytest +from scilpy.utils.orientation import (validate_voxel_order, + parse_voxel_order) + + +def test_validate_voxel_order_valid(): + """Test that valid axis codes pass validation.""" + validate_voxel_order("RAS") + validate_voxel_order(("R", "A", "S")) + validate_voxel_order(["R", "A", "S"]) + + +def test_validate_voxel_order_none(): + """Test that None raises a ValueError.""" + with pytest.raises(ValueError, + match="Axis codes cannot be None."): + validate_voxel_order(None) + + +def test_validate_voxel_order_invalid_code(): + """Test that an invalid code raises a ValueError.""" + with pytest.raises(ValueError, + match="Invalid axis code 'X' in target."): + validate_voxel_order("XAS") + + +def test_validate_voxel_order_conflicting_codes(): + """Test that conflicting codes raise a ValueError.""" + with pytest.raises(ValueError, + match="Conflicting axis codes 'L' and 'R' in target."): + validate_voxel_order("LRS") + + +def test_validate_voxel_order_wrong_length(): + """Test that codes with length != 3 raise a ValueError.""" + with pytest.raises(ValueError, + match="Target axis codes must be of length 3."): + validate_voxel_order("RASL") + + +def test_validate_voxel_order_repeated_codes(): + """Test that repeated codes raise a ValueError.""" + with pytest.raises(ValueError, + match="Target axis codes must be unique."): + validate_voxel_order("RRS") + + +def test_parse_voxel_order_valid_alpha(): + """Test parsing of valid alphabetical voxel order strings.""" + assert parse_voxel_order("RAS") == ("R", "A", "S") + assert parse_voxel_order("LPI") == ("L", "P", "I") + assert parse_voxel_order("ASR") == ("A", "S", "R") + + +def test_parse_voxel_order_invalid_alpha_length(): + """Test that alphabetical strings of incorrect length raise an error.""" + with pytest.raises(ValueError, + match="Voxel order string must have 3 characters."): + parse_voxel_order("RA") + + +def test_parse_voxel_order_valid_numeric(): + """Test parsing of valid numeric voxel order strings.""" + assert parse_voxel_order("1,2,3") == ("R", "A", "S") + assert parse_voxel_order("-1,2,-3") == ("L", "A", "I") + assert parse_voxel_order("2,3,1") == ("A", "S", "R") + + +def test_parse_voxel_order_invalid_numeric_repeat(): + """Test that numeric strings with repeated axes raise an error.""" + with pytest.raises(ValueError, match="Axes cannot be repeated."): + parse_voxel_order("1,1,2") + + +def test_parse_voxel_order_invalid_format(): + """Test that mixed or invalid format strings raise an error.""" + with pytest.raises(ValueError, + match="Invalid voxel order format: 1,A,2"): + parse_voxel_order("1,A,2") + + with pytest.raises(ValueError, + match="Voxel order string must have 3 or 4 numbers."): + parse_voxel_order("1,2,3,4,5", dimensions=4) + +def test_parse_voxel_order_4d_valid_numeric(): + """Test parsing of valid 4D numeric voxel order strings.""" + assert parse_voxel_order("1,2,3,4", dimensions=4) == ("R", "A", "S", "T") + assert parse_voxel_order("-1,2,-3,4", dimensions=4) == ("L", "A", "I", "T") + assert parse_voxel_order("2,3,1", dimensions=4) == ("A", "S", "R") + + +def test_parse_voxel_order_4d_invalid_alpha(): + """Test that 4D alphabetical voxel order strings raise an error.""" + with pytest.raises(ValueError, + match="Alphabetical voxel order is not supported for 4D " + "images. Please use numeric format."): + parse_voxel_order("RAS", dimensions=4) + + +def test_parse_voxel_order_4d_invalid_numeric(): + """Test that invalid 4D numeric voxel order strings raise an error.""" + with pytest.raises(ValueError, + match="The 4th dimension must be 4 or -4."): + parse_voxel_order("1,2,3,5", dimensions=4) + + with pytest.raises(ValueError, + match="Voxel order string must have 3 or 4 numbers."): + parse_voxel_order("1,2", dimensions=4) + + with pytest.raises(ValueError, + match="Voxel order string must have 3 or 4 numbers."): + parse_voxel_order("1,2,3,4,5", dimensions=4) + + with pytest.raises(ValueError, match="Axes cannot be repeated."): + parse_voxel_order("1,1,2,4", dimensions=4) + + +def test_parse_voxel_order_invalid_format_3d(): + """Test that mixed or invalid format strings raise an error for 3D.""" + with pytest.raises(ValueError, + match="Invalid voxel order format: 1A2"): + parse_voxel_order("1A2") + + with pytest.raises(ValueError, + match="4D voxel order is only supported for 4D images."): + parse_voxel_order("1,2,3,4")