Skip to content

Commit 9f4da6a

Browse files
authored
allow slicing with SpatialCrop(d) (#1954)
* allow slicing with SpatialCrop(d) Signed-off-by: Richard Brown <[email protected]>
1 parent 357e6d4 commit 9f4da6a

File tree

6 files changed

+96
-47
lines changed

6 files changed

+96
-47
lines changed

monai/apps/deepgrow/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,10 @@ def __call__(self, data):
698698
cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
699699
else:
700700
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
701-
box_start, box_end = cropper.roi_start, cropper.roi_end
702701

702+
# update bounding box in case it was corrected by the SpatialCrop constructor
703+
box_start = np.array([s.start for s in cropper.slices])
704+
box_end = np.array([s.stop for s in cropper.slices])
703705
for key in self.key_iterator(d):
704706
if not np.array_equal(d[key].shape[1:], original_spatial_shape):
705707
raise RuntimeError("All the image specified in keys should have same spatial shape")

monai/transforms/croppad/array.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,11 @@ class SpatialCrop(Transform):
214214
"""
215215
General purpose cropper to produce sub-volume region of interest (ROI).
216216
It can support to crop ND spatial (channel-first) data.
217-
Either a spatial center and size must be provided, or alternatively,
218-
if center and size are not provided, the start and end coordinates of the ROI must be provided.
217+
218+
The cropped region can be parameterised in various ways:
219+
- a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)
220+
- a spatial center and size
221+
- the start and end coordinates of the ROI
219222
"""
220223

221224
def __init__(
@@ -224,35 +227,44 @@ def __init__(
224227
roi_size: Union[Sequence[int], np.ndarray, None] = None,
225228
roi_start: Union[Sequence[int], np.ndarray, None] = None,
226229
roi_end: Union[Sequence[int], np.ndarray, None] = None,
230+
roi_slices: Optional[Sequence[slice]] = None,
227231
) -> None:
228232
"""
229233
Args:
230234
roi_center: voxel coordinates for center of the crop ROI.
231235
roi_size: size of the crop ROI.
232236
roi_start: voxel coordinates for start of the crop ROI.
233237
roi_end: voxel coordinates for end of the crop ROI.
238+
roi_slices: list of slices for each of the spatial dimensions.
234239
"""
235-
if roi_center is not None and roi_size is not None:
236-
roi_center = np.asarray(roi_center, dtype=np.int16)
237-
roi_size = np.asarray(roi_size, dtype=np.int16)
238-
self.roi_start = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0)
239-
self.roi_end = np.maximum(self.roi_start + roi_size, self.roi_start)
240+
if roi_slices:
241+
if not all(s.step is None or s.step == 1 for s in roi_slices):
242+
raise ValueError("Only slice steps of 1/None are currently supported")
243+
self.slices = list(roi_slices)
240244
else:
241-
if roi_start is None or roi_end is None:
242-
raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.")
243-
self.roi_start = np.maximum(np.asarray(roi_start, dtype=np.int16), 0)
244-
self.roi_end = np.maximum(np.asarray(roi_end, dtype=np.int16), self.roi_start)
245-
# Allow for 1D by converting back to np.array (since np.maximum will convert to int)
246-
self.roi_start = self.roi_start if isinstance(self.roi_start, np.ndarray) else np.array([self.roi_start])
247-
self.roi_end = self.roi_end if isinstance(self.roi_end, np.ndarray) else np.array([self.roi_end])
245+
if roi_center is not None and roi_size is not None:
246+
roi_center = np.asarray(roi_center, dtype=np.int16)
247+
roi_size = np.asarray(roi_size, dtype=np.int16)
248+
roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0)
249+
roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np)
250+
else:
251+
if roi_start is None or roi_end is None:
252+
raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.")
253+
roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0)
254+
roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np)
255+
# Allow for 1D by converting back to np.array (since np.maximum will convert to int)
256+
roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np])
257+
roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np])
258+
# convert to slices
259+
self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)]
248260

249261
def __call__(self, img: Union[np.ndarray, torch.Tensor]):
250262
"""
251263
Apply the transform to `img`, assuming `img` is channel-first and
252264
slicing doesn't apply to the channel dim.
253265
"""
254-
sd = min(self.roi_start.size, self.roi_end.size, len(img.shape[1:])) # spatial dims
255-
slices = [slice(None)] + [slice(s, e) for s, e in zip(self.roi_start[:sd], self.roi_end[:sd])]
266+
sd = min(len(self.slices), len(img.shape[1:])) # spatial dims
267+
slices = [slice(None)] + self.slices[:sd]
256268
return img[tuple(slices)]
257269

258270

