Skip to content

Commit c94ea90

Browse files
authored
Merge pull request #249 from jhlegarreta/tst/test-estimator-iterates-correctly
TST: Test that the estimator iterates correctly over the volumes
2 parents 0f701aa + 65d2fa1 commit c94ea90

File tree

3 files changed

+148
-4
lines changed

3 files changed

+148
-4
lines changed

src/nifreeze/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,13 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125125

126126
# Prepare iterator
127127
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
128-
index_iter = iterfunc(size=len(dataset), seed=kwargs.get("seed", None))
128+
index_iter = iterfunc(
129+
size=len(dataset),
130+
bvals=kwargs.pop("bvals", None),
131+
uptake=kwargs.pop("uptake", None),
132+
seed=kwargs.get("seed", None),
133+
round_decimals=kwargs.pop("round_decimals", iterators.DEFAULT_ROUND_DECIMALS),
134+
)
129135

130136
# Initialize model
131137
if isinstance(self._model, str):

src/nifreeze/utils/iterators.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from itertools import chain, zip_longest
2727
from typing import Iterator
2828

29+
DEFAULT_ROUND_DECIMALS = 2
30+
"""Round decimals to use when comparing values to be sorted for iteration purposes."""
31+
2932
SIZE_KEYS = ("size", "bvals", "uptake")
3033
"""Keys that may be used to infer the number of volumes in a dataset. When the
3134
size of the structure to iterate over is not given explicitly, these keys
@@ -168,7 +171,9 @@ def random_iterator(**kwargs) -> Iterator[int]:
168171
"""
169172

170173

