@@ -214,8 +214,11 @@ class SpatialCrop(Transform):
214
214
"""
215
215
General purpose cropper to produce sub-volume region of interest (ROI).
216
216
It can support to crop ND spatial (channel-first) data.
217
- Either a spatial center and size must be provided, or alternatively,
218
- if center and size are not provided, the start and end coordinates of the ROI must be provided.
217
+
218
+ The cropped region can be parameterised in various ways:
219
+ - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)
220
+ - a spatial center and size
221
+ - the start and end coordinates of the ROI
219
222
"""
220
223
221
224
def __init__ (
@@ -224,35 +227,44 @@ def __init__(
224
227
roi_size : Union [Sequence [int ], np .ndarray , None ] = None ,
225
228
roi_start : Union [Sequence [int ], np .ndarray , None ] = None ,
226
229
roi_end : Union [Sequence [int ], np .ndarray , None ] = None ,
230
+ roi_slices : Optional [Sequence [slice ]] = None ,
227
231
) -> None :
228
232
"""
229
233
Args:
230
234
roi_center: voxel coordinates for center of the crop ROI.
231
235
roi_size: size of the crop ROI.
232
236
roi_start: voxel coordinates for start of the crop ROI.
233
237
roi_end: voxel coordinates for end of the crop ROI.
238
+ roi_slices: list of slices for each of the spatial dimensions.
234
239
"""
235
- if roi_center is not None and roi_size is not None :
236
- roi_center = np .asarray (roi_center , dtype = np .int16 )
237
- roi_size = np .asarray (roi_size , dtype = np .int16 )
238
- self .roi_start = np .maximum (roi_center - np .floor_divide (roi_size , 2 ), 0 )
239
- self .roi_end = np .maximum (self .roi_start + roi_size , self .roi_start )
240
+ if roi_slices :
241
+ if not all (s .step is None or s .step == 1 for s in roi_slices ):
242
+ raise ValueError ("Only slice steps of 1/None are currently supported" )
243
+ self .slices = list (roi_slices )
240
244
else :
241
- if roi_start is None or roi_end is None :
242
- raise ValueError ("Please specify either roi_center, roi_size or roi_start, roi_end." )
243
- self .roi_start = np .maximum (np .asarray (roi_start , dtype = np .int16 ), 0 )
244
- self .roi_end = np .maximum (np .asarray (roi_end , dtype = np .int16 ), self .roi_start )
245
- # Allow for 1D by converting back to np.array (since np.maximum will convert to int)
246
- self .roi_start = self .roi_start if isinstance (self .roi_start , np .ndarray ) else np .array ([self .roi_start ])
247
- self .roi_end = self .roi_end if isinstance (self .roi_end , np .ndarray ) else np .array ([self .roi_end ])
245
+ if roi_center is not None and roi_size is not None :
246
+ roi_center = np .asarray (roi_center , dtype = np .int16 )
247
+ roi_size = np .asarray (roi_size , dtype = np .int16 )
248
+ roi_start_np = np .maximum (roi_center - np .floor_divide (roi_size , 2 ), 0 )
249
+ roi_end_np = np .maximum (roi_start_np + roi_size , roi_start_np )
250
+ else :
251
+ if roi_start is None or roi_end is None :
252
+ raise ValueError ("Please specify either roi_center, roi_size or roi_start, roi_end." )
253
+ roi_start_np = np .maximum (np .asarray (roi_start , dtype = np .int16 ), 0 )
254
+ roi_end_np = np .maximum (np .asarray (roi_end , dtype = np .int16 ), roi_start_np )
255
+ # Allow for 1D by converting back to np.array (since np.maximum will convert to int)
256
+ roi_start_np = roi_start_np if isinstance (roi_start_np , np .ndarray ) else np .array ([roi_start_np ])
257
+ roi_end_np = roi_end_np if isinstance (roi_end_np , np .ndarray ) else np .array ([roi_end_np ])
258
+ # convert to slices
259
+ self .slices = [slice (s , e ) for s , e in zip (roi_start_np , roi_end_np )]
248
260
249
261
def __call__ (self , img : Union [np .ndarray , torch .Tensor ]):
250
262
"""
251
263
Apply the transform to `img`, assuming `img` is channel-first and
252
264
slicing doesn't apply to the channel dim.
253
265
"""
254
- sd = min (self .roi_start . size , self . roi_end . size , len (img .shape [1 :])) # spatial dims
255
- slices = [slice (None )] + [ slice ( s , e ) for s , e in zip ( self .roi_start [:sd ], self . roi_end [: sd ]) ]
266
+ sd = min (len ( self .slices ) , len (img .shape [1 :])) # spatial dims
267
+ slices = [slice (None )] + self .slices [:sd ]
256
268
return img [tuple (slices )]
257
269
258
270
0 commit comments