2121# https://www.nipreps.org/community/licensing/
2222#
2323
24+ from typing import Union
25+
2426import numpy as np
27+ import pytest
2528
29+ import nifreeze .estimator
2630from nifreeze .data .base import BaseDataset
31+ from nifreeze .data .dmri import DEFAULT_LOWB_THRESHOLD
2732from nifreeze .estimator import Estimator
2833from nifreeze .model .base import BaseModel
34+ from nifreeze .utils import iterators
2935
3036DATAOBJ_SIZE = (5 , 5 , 5 , 4 )
3137
@@ -57,6 +63,37 @@ def set_transform(self, idx, matrix):
5763 pass
5864
5965
66+ class DummyDWIDataset (BaseDataset ):
67+ def __init__ (self , dwi_dataobj , affine , brainmask_dataobj , b0_dataobj , gradients ):
68+ self .dataobj = dwi_dataobj
69+ self .affine = affine
70+ self .brainmask = brainmask_dataobj
71+ self .bzero = b0_dataobj
72+ self .gradients = gradients
73+
74+ def __len__ (self ):
75+ return self .dataobj .shape [- 1 ]
76+
77+ def __getitem__ (self , idx ):
78+ return self .dataobj [..., idx ], self .brainmask , self .gradients
79+
80+
81+ class DummyPETDataset (BaseDataset ):
82+ def __init__ (self , pet_dataobj , affine , brainmask_dataobj , midrame , total_duration ):
83+ self .dataobj = pet_dataobj
84+ self .affine = affine
85+ self .brainmask = brainmask_dataobj
86+ self .midrame = midrame
87+ self .gradients = total_duration
88+ self .uptake = np .sum (pet_dataobj .reshape (- 1 , pet_dataobj .shape [- 1 ]), axis = 0 )
89+
90+ def __len__ (self ):
91+ return self .dataobj .shape [- 1 ]
92+
93+ def __getitem__ (self , idx ):
94+ return self .dataobj [..., idx ], self .brainmask , self .midrame [idx ]
95+
96+
6097def test_estimator_init_model_instance (request ):
6198 rng = request .node .rng
6299 model = DummyModel (dataset = DummyDataset (rng ))
@@ -78,3 +115,93 @@ def test_estimator_init_model_string(request, monkeypatch):
78115 est .run (_dataset )
79116 assert isinstance (est ._model , str )
80117 assert est ._model == model_name
118+
119+
120+ @pytest .mark .parametrize (
121+ "strategy, iterator_func, modality" ,
122+ [
123+ ("linear" , iterators .linear_iterator , "dwi" ),
124+ ("linear" , iterators .linear_iterator , "pet" ),
125+ ("random" , iterators .random_iterator , "dwi" ),
126+ ("random" , iterators .random_iterator , "pet" ),
127+ ("centralsym" , iterators .centralsym_iterator , "dwi" ),
128+ ("centralsym" , iterators .centralsym_iterator , "pet" ),
129+ ("bvalue" , iterators .bvalue_iterator , "dwi" ),
130+ ("uptake" , iterators .uptake_iterator , "pet" ),
131+ ],
132+ )
133+ def test_estimator_iterator_index_match (
134+ monkeypatch , setup_random_dwi_data , setup_random_pet_data , strategy , iterator_func , modality
135+ ):
136+ dataset : Union ["DummyDWIDataset" , "DummyPETDataset" ] # Avoids type annotation errors
137+ if modality == "dwi" :
138+ (
139+ dwi_dataobj ,
140+ affine ,
141+ brainmask_dataobj ,
142+ b0_dataobj ,
143+ gradients ,
144+ _ ,
145+ ) = setup_random_dwi_data
146+
147+ dataset = DummyDWIDataset (dwi_dataobj , affine , brainmask_dataobj , b0_dataobj , gradients )
148+ bvals = gradients [- 1 , :][np .where (gradients [- 1 , :] > DEFAULT_LOWB_THRESHOLD )]
149+ kwargs = dict ({"bvals" : bvals })
150+ elif modality == "pet" :
151+ (
152+ pet_dataobj ,
153+ affine ,
154+ brainmask_dataobj ,
155+ midframe ,
156+ total_duration ,
157+ ) = setup_random_pet_data
158+
159+ dataset = DummyPETDataset (pet_dataobj , affine , brainmask_dataobj , midframe , total_duration )
160+ uptake = dataset .uptake
161+ kwargs = dict ({"uptake" : uptake })
162+ else :
163+ raise NotImplementedError (f"{ modality } not implemented" )
164+
165+ # Patch set_transform to record indices and matrices
166+ recorded_indices = []
167+ recorded_matrices = []
168+
169+ # Make this accept `self` so it behaves as a proper instance method
170+ def fake_set_transform (self , i , xform ):
171+ recorded_indices .append (i )
172+ recorded_matrices .append (xform )
173+
174+ monkeypatch .setattr (type (dataset ), "set_transform" , fake_set_transform )
175+
176+ # Patch registration to return identity matrix
177+ class DummyXForm :
178+ matrix = np .eye (4 )
179+
180+ nifreeze .estimator ._run_registration = lambda * a , ** k : DummyXForm ()
181+
182+ model = DummyModel (dataset = dataset )
183+ estimator = Estimator (model , strategy = strategy )
184+ estimator .run (dataset , ** kwargs )
185+
186+ n_vols = len (dataset )
187+
188+ # Get expected indices
189+ if strategy == "linear" :
190+ expected_indices = list (iterator_func (size = n_vols ))
191+ elif strategy == "random" :
192+ expected_indices = sorted (list (iterator_func (size = n_vols , seed = 42 )))
193+ recorded_indices_sorted = sorted (recorded_indices )
194+ assert recorded_indices_sorted == expected_indices
195+ return
196+ elif strategy == "centralsym" :
197+ expected_indices = list (iterator_func (size = n_vols ))
198+ elif strategy == "bvalue" :
199+ expected_indices = list (iterator_func (bvals = bvals ))
200+ elif strategy == "uptake" :
201+ expected_indices = list (iterator_func (uptake = uptake ))
202+ else :
203+ raise ValueError (f"Unknown strategy { strategy } " )
204+
205+ # Assert indices and matrices
206+ assert recorded_indices == expected_indices
207+ assert all (np .allclose (mat , np .eye (4 )) for mat in recorded_matrices )
0 commit comments