Skip to content

Commit fd89b21

Browse files
oestebanjhlegarreta
authored andcommitted
enh: add downsample() in addition to decimate()
1 parent b4e5173 commit fd89b21

File tree

2 files changed

+176
-32
lines changed

2 files changed

+176
-32
lines changed

src/nifreeze/data/filtering.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import numpy as np
3030
from nibabel import Nifti1Image, load
31+
from nibabel.affines import apply_affine, voxel_sizes
3132
from scipy.ndimage import gaussian_filter as _gs
3233
from scipy.ndimage import map_coordinates, median_filter
3334
from skimage.morphology import ball
@@ -141,15 +142,15 @@ def gaussian_filter(
141142
return _gs(data, vox_width)
142143

143144

144-
def decimate(
145+
def downsample(
145146
in_file: str,
146-
factor: int | tuple[int, int, int],
147+
shape: tuple[int, int, int],
147148
smooth: bool | tuple[int, int, int] = True,
148149
order: int = 3,
149150
nonnegative: bool = True,
150151
) -> Nifti1Image:
151152
"""
152-
Decimates a 3D or 4D Nifti image by a specified downsampling factor.
153+
Downsamples a 3D or 4D Nifti image by a specified downsampling factor.
153154
154155
This function downsamples a Nifti image by averaging voxels within a user-defined
155156
factor in each spatial dimension. It optionally applies Gaussian smoothing
@@ -186,47 +187,54 @@ def decimate(
186187

187188
imnii = load(in_file)
188189
data = np.squeeze(imnii.get_fdata()) # Remove unused dimensions
189-
datashape = data.shape
190+
datashape = np.array(data.shape)
191+
shape = np.array(shape)
190192
ndim = data.ndim
191193

192-
if isinstance(factor, Number):
193-
factor = tuple([factor] * min(3, ndim))
194-
195-
if any(f <= 0 for f in factor[:3]):
196-
raise ValueError("All spatial downsampling factors must be positive.")
197-
198-
if ndim == 4 and len(factor) == 3:
199-
factor = (*factor, 0)
200-
201194
if smooth:
202195
if smooth is True:
203-
smooth = factor
196+
smooth = datashape[:3] / shape[:3]
204197
data = gaussian_filter(data, smooth)
205198

199+
extents = (
200+
apply_affine(imnii.affine, datashape - 0.5)
201+
- apply_affine(imnii.affine, (-0.5, -0.5, -0.5))
202+
)
203+
newzooms = extents / shape
204+
205+
# Update affine transformation
206+
newaffine = np.eye(4)
207+
oldzooms = voxel_sizes(imnii.affine)
208+
newaffine[:3, :3] = np.diag(newzooms / oldzooms) @ imnii.affine[:3, :3]
209+
210+
# Update offset so new array is aligned with original
211+
newaffine[:3, 3] = (
212+
apply_affine(imnii.affine, 0.5 * datashape)
213+
- apply_affine(newaffine, 0.5 * shape)
214+
)
215+
216+
xfm = np.linalg.inv(imnii.affine) @ newaffine
217+
206218
# Create downsampled grid
207219
down_grid = np.array(
208220
np.meshgrid(
209-
*[np.arange(_s, step=int(_f) or 1) for _s, _f in zip(datashape, factor)],
221+
*[np.arange(_s, step=1) for _s in shape],
210222
indexing="ij",
211223
)
212224
)
213-
new_shape = down_grid.shape[1:]
214-
215-
# Update affine transformation
216-
newaffine = imnii.affine.copy()
217-
newaffine[:3, :3] = np.array(factor[:3]) * newaffine[:3, :3]
218225

219-
# TODO: Update offset so new array is aligned with original
226+
# Locations is an Nx3 array of index coordinates of the original image where we sample
227+
locations = apply_affine(xfm, down_grid.reshape((ndim, np.prod(shape))).T)
220228

221229
# Resample data on the new grid
222230
resampled = map_coordinates(
223231
data,
224-
down_grid.reshape((ndim, np.prod(new_shape))),
232+
locations.T,
225233
order=order,
226234
mode="constant",
227235
cval=0,
228236
prefilter=True,
229-
).reshape(new_shape)
237+
).reshape(shape)
230238

231239
# Set negative values to zero (optional)
232240
if order > 2 and nonnegative:
@@ -238,3 +246,80 @@ def decimate(
238246
newnii.set_qform(newaffine, code=1)
239247

240248
return newnii
249+
250+
251+
def decimate(
252+
in_file: str,
253+
factor: int | tuple[int, int, int],
254+
smooth: bool | tuple[int, int, int] = True,
255+
order: int = 3,
256+
nonnegative: bool = True,
257+
) -> Nifti1Image:
258+
"""
259+
Decimates a 3D or 4D Nifti image by a specified downsampling factor.
260+
261+
This function downsamples a Nifti image by averaging voxels within a user-defined
262+
factor in each spatial dimension. It optionally applies Gaussian smoothing
263+
before downsampling to reduce aliasing artifacts. The function also handles
264+
updating the affine transformation matrix to reflect the change in voxel size.
265+
266+
Parameters
267+
----------
268+
in_file : :obj:`str`
269+
Path to the input NIfTI image file.
270+
factor : :obj:`int` or :obj:`tuple`
271+
The downsampling factor. If a single integer is provided, it is applied
272+
uniformly across all spatial dimensions. Alternatively, a tuple of three
273+
integers can be provided to specify different downsampling factors for each
274+
spatial dimension (x, y, z). Values must be greater than 0.
275+
smooth : :obj:`bool` or :obj:`tuple`, optional (default=``True``)
276+
Controls application of Gaussian smoothing before downsampling. If True,
277+
a smoothing kernel size equal to the downsampling factor is applied.
278+
Alternatively, a tuple of three integers can be provided to specify
279+
different smoothing kernel sizes for each spatial dimension. Setting to
280+
False disables smoothing.
281+
order : :obj:`int`, optional (default=3)
282+
The order of the spline interpolation used for downsampling. Higher
283+
orders provide smoother results but are computationally more expensive.
284+
nonnegative : :obj:`bool`, optional (default=``True``)
285+
If True, negative values in the downsampled data are set to zero.
286+
287+
Returns
288+
-------
289+
:obj:`~nibabel.Nifti1Image`
290+
The downsampled NIfTI image object.
291+
292+
"""
293+
294+
imnii = load(in_file)
295+
data = np.squeeze(imnii.get_fdata()) # Remove unused dimensions
296+
ndim = data.ndim
297+
298+
if isinstance(factor, Number):
299+
factor = tuple([factor] * min(3, ndim))
300+
301+
if any(f <= 0 for f in factor[:3]):
302+
raise ValueError("All spatial downsampling factors must be positive.")
303+
304+
if ndim == 4 and len(factor) == 3:
305+
factor = (*factor, 0)
306+
307+
if smooth:
308+
if smooth is True:
309+
smooth = factor
310+
data = gaussian_filter(data, smooth)
311+
312+
# Update affine transformation
313+
newaffine = imnii.affine.copy()
314+
newaffine[:3, :3] = np.array(factor[:3]) * newaffine[:3, :3]
315+
316+
# Create new Nifti image with updated information
317+
newnii = Nifti1Image(
318+
data[::factor[0], ::factor[1], ::factor[2]],
319+
newaffine,
320+
imnii.header,
321+
)
322+
newnii.set_sform(newaffine, code=1)
323+
newnii.set_qform(newaffine, code=1)
324+
325+
return newnii

test/test_filtering.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,70 @@
2626

2727
import pytest
2828

29-
from nifreeze.data.filtering import decimate
29+
from nifreeze.data.filtering import decimate, downsample
3030

3131

3232
@pytest.mark.parametrize(
3333
("size", "block_size"),
3434
[
35-
((20, 20, 20), (5, 5, 5),)
35+
((20, 20, 20), (5, 5, 5),),
36+
((21, 21, 21), (5, 5, 5),),
3637
],
3738
)
38-
def test_decimation(tmp_path, size, block_size):
39+
@pytest.mark.parametrize(
40+
("zoom_x", ),
41+
# [(1.0, ), (-1.0, ), (2.0, ), (-2.0, )],
42+
[(2.0,)],
43+
)
44+
@pytest.mark.parametrize(
45+
("zoom_y", ),
46+
# [(1.0, ), (-1.0, ), (2.0, ), (-2.0, )],
47+
[(-2.0,)],
48+
)
49+
@pytest.mark.parametrize(
50+
("zoom_z", ),
51+
# [(1.0, ), (-1.0, ), (2.0, ), (-2.0, )],
52+
[(-2.0,)],
53+
)
54+
@pytest.mark.parametrize(
55+
("angle_x", ),
56+
# [(0.0, ), (0.2, ), (-0.05, )],
57+
[(-0.05,)]
58+
)
59+
@pytest.mark.parametrize(
60+
("angle_y", ),
61+
[(0.0, ), (0.2, ), (-0.05, )],
62+
)
63+
@pytest.mark.parametrize(
64+
("angle_z", ),
65+
[(0.0, ), (0.2, ), (-0.05, )],
66+
)
67+
@pytest.mark.parametrize(
68+
("offsets", ),
69+
[
70+
(None, ),
71+
((0.0, 0.0, 0.0),),
72+
],
73+
)
74+
def test_decimation(
75+
tmp_path,
76+
size,
77+
block_size,
78+
zoom_x,
79+
zoom_y,
80+
zoom_z,
81+
angle_x,
82+
angle_y,
83+
angle_z,
84+
offsets,
85+
):
3986
"""Exercise decimation."""
4087

4188
# Calculate the number of sub-blocks in each dimension
4289
num_blocks = [s // b for s, b in zip(size, block_size)]
4390

4491
# Create the empty array
45-
voxel_array = np.zeros(size, dtype=int)
92+
voxel_array = np.zeros(size, dtype=np.uint16)
4693

4794
# Fill the array with increasing values based on sub-block position
4895
current_block = 0
@@ -58,11 +105,23 @@ def test_decimation(tmp_path, size, block_size):
58105

59106
fname = tmp_path / "test_img.nii.gz"
60107

61-
nb.Nifti1Image(voxel_array, None, None).to_filename(fname)
108+
affine = np.eye(4)
109+
affine[:3, :3] = (
110+
nb.eulerangles.euler2mat(x=angle_x, y=angle_y, z=angle_z)
111+
@ np.diag((zoom_x, zoom_y, zoom_z))
112+
@ affine[:3, :3]
113+
)
62114

63-
# Need to define test oracle. For now, just see if it doesn't smoke.
64-
decimate(fname, factor=2, smooth=False, order=1)
115+
if offsets is None:
116+
affine[:3, 3] = -0.5 * nb.affines.apply_affine(affine, np.array(size) - 1)
65117

66-
# out.to_filename(tmp_path / "decimated.nii.gz")
118+
test_image = nb.Nifti1Image(voxel_array.astype(np.uint16), affine, None)
119+
test_image.header.set_data_dtype(np.uint16)
120+
test_image.to_filename(fname)
121+
122+
# Need to define test oracle. For now, just see if it doesn't smoke.
123+
out = decimate(fname, factor=2, smooth=False, order=1)
124+
out.to_filename(tmp_path / "decimated.nii.gz")
67125

68-
# import pdb; pdb.set_trace()
126+
out = downsample(fname, shape=(10, 10, 10), smooth=False, order=1)
127+
out.to_filename(tmp_path / "downsampled.nii.gz")

0 commit comments

Comments
 (0)