171-
def _value_iterator(values: list, ascending: bool, round_decimals: int = 2) -> Iterator[int]:
174+
def _value_iterator(
175+
values: list, ascending: bool, round_decimals: int = DEFAULT_ROUND_DECIMALS
176+
) -> Iterator[int]:
172177
"""
173178
Traverse the given values in ascending or descenting order.
174179
@@ -231,7 +236,9 @@ def bvalue_iterator(*_, **kwargs) -> Iterator[int]:
231236
bvals = kwargs.pop(BVALS_KWARG, None)
232237
if bvals is None:
233238
raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG))
234-
return _value_iterator(bvals, ascending=True, **kwargs)
239+
return _value_iterator(
240+
bvals, ascending=True, round_decimals=kwargs.pop("round_decimals", DEFAULT_ROUND_DECIMALS)
241+
)
235242

236243

237244
def uptake_iterator(*_, **kwargs) -> Iterator[int]:
@@ -263,7 +270,11 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
263270
uptake = kwargs.pop(UPTAKE_KWARG, None)
264271
if uptake is None:
265272
raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG))
266-
return _value_iterator(uptake, ascending=False, **kwargs)
273+
return _value_iterator(
274+
uptake,
275+
ascending=False,
276+
round_decimals=kwargs.pop("round_decimals", DEFAULT_ROUND_DECIMALS),
277+
)
267278

268279

269280
def centralsym_iterator(**kwargs) -> Iterator[int]:

test/test_estimator.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323

24+
from typing import Union
25+
2426
import numpy as np
27+
import pytest
2528

29+
import nifreeze.estimator
2630
from nifreeze.data.base import BaseDataset
31+
from nifreeze.data.dmri import DEFAULT_LOWB_THRESHOLD
2732
from nifreeze.estimator import Estimator
2833
from nifreeze.model.base import BaseModel
34+
from nifreeze.utils import iterators
2935

3036
DATAOBJ_SIZE = (5, 5, 5, 4)
3137

@@ -57,6 +63,37 @@ def set_transform(self, idx, matrix):
5763
pass
5864

5965

66+
class DummyDWIDataset(BaseDataset):
67+
def __init__(self, dwi_dataobj, affine, brainmask_dataobj, b0_dataobj, gradients):
68+
self.dataobj = dwi_dataobj
69+
self.affine = affine
70+
self.brainmask = brainmask_dataobj
71+
self.bzero = b0_dataobj
72+
self.gradients = gradients
73+
74+
def __len__(self):
75+
return self.dataobj.shape[-1]
76+
77+
def __getitem__(self, idx):
78+
return self.dataobj[..., idx], self.brainmask, self.gradients
79+
80+
81+
class DummyPETDataset(BaseDataset):
82+
def __init__(self, pet_dataobj, affine, brainmask_dataobj, midrame, total_duration):
83+
self.dataobj = pet_dataobj
84+
self.affine = affine
85+
self.brainmask = brainmask_dataobj
86+
self.midrame = midrame
87+
self.gradients = total_duration
88+
self.uptake = np.sum(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0)
89+
90+
def __len__(self):
91+
return self.dataobj.shape[-1]
92+
93+
def __getitem__(self, idx):
94+
return self.dataobj[..., idx], self.brainmask, self.midrame[idx]
95+
96+
6097
def test_estimator_init_model_instance(request):
6198
rng = request.node.rng
6299
model = DummyModel(dataset=DummyDataset(rng))
@@ -78,3 +115,93 @@ def test_estimator_init_model_string(request, monkeypatch):
78115
est.run(_dataset)
79116
assert isinstance(est._model, str)
80117
assert est._model == model_name
118+
119+
120+
@pytest.mark.parametrize(
121+
"strategy, iterator_func, modality",
122+
[
123+
("linear", iterators.linear_iterator, "dwi"),
124+
("linear", iterators.linear_iterator, "pet"),
125+
("random", iterators.random_iterator, "dwi"),
126+
("random", iterators.random_iterator, "pet"),
127+
("centralsym", iterators.centralsym_iterator, "dwi"),
128+
("centralsym", iterators.centralsym_iterator, "pet"),
129+
("bvalue", iterators.bvalue_iterator, "dwi"),
130+
("uptake", iterators.uptake_iterator, "pet"),
131+
],
132+
)
133+
def test_estimator_iterator_index_match(
134+
monkeypatch, setup_random_dwi_data, setup_random_pet_data, strategy, iterator_func, modality
135+
):
136+
dataset: Union["DummyDWIDataset", "DummyPETDataset"] # Avoids type annotation errors
137+
if modality == "dwi":
138+
(
139+
dwi_dataobj,
140+
affine,
141+
brainmask_dataobj,
142+
b0_dataobj,
143+
gradients,
144+
_,
145+
) = setup_random_dwi_data
146+
147+
dataset = DummyDWIDataset(dwi_dataobj, affine, brainmask_dataobj, b0_dataobj, gradients)
148+
bvals = gradients[-1, :][np.where(gradients[-1, :] > DEFAULT_LOWB_THRESHOLD)]
149+
kwargs = dict({"bvals": bvals})
150+
elif modality == "pet":
151+
(
152+
pet_dataobj,
153+
affine,
154+
brainmask_dataobj,
155+
midframe,
156+
total_duration,
157+
) = setup_random_pet_data
158+
159+
dataset = DummyPETDataset(pet_dataobj, affine, brainmask_dataobj, midframe, total_duration)
160+
uptake = dataset.uptake
161+
kwargs = dict({"uptake": uptake})
162+
else:
163+
raise NotImplementedError(f"{modality} not implemented")
164+
165+
# Patch set_transform to record indices and matrices
166+
recorded_indices = []
167+
recorded_matrices = []
168+
169+
# Make this accept `self` so it behaves as a proper instance method
170+
def fake_set_transform(self, i, xform):
171+
recorded_indices.append(i)
172+
recorded_matrices.append(xform)
173+
174+
monkeypatch.setattr(type(dataset), "set_transform", fake_set_transform)
175+
176+
# Patch registration to return identity matrix
177+
class DummyXForm:
178+
matrix = np.eye(4)
179+
180+
nifreeze.estimator._run_registration = lambda *a, **k: DummyXForm()
181+
182+
model = DummyModel(dataset=dataset)
183+
estimator = Estimator(model, strategy=strategy)
184+
estimator.run(dataset, **kwargs)
185+
186+
n_vols = len(dataset)
187+
188+
# Get expected indices
189+
if strategy == "linear":
190+
expected_indices = list(iterator_func(size=n_vols))
191+
elif strategy == "random":
192+
expected_indices = sorted(list(iterator_func(size=n_vols, seed=42)))
193+
recorded_indices_sorted = sorted(recorded_indices)
194+
assert recorded_indices_sorted == expected_indices
195+
return
196+
elif strategy == "centralsym":
197+
expected_indices = list(iterator_func(size=n_vols))
198+
elif strategy == "bvalue":
199+
expected_indices = list(iterator_func(bvals=bvals))
200+
elif strategy == "uptake":
201+
expected_indices = list(iterator_func(uptake=uptake))
202+
else:
203+
raise ValueError(f"Unknown strategy {strategy}")
204+
205+
# Assert indices and matrices
206+
assert recorded_indices == expected_indices
207+
assert all(np.allclose(mat, np.eye(4)) for mat in recorded_matrices)

0 commit comments

Comments
 (0)