Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/transforms/augmentation/test_random_crop_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torchio as tio
from ...utils import TorchioTestCase


class TestRandomCropOrPad(TorchioTestCase):
"""Tests for `RandomCropOrPad`."""
def test_no_changes(self):
sample_t1 = self.sample_subject['t1']
shape = sample_t1.spatial_shape
transform = tio.RandomCropOrPad(shape)
transformed = transform(self.sample_subject)
self.assertTensorEqual(sample_t1.data, transformed['t1'].data)
self.assertTensorEqual(sample_t1.affine, transformed['t1'].affine)

def test_different_shape(self):
shape = self.sample_subject['t1'].spatial_shape
target_shape = 9, 21, 30
transform = tio.RandomCropOrPad(target_shape)
transformed = transform(self.sample_subject)
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertNotEqual(shape, result_shape)

def test_shape_right(self):
target_shape = 9, 21, 30
transform = tio.RandomCropOrPad(target_shape)
transformed = transform(self.sample_subject)
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_pad(self):
target_shape = 11, 22, 30
transform = tio.RandomCropOrPad(target_shape)
transformed = transform(self.sample_subject)
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_crop(self):
target_shape = 9, 18, 30
transform = tio.RandomCropOrPad(target_shape)
transformed = transform(self.sample_subject)
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape)

def test_shape_negative(self):
with self.assertRaises(ValueError):
tio.RandomCropOrPad(-1)

def test_shape_float(self):
with self.assertRaises(ValueError):
tio.RandomCropOrPad(2.5)

def test_shape_one(self):
transform = tio.RandomCropOrPad(1)
transformed = transform(self.sample_subject)
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual((1, 1, 1), result_shape)
5 changes: 3 additions & 2 deletions tests/transforms/test_invertibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ class TestInvertibility(TorchioTestCase):
def test_all_random_transforms(self):
transform = self.get_large_composed_transform()
# Remove RandomLabelsToImage as it will add a new image to the subject
# Remove RandomCropOrPad as it will change the dimension
for t in transform.transforms:
if t.name == 'RandomLabelsToImage':
if t.name == 'RandomLabelsToImage' or t.name == 'RandomCropOrPad':
transform.transforms.remove(t)
break

# Ignore elastic deformation and gamma warnings during execution
# Ignore some transforms not invertible
with warnings.catch_warnings():
Expand Down
3 changes: 2 additions & 1 deletion torchio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .augmentation.spatial import RandomFlip, Flip
from .augmentation.spatial import RandomAffine, Affine
from .augmentation.spatial import RandomAnisotropy
from .augmentation.spatial import RandomAnisotropy, RandomCropOrPad
from .augmentation.spatial import RandomElasticDeformation, ElasticDeformation