monai/transforms/croppad/dictionary.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,13 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
286286
class SpatialCropd(MapTransform, InvertibleTransform):
287287
"""
288288
Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`.
289-
Either a spatial center and size must be provided, or alternatively if center and size
290-
are not provided, the start and end coordinates of the ROI must be provided.
289+
General purpose cropper to produce sub-volume region of interest (ROI).
290+
It can support to crop ND spatial (channel-first) data.
291+
292+
The cropped region can be parameterised in various ways:
293+
- a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)
294+
- a spatial center and size
295+
- the start and end coordinates of the ROI
291296
"""
292297

293298
def __init__(
@@ -297,6 +302,7 @@ def __init__(
297302
roi_size: Optional[Sequence[int]] = None,
298303
roi_start: Optional[Sequence[int]] = None,
299304
roi_end: Optional[Sequence[int]] = None,
305+
roi_slices: Optional[Sequence[slice]] = None,
300306
allow_missing_keys: bool = False,
301307
) -> None:
302308
"""
@@ -307,10 +313,11 @@ def __init__(
307313
roi_size: size of the crop ROI.
308314
roi_start: voxel coordinates for start of the crop ROI.
309315
roi_end: voxel coordinates for end of the crop ROI.
316+
roi_slices: list of slices for each of the spatial dimensions.
310317
allow_missing_keys: don't raise exception if key is missing.
311318
"""
312319
super().__init__(keys, allow_missing_keys)
313-
self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end)
320+
self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices)
314321

315322
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
316323
d = dict(data)
@@ -325,9 +332,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
325332
for key in self.key_iterator(d):
326333
transform = self.get_most_recent_transform(d, key)
327334
# Create inverse transform
328-
orig_size = transform[InverseKeys.ORIG_SIZE]
329-
pad_to_start = np.array(self.cropper.roi_start)
330-
pad_to_end = orig_size - self.cropper.roi_end
335+
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
336+
current_size = np.array(d[key].shape[1:])
337+
# get required pad to start and end
338+
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)])
339+
pad_to_end = orig_size - current_size - pad_to_start
331340
# interleave mins and maxes
332341
pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))
333342
inverse_transform = BorderPad(pad)

tests/test_inverse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@
122122
)
123123
)
124124

125-
126125
TESTS.append(
127126
(
128127
"SpatialCropd 2d",
@@ -132,6 +131,15 @@
132131
)
133132
)
134133

134+
TESTS.append(
135+
(
136+
"SpatialCropd 3d",
137+
"3D",
138+
0,
139+
SpatialCropd(KEYS, roi_slices=[slice(s, e) for s, e in zip([None, None, -99], [None, -2, None])]),
140+
)
141+
)
142+
135143
TESTS.append(
136144
(
137145
"SpatialCropd 2d",

tests/test_spatial_crop.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
(3, 3, 3, 3),
4141
(3, 0, 3, 3),
4242
],
43+
[
44+
{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]},
45+
(3, 3, 3, 3),
46+
(3, 1, 2, 2),
47+
],
48+
]
49+
50+
TEST_ERRORS = [
51+
[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}],
4352
]
4453

4554

@@ -56,6 +65,11 @@ def test_tensor_shape(self, input_param, input_shape, expected_shape):
5665
result = SpatialCrop(**input_param)(input_data)
5766
self.assertTupleEqual(result.shape, expected_shape)
5867

68+
@parameterized.expand(TEST_ERRORS)
69+
def test_error(self, input_param):
70+
with self.assertRaises(ValueError):
71+
SpatialCrop(**input_param)
72+
5973

6074
if __name__ == "__main__":
6175
unittest.main()

tests/test_spatial_cropd.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,37 @@
1616

1717
from monai.transforms import SpatialCropd
1818

19-
TEST_CASE_1 = [
20-
{"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]},
21-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
22-
(3, 2, 2, 2),
23-
]
24-
25-
TEST_CASE_2 = [
26-
{"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]},
27-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
28-
(3, 2, 2, 2),
29-
]
30-
31-
TEST_CASE_3 = [
32-
{"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]},
33-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
34-
(3, 2, 2, 3),
35-
]
36-
37-
TEST_CASE_4 = [
38-
{"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]},
39-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
40-
(3, 2, 2, 2),
19+
TEST_CASES = [
20+
[
21+
{"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]},
22+
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
23+
(3, 2, 2, 2),
24+
],
25+
[
26+
{"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]},
27+
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
28+
(3, 2, 2, 2),
29+
],
30+
[
31+
{"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]},
32+
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
33+
(3, 2, 2, 3),
34+
],
35+
[
36+
{"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]},
37+
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
38+
(3, 2, 2, 2),
39+
],
40+
[
41+
{"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]},
42+
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
43+
(3, 1, 2, 2),
44+
],
4145
]
4246

4347

4448
class TestSpatialCropd(unittest.TestCase):
45-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
49+
@parameterized.expand(TEST_CASES)
4650
def test_shape(self, input_param, input_data, expected_shape):
4751
result = SpatialCropd(**input_param)(input_data)
4852
self.assertTupleEqual(result["img"].shape, expected_shape)

0 commit comments

Comments
 (0)