diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index 31fbb197c..e31ecf19b 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -52,6 +52,16 @@ PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM TypeBound = Tuple[float, float] TypeBounds = Tuple[TypeBound, TypeBound, TypeBound] +FLIP_AXIS = { + 'L': 'R', + 'R': 'L', + 'A': 'P', + 'P': 'A', + 'I': 'S', + 'S': 'I', + 'T': 'B', + 'B': 'T', +} deprecation_message = ( 'Setting the image data with the property setter is deprecated. Use the' @@ -378,7 +388,7 @@ def axis_name_to_index(self, axis: str) -> int: versions and first letters are also valid, as only the first letter will be used. - .. note:: If you are working with animals, you should probably use + .. note:: If you are working with animals, you should use ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'`` for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``, respectively. @@ -392,6 +402,15 @@ def axis_name_to_index(self, axis: str) -> int: if not isinstance(axis, str): raise ValueError('Axis must be a string') axis = axis[0].upper() + if axis not in 'LRPAISTB': + message = ( + 'Incorrect axis naming. Please use one of: "Left", "Right", ' + '"Anterior", "Posterior", "Inferior", "Superior". ' + 'Lower-case versions and first letters are also valid ' + '(i.e., "L", "r", etc). For 2D images, use "Top" and "Bottom" ' + 'to refer to the vertical (2nd) axis.' + ) + raise ValueError(message) # Generally, TorchIO tensors are (C, W, H, D) if axis in 'TB': # Top, Bottom @@ -400,31 +419,12 @@ def axis_name_to_index(self, axis: str) -> int: try: index = self.orientation.index(axis) except ValueError: - index = self.orientation.index(self.flip_axis(axis)) + index = self.orientation.index(FLIP_AXIS[axis]) # Return negative indices so that it does not matter whether we # refer to spatial dimensions or not index = -3 + index return index - @staticmethod - def flip_axis(axis: str) -> str: - """Return the opposite axis label. For example, ``'L'`` -> ``'R'``. - - Args: - axis: Axis label, such as ``'L'`` or ``'left'``. - """ - labels = 'LRPAISTBDV' - first = labels[::2] - last = labels[1::2] - flip_dict = {a: b for a, b in zip(first + last, last + first)} - axis = axis[0].upper() - flipped_axis = flip_dict.get(axis) - if flipped_axis is None: - values = ', '.join(labels) - message = f'Axis not understood. Please use one of: {values}' - raise ValueError(message) - return flipped_axis - def get_spacing_string(self) -> str: strings = [f'{n:.2f}' for n in self.spacing] string = f'({", ".join(strings)})' diff --git a/src/torchio/data/io.py b/src/torchio/data/io.py index d463f492f..4eef27f2a 100644 --- a/src/torchio/data/io.py +++ b/src/torchio/data/io.py @@ -14,11 +14,9 @@ from ..typing import TypeData from ..typing import TypeDataAffine from ..typing import TypeDirection -from ..typing import TypeDoubletInt from ..typing import TypePath from ..typing import TypeQuartetInt from ..typing import TypeTripletFloat -from ..typing import TypeTripletInt # Matrices used to switch between LPS and RAS @@ -87,26 +85,43 @@ def _read_dicom(directory: TypePath): def read_shape(path: TypePath) -> TypeQuartetInt: - reader = sitk.ImageFileReader() - reader.SetFileName(str(path)) - reader.ReadImageInformation() - num_channels = reader.GetNumberOfComponents() - num_dimensions = reader.GetDimension() + try: + reader = sitk.ImageFileReader() + reader.SetFileName(str(path)) + reader.ReadImageInformation() + num_channels = reader.GetNumberOfComponents() + num_dimensions = reader.GetDimension() + shape = reader.GetSize() + except RuntimeError as e: # try with NiBabel + message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...' + warnings.warn(message, stacklevel=2) + try: + obj: SpatialImage = nib.load(str(path)) # type: ignore[assignment] + except nib.loadsave.ImageFileError as e: + message = ( + f'File "{path}" not understood.' + ' Check supported formats by at' + ' https://simpleitk.readthedocs.io/en/master/IO.html#images' + ' and https://nipy.org/nibabel/api.html#file-formats' + ) + raise RuntimeError(message) from e + num_dimensions = obj.ndim + shape = obj.shape + num_channels = 1 if num_dimensions < 4 else shape[-1] assert 2 <= num_dimensions <= 4 if num_dimensions == 2: - spatial_shape_2d: TypeDoubletInt = reader.GetSize() - assert len(spatial_shape_2d) == 2 - si, sj = spatial_shape_2d + assert len(shape) == 2 + si, sj = shape sk = 1 elif num_dimensions == 4: # We assume bad NIfTI file (channels encoded as spatial dimension) - spatial_shape_4d: TypeQuartetInt = reader.GetSize() - assert len(spatial_shape_4d) == 4 - si, sj, sk, num_channels = spatial_shape_4d + assert len(shape) == 4 + si, sj, sk, num_channels = shape elif num_dimensions == 3: - spatial_shape_3d: TypeTripletInt = reader.GetSize() - assert len(spatial_shape_3d) == 3 - si, sj, sk = spatial_shape_3d + assert len(shape) == 3 + si, sj, sk = shape + else: + raise ValueError(f'Unsupported number of dimensions: {num_dimensions}') shape = num_channels, si, sj, sk return shape diff --git a/src/torchio/transforms/augmentation/intensity/random_ghosting.py b/src/torchio/transforms/augmentation/intensity/random_ghosting.py index 4992fdcbe..fb8da3b5b 100644 --- a/src/torchio/transforms/augmentation/intensity/random_ghosting.py +++ b/src/torchio/transforms/augmentation/intensity/random_ghosting.py @@ -1,6 +1,5 @@ from collections import defaultdict from typing import Dict -from typing import Iterable from typing import Tuple from typing import Union @@ -60,16 +59,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - if not isinstance(axes, tuple): - try: - axes = tuple(axes) # type: ignore[arg-type] - except TypeError: - axes = (axes,) # type: ignore[assignment] - assert isinstance(axes, Iterable) - for axis in axes: - if not isinstance(axis, str) and axis not in (0, 1, 2): - raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"') - self.axes = axes + self.axes = self.parse_axes(axes) self.num_ghosts_range = self._parse_range( num_ghosts, 'num_ghosts', @@ -84,16 +74,13 @@ def __init__( self.restore = _parse_restore(restore) def apply_transform(self, subject: Subject) -> Subject: + axes = self.ensure_axes_indices(subject, self.axes) arguments: Dict[str, dict] = defaultdict(dict) - if any(isinstance(n, str) for n in self.axes): - subject.check_consistent_orientation() - for name, image in self.get_images_dict(subject).items(): - is_2d = image.is_2d() - axes = [a for a in self.axes if a != 2] if is_2d else self.axes + for name, _ in self.get_images_dict(subject).items(): min_ghosts, max_ghosts = self.num_ghosts_range params = self.get_params( + axes, (int(min_ghosts), int(max_ghosts)), - axes, # type: ignore[arg-type] self.intensity_range, ) num_ghosts_param, axis_param, intensity_param = params @@ -108,8 +95,8 @@ def apply_transform(self, subject: Subject) -> Subject: def get_params( self, - num_ghosts_range: Tuple[int, int], axes: Tuple[int, ...], + num_ghosts_range: Tuple[int, int], intensity_range: Tuple[float, float], ) -> Tuple: ng_min, ng_max = num_ghosts_range @@ -118,6 +105,17 @@ def get_params( intensity = self.sample_uniform(*intensity_range) return num_ghosts, axis, intensity + @staticmethod + def parse_restore(restore): + try: + restore = float(restore) + except ValueError as e: + raise TypeError(f'Restore must be a float, not "{restore}"') from e + if not 0 <= restore <= 1: + message = f'Restore must be a number between 0 and 1, not {restore}' + raise ValueError(message) + return restore + class Ghosting(IntensityTransform, FourierTransform): r"""Add MRI ghosting artifact. diff --git a/src/torchio/transforms/augmentation/intensity/random_motion.py b/src/torchio/transforms/augmentation/intensity/random_motion.py index 13b3b66d7..b29271a6d 100644 --- a/src/torchio/transforms/augmentation/intensity/random_motion.py +++ b/src/torchio/transforms/augmentation/intensity/random_motion.py @@ -26,6 +26,8 @@ class RandomMotion(RandomTransform, IntensityTransform, FourierTransform): simulate motion artifacts for data augmentation. Args: + axes: Tuple of integers or strings representing the axes along which + the simulated movements will occur. degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of the simulated movements. The rotation angles around each axis are :math:`(\theta_1, \theta_2, \theta_3)`, @@ -52,6 +54,7 @@ class RandomMotion(RandomTransform, IntensityTransform, FourierTransform): def __init__( self, + axes: Union[int, Tuple[int, ...], str, Tuple[str, ...]] = (0, 1, 2), degrees: Union[float, Tuple[float, float]] = 10, translation: Union[float, Tuple[float, float]] = 10, # in mm num_transforms: int = 2, @@ -59,6 +62,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.axes = self.parse_axes(axes) self.degrees_range = self.parse_degrees(degrees) self.translation_range = self.parse_translation(translation) if num_transforms < 1 or not isinstance(num_transforms, int): @@ -73,18 +77,20 @@ def __init__( ) def apply_transform(self, subject: Subject) -> Subject: + axes = self.ensure_axes_indices(subject, self.axes) arguments: Dict[str, dict] = defaultdict(dict) for name, image in self.get_images_dict(subject).items(): - params = self.get_params( + axis, times, degrees, translation = self.get_params( + axes, self.degrees_range, self.translation_range, self.num_transforms, is_2d=image.is_2d(), ) - times_params, degrees_params, translation_params = params - arguments['times'][name] = times_params - arguments['degrees'][name] = degrees_params - arguments['translation'][name] = translation_params + arguments['axis'][name] = axis + arguments['times'][name] = times + arguments['degrees'][name] = degrees + arguments['translation'][name] = translation arguments['image_interpolation'][name] = self.image_interpolation transform = Motion(**self.add_include_exclude(arguments)) transformed = transform(subject) @@ -93,12 +99,14 @@ def apply_transform(self, subject: Subject) -> Subject: def get_params( self, + axes: Tuple[int, ...], degrees_range: Tuple[float, float], translation_range: Tuple[float, float], num_transforms: int, perturbation: float = 0.3, is_2d: bool = False, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[int, np.ndarray, np.ndarray, np.ndarray]: + axis = axes[torch.randint(0, len(axes), (1,))] # If perturbation is 0, time intervals between movements are constant degrees_params = self.get_params_array( degrees_range, @@ -117,7 +125,7 @@ def get_params( noise.uniform_(-step * perturbation, step * perturbation) times += noise times_params = times.numpy() - return times_params, degrees_params, translation_params + return axis, times_params, degrees_params, translation_params @staticmethod def get_params_array(nums_range: Tuple[float, float], num_transforms: int): @@ -134,6 +142,7 @@ class Motion(IntensityTransform, FourierTransform): simulate motion artifacts for data augmentation. Args: + axis: Integer representing the axis along which the simulated movements degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`. translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm. times: Sequence of times from 0 to 1 at which the motions happen. @@ -144,6 +153,7 @@ class Motion(IntensityTransform, FourierTransform): def __init__( self, + axis: Union[int, Dict[str, int]], degrees: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]], translation: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]], times: Union[Sequence[float], Dict[str, Sequence[float]]], @@ -153,11 +163,13 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.axis = axis self.degrees = degrees self.translation = translation self.times = times self.image_interpolation = image_interpolation self.args_names = [ + 'axis', 'degrees', 'translation', 'times', @@ -165,16 +177,19 @@ def __init__( ] def apply_transform(self, subject: Subject) -> Subject: + axis = self.axis degrees = self.degrees translation = self.translation times = self.times image_interpolation = self.image_interpolation for image_name, image in self.get_images_dict(subject).items(): if self.arguments_are_dict(): + assert isinstance(self.axis, dict) assert isinstance(self.degrees, dict) assert isinstance(self.translation, dict) assert isinstance(self.times, dict) assert isinstance(self.image_interpolation, dict) + axis = self.axis[image_name] degrees = self.degrees[image_name] translation = self.translation[image_name] times = self.times[image_name] @@ -191,11 +206,13 @@ def apply_transform(self, subject: Subject) -> Subject: np.asarray(translation), sitk_image, ) + assert isinstance(axis, int) assert isinstance(image_interpolation, str) transformed_channel = self.add_artifact( sitk_image, transforms, np.asarray(times), + axis, image_interpolation, ) result_arrays.append(transformed_channel) @@ -211,35 +228,18 @@ def get_rigid_transforms( ) -> List[sitk.Euler3DTransform]: center_ijk = np.array(image.GetSize()) / 2 center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk) - identity = np.eye(4) - matrices = [identity] + ident_transform = sitk.Euler3DTransform() + ident_transform.SetCenter(center_lps) + transforms = [ident_transform] for degrees, translation in zip(degrees_params, translation_params): radians = np.radians(degrees).tolist() motion = sitk.Euler3DTransform() motion.SetCenter(center_lps) motion.SetRotation(*radians) motion.SetTranslation(translation.tolist()) - motion_matrix = self.transform_to_matrix(motion) - matrices.append(motion_matrix) - transforms = [self.matrix_to_transform(m) for m in matrices] + transforms.append(motion) return transforms - @staticmethod - def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray: - matrix = np.eye(4) - rotation = np.array(transform.GetMatrix()).reshape(3, 3) - matrix[:3, :3] = rotation - matrix[:3, 3] = transform.GetTranslation() - return matrix - - @staticmethod - def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform: - transform = sitk.Euler3DTransform() - rotation = matrix[:3, :3].flatten().tolist() - transform.SetMatrix(rotation) - transform.SetTranslation(matrix[:3, 3]) - return transform - def resample_images( self, image: sitk.Image, @@ -248,10 +248,10 @@ def resample_images( ) -> List[sitk.Image]: floating = reference = image default_value = np.float64(sitk.GetArrayViewFromImage(image).min()) + interpolator = self.get_sitk_interpolator(interpolation) transforms = transforms[1:] # first is identity images = [image] # first is identity for transform in transforms: - interpolator = self.get_sitk_interpolator(interpolation) resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(interpolator) resampler.SetReferenceImage(reference) @@ -277,6 +277,7 @@ def add_artifact( image: sitk.Image, transforms: Sequence[sitk.Euler3DTransform], times: np.ndarray, + axis: int, interpolation: str, ): images = self.resample_images(image, transforms, interpolation) @@ -287,12 +288,14 @@ def add_artifact( spectra.append(spectrum) self.sort_spectra(spectra, times) result_spectrum = torch.empty_like(spectra[0]) - last_index = result_spectrum.shape[2] + last_index = result_spectrum.shape[axis] indices = (last_index * times).astype(int).tolist() indices.append(last_index) ini = 0 + slices = [slice(None)] * len(result_spectrum.shape) for spectrum, fin in zip(spectra, indices): - result_spectrum[..., ini:fin] = spectrum[..., ini:fin] + slices[axis] = slice(ini, fin) + result_spectrum[slices] = spectrum[slices] ini = fin result_image = self.inv_fourier_transform(result_spectrum).real.float() return result_image diff --git a/src/torchio/transforms/augmentation/spatial/random_anisotropy.py b/src/torchio/transforms/augmentation/spatial/random_anisotropy.py index 2711f415c..de439d511 100644 --- a/src/torchio/transforms/augmentation/spatial/random_anisotropy.py +++ b/src/torchio/transforms/augmentation/spatial/random_anisotropy.py @@ -1,4 +1,3 @@ -import warnings from typing import Tuple from typing import Union @@ -7,7 +6,6 @@ from .. import RandomTransform from ....data.subject import Subject from ....typing import TypeRangeFloat -from ....utils import to_tuple from ...preprocessing import Resample @@ -48,7 +46,7 @@ class RandomAnisotropy(RandomTransform): def __init__( self, - axes: Union[int, Tuple[int, ...]] = (0, 1, 2), + axes: Union[int, Tuple[int, ...], str, Tuple[str, ...]] = (0, 1, 2), downsampling: TypeRangeFloat = (1.5, 5), image_interpolation: str = 'linear', scalars_only: bool = True, @@ -74,27 +72,10 @@ def get_params( downsampling = self.sample_uniform(*downsampling_range) return axis, downsampling - @staticmethod - def parse_axes(axes: Union[int, Tuple[int, ...]]): - axes_tuple = to_tuple(axes) - for axis in axes_tuple: - is_int = isinstance(axis, int) - if not is_int or axis not in (0, 1, 2): - raise ValueError('All axes must be 0, 1 or 2') - return axes_tuple - def apply_transform(self, subject: Subject) -> Subject: - is_2d = subject.get_first_image().is_2d() - if is_2d and 2 in self.axes: - warnings.warn( - f'Input image is 2D, but "2" is in axes: {self.axes}', - RuntimeWarning, - stacklevel=2, - ) - self.axes = list(self.axes) - self.axes.remove(2) + axes = self.ensure_axes_indices(subject, self.axes) axis, downsampling = self.get_params( - self.axes, + axes, self.downsampling_range, ) target_spacing = list(subject.spacing) diff --git a/src/torchio/transforms/augmentation/spatial/random_flip.py b/src/torchio/transforms/augmentation/spatial/random_flip.py index b61237b30..abbacc177 100644 --- a/src/torchio/transforms/augmentation/spatial/random_flip.py +++ b/src/torchio/transforms/augmentation/spatial/random_flip.py @@ -8,7 +8,6 @@ from .. import RandomTransform from ... import SpatialTransform from ....data.subject import Subject -from ....utils import to_tuple class RandomFlip(RandomTransform, SpatialTransform): @@ -45,11 +44,11 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.axes = _parse_axes(axes) + self.axes = self.parse_axes(axes) self.flip_probability = self.parse_probability(flip_probability) def apply_transform(self, subject: Subject) -> Subject: - potential_axes = _ensure_axes_indices(subject, self.axes) + potential_axes = self.ensure_axes_indices(subject, self.axes) axes_to_flip_hot = self.get_params(self.flip_probability) for i in range(3): if i not in potential_axes: @@ -85,51 +84,31 @@ class Flip(SpatialTransform): image orientation is not known. """ - def __init__(self, axes, **kwargs): + def __init__( + self, axes: Union[int, Tuple[int, ...], str, Tuple[str, ...]], **kwargs + ): super().__init__(**kwargs) - self.axes = _parse_axes(axes) - self.args_names = ('axes',) + self.axes = self.parse_axes(axes) + self.args_names = ['axes'] def apply_transform(self, subject: Subject) -> Subject: - axes = _ensure_axes_indices(subject, self.axes) + axes = self.ensure_axes_indices(subject, self.axes) for image in self.get_images(subject): - _flip_image(image, axes) + self.flip_image(image, axes) return subject + @staticmethod + def flip_image(image, axes): + spatial_axes = np.array(axes, int) + 1 + data = image.numpy() + data = np.flip(data, axis=spatial_axes) + data = data.copy() # remove negative strides + data = torch.as_tensor(data) + image.set_data(data) + @staticmethod def is_invertible(): return True def inverse(self): return self - - -def _parse_axes(axes: Union[int, Tuple[int, ...]]): - axes_tuple = to_tuple(axes) - for axis in axes_tuple: - is_int = isinstance(axis, int) - is_string = isinstance(axis, str) - valid_number = is_int and axis in (0, 1, 2) - if not is_string and not valid_number: - message = ( - f'All axes must be 0, 1 or 2, but found "{axis}" with type {type(axis)}' - ) - raise ValueError(message) - return axes_tuple - - -def _ensure_axes_indices(subject, axes): - if any(isinstance(n, str) for n in axes): - subject.check_consistent_orientation() - image = subject.get_first_image() - axes = sorted(3 + image.axis_name_to_index(n) for n in axes) - return axes - - -def _flip_image(image, axes): - spatial_axes = np.array(axes, int) + 1 - data = image.numpy() - data = np.flip(data, axis=spatial_axes) - data = data.copy() # remove negative strides - data = torch.as_tensor(data) - image.set_data(data) diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 26f522443..2618af79a 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -226,6 +226,12 @@ def to_range(n, around): def parse_params(self, params, around, name, make_ranges=True, **kwargs): params = to_tuple(params) + if make_ranges and any(isinstance(p, (str, bytes)) for p in params): + message = ( + f'"{name}" must be a number or a sequence of numbers for' + f' make_ranges=True, not {params}' + ) + raise ValueError(message) # d or (a, b) if len(params) == 1 or (len(params) == 2 and make_ranges): params *= 3 # (d, d, d) or (a, b, a, b, a, b) @@ -388,6 +394,42 @@ def parse_include_and_exclude_keys( Transform.validate_keys_sequence(label_keys, 'label_keys') return include, exclude + @staticmethod + def parse_axes( + axes: Union[int, str, Tuple[int, ...], Tuple[str, ...]], + ) -> Union[Tuple[int, ...], Tuple[str, ...]]: + axes_tuple = to_tuple(axes) + for axis in axes_tuple: + valid_number = isinstance(axis, int) and axis in (0, 1, 2) + valid_str = isinstance(axis, str) and axis[0].upper() in 'LRAPSITB' + if not valid_str and not valid_number: + message = ( + f'All axes must be 0, 1 or 2 or axis strings, ' + f'but found "{axis}" with type {type(axis)}' + ) + raise ValueError(message) + return tuple(sorted(set(axes_tuple))) + + @staticmethod + def ensure_axes_indices( + subject: Subject, + axes: Union[Tuple[int, ...], Tuple[str, ...]], + ) -> Tuple[int, ...]: + image = subject.get_first_image() + if any(isinstance(n, str) for n in axes): # axis strings + subject.check_consistent_orientation() + int_axes = tuple( + { + (3 + image.axis_name_to_index(n)) if isinstance(n, str) else int(n) + for n in axes + } + ) + if image.is_2d() and 2 in int_axes: + list_axes = list(int_axes) + list_axes.remove(2) + int_axes = tuple(list_axes) + return int_axes + @staticmethod def validate_keys_sequence(keys: TypeKeys, name: str) -> None: """Ensure that the input is not a string but a sequence of strings.""" diff --git a/src/torchio/utils.py b/src/torchio/utils.py index 6f7eecc4a..6a16ef236 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -7,6 +7,7 @@ import sys import tempfile from pathlib import Path +from collections import abc from typing import Any from typing import Dict from typing import Iterable @@ -25,14 +26,13 @@ from tqdm.auto import trange from . import constants -from .typing import TypeNumber from .typing import TypePath def to_tuple( - value: Any, + value: Union[Any, Iterable[Any]], length: int = 1, -) -> Tuple[TypeNumber, ...]: +) -> Tuple[Any, ...]: """Convert variable to tuple of length n. Example: @@ -52,10 +52,9 @@ def to_tuple( >>> to_tuple([1, 2], length=3) (1, 2) """ - try: - iter(value) + if isinstance(value, abc.Iterable) and not isinstance(value, (str, bytes)): value = tuple(value) - except TypeError: + else: value = length * (value,) return value @@ -386,7 +385,7 @@ def guess_external_viewer() -> Optional[Path]: def parse_spatial_shape(shape): result = to_tuple(shape, length=3) for n in result: - if n < 1 or n % 1: + if isinstance(n, (str, bytes)) or n < 1 or n % 1: message = ( 'All elements in a spatial shape must be positive integers,' f' but the following shape was passed: {shape}' diff --git a/tests/transforms/augmentation/test_random_flip.py b/tests/transforms/augmentation/test_random_flip.py index 79bc52f91..d766b591b 100644 --- a/tests/transforms/augmentation/test_random_flip.py +++ b/tests/transforms/augmentation/test_random_flip.py @@ -31,6 +31,10 @@ def test_wrong_flip_probability_type(self): with pytest.raises(ValueError): tio.RandomFlip(flip_probability='wrong') + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + tio.RandomFlip(axes=('g',)) + def test_anatomical_axis(self): transform = tio.RandomFlip(axes=['i'], flip_probability=1) tensor = torch.rand(1, 2, 3, 4) diff --git a/tests/transforms/augmentation/test_random_ghosting.py b/tests/transforms/augmentation/test_random_ghosting.py index 944aafe49..b969538aa 100644 --- a/tests/transforms/augmentation/test_random_ghosting.py +++ b/tests/transforms/augmentation/test_random_ghosting.py @@ -31,6 +31,14 @@ def test_with_ghosting(self): transformed.t1.data, ) + def test_anatomical_axis(self): + transform = RandomGhosting(axes=['a']) + transformed = transform(self.sample_subject) + self.assert_tensor_not_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + def test_intensity_range_with_negative_min(self): with pytest.raises(ValueError): RandomGhosting(intensity=(-0.5, 4)) @@ -74,3 +82,7 @@ def test_out_of_range_restore(self): def test_wrong_restore_type(self): with pytest.raises(TypeError): RandomGhosting(restore='wrong') + + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + RandomGhosting(axes=('v',)) diff --git a/tests/transforms/augmentation/test_random_motion.py b/tests/transforms/augmentation/test_random_motion.py index 69678434a..70bf3a5b6 100644 --- a/tests/transforms/augmentation/test_random_motion.py +++ b/tests/transforms/augmentation/test_random_motion.py @@ -35,6 +35,14 @@ def test_with_movement(self): transformed.t1.data, ) + def test_anatomical_axis(self): + transform = RandomMotion(axes=('a',)) + transformed = transform(self.sample_subject) + self.assert_tensor_not_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + def test_negative_degrees(self): with pytest.raises(ValueError): RandomMotion(degrees=-10) @@ -58,3 +66,19 @@ def test_wrong_image_interpolation_type(self): def test_wrong_image_interpolation_value(self): with pytest.raises(ValueError): RandomMotion(image_interpolation='wrong') + + def test_out_of_range_axis(self): + with pytest.raises(ValueError): + RandomMotion(axes=3) + + def test_out_of_range_axis_in_tuple(self): + with pytest.raises(ValueError): + RandomMotion(axes=(0, -1, 2)) + + def test_wrong_axes_type(self): + with pytest.raises(ValueError): + RandomMotion(axes=None) + + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + RandomMotion(axes=('C',))