from .augmentation.intensity import RandomSwap, Swap
Expand Down Expand Up @@ -94,6 +94,7 @@
'Clamp',
'Mask',
'CropOrPad',
'RandomCropOrPad',
'CopyAffine',
'EnsureShapeMultiple',
'train_histogram',
Expand Down
2 changes: 2 additions & 0 deletions torchio/transforms/augmentation/spatial/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .random_flip import RandomFlip, Flip
from .random_affine import RandomAffine, Affine
from .random_anisotropy import RandomAnisotropy
from .random_crop_or_pad import RandomCropOrPad
from .random_elastic_deformation import (
RandomElasticDeformation,
ElasticDeformation,
Expand All @@ -12,6 +13,7 @@
'Flip',
'RandomAffine',
'Affine',
'RandomCropOrPad',
'RandomAnisotropy',
'RandomElasticDeformation',
'ElasticDeformation',
Expand Down
122 changes: 122 additions & 0 deletions torchio/transforms/augmentation/spatial/random_crop_or_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Union, Tuple, Optional

import numpy as np
from random import randint

from ...preprocessing.spatial.crop_or_pad import CropOrPad
from ...preprocessing.spatial.pad import Pad
from ...preprocessing.spatial.crop import Crop
from ... import SpatialTransform
from ...transform import TypeTripletInt, TypeSixBounds
from ....utils import parse_spatial_shape
from ....data.subject import Subject


class RandomCropOrPad(SpatialTransform):
"""Modify the field of view by random cropping or padding to a target shape.

This transform modifies the affine matrix associated to the volume so that
physical positions of the voxels are maintained.

Args:
target_shape: Tuple :math:`(W, H, D)`. If a single value :math:`N` is
provided, then :math:`W = H = D = N`.
padding_mode: Same as :attr:`padding_mode` in
:class:`~torchio.transforms.Pad`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.

Example:
>>> import torchio as tio
>>> subject = tio.Subject(
... chest_ct=tio.ScalarImage('subject_a_ct.nii.gz'),
... heart_mask=tio.LabelMap('subject_a_heart_seg.nii.gz'),
... )
>>> subject.chest_ct.shape
torch.Size([1, 512, 512, 289])
>>> transform = tio.RandomCropOrPad(
... (120, 80, 180)
... )
>>> transformed = transform(subject)
>>> transformed.chest_ct.shape
torch.Size([1, 120, 80, 180])

.. plot::

import torchio as tio
t1 = tio.datasets.Colin27().t1
crop_pad = tio.RandomCropOrPad((256, 256, 32))
t1_pad_crop = crop_pad(t1)
subject = tio.Subject(t1=t1, crop_pad=t1_pad_crop)
subject.plot()
""" # noqa: E501

def __init__(
self,
target_shape: Union[int, TypeTripletInt, None] = 16,
padding_mode: Union[str, float] = 0,
**kwargs
):
super().__init__(**kwargs)
self.target_shape = parse_spatial_shape(target_shape)
self.padding_mode = padding_mode

def _compute_random_cropping_padding_from_shapes(
self, source_shape: TypeTripletInt,
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
diff_shape = np.array(self.target_shape) - source_shape

cropping = -np.minimum(diff_shape, 0)
if cropping.any():
cropping_params = CropOrPad._get_six_bounds_parameters(cropping)
# adjust the cropping params by a random amount
# note: randint(0, 0) will return 0
random_x = randint(-cropping[0] // 2, cropping[0] // 2)
random_y = randint(-cropping[1] // 2, cropping[1] // 2)
random_z = randint(-cropping[2] // 2, cropping[2] // 2)
cropping_params = [
cropping_params[0] + random_x,
cropping_params[1] - random_x,
cropping_params[2] + random_y,
cropping_params[3] - random_y,
cropping_params[4] + random_z,
cropping_params[5] - random_z,
]
else:
cropping_params = None

padding = np.maximum(diff_shape, 0)
if padding.any():
padding_params = CropOrPad._get_six_bounds_parameters(padding)
# adjust the padding params by a random amount
# note: randint(0, 0) will return 0
random_x = randint(-padding[0] // 2, padding[0] // 2)
random_y = randint(-padding[1] // 2, padding[1] // 2)
random_z = randint(-padding[2] // 2, padding[2] // 2)
padding_params = [
padding_params[0] + random_x,
padding_params[1] - random_x,
padding_params[2] + random_y,
padding_params[3] - random_y,
padding_params[4] + random_z,
padding_params[5] - random_z,
]
else:
padding_params = None

return padding_params, cropping_params

def apply_transform(self, subject: Subject) -> Subject:
subject.check_consistent_space()
source_shape = subject.spatial_shape
(
padding_params,
cropping_params,
) = self._compute_random_cropping_padding_from_shapes(source_shape)
padding_kwargs = {'padding_mode': self.padding_mode}

if padding_params is not None:
subject = Pad(padding_params, **padding_kwargs)(subject)
if cropping_params is not None:
subject = Crop(cropping_params)(subject)
return subject