Skip to content

Commit 713dbbf

Browse files
committed
Remove pointless .fit() method; docstring updates
1 parent 3c23538 commit 713dbbf

File tree

2 files changed

+43
-95
lines changed

2 files changed

+43
-95
lines changed

cifti.py

+32-61
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ class CiftiHandler(object):
115115
116116
Methods
117117
-------
118-
* get_volume_data : Extract volume data
119-
* get_surface_data : Extract surface data for a given hemisphere
120-
* get_all_data : Convenience function for extracting surface and volume data
121-
* create_new_cifti : Create new CIFTI image from provided data
118+
* ``get_volume_data`` : Extract volume data
119+
* ``get_surface_data`` : Extract surface data for a given hemisphere
120+
* ``get_all_data`` : Convenience function for extracting surface and volume data
121+
* ``create_new_cifti`` : Create new CIFTI image from provided data
122122
123123
Example usage
124124
-------------
@@ -401,20 +401,18 @@ class CiftiMasker(object):
401401
402402
Methods
403403
-------
404-
* fit : Load mask
405-
* transform : Load dataset, applying mask
406-
* transform_multiple : Load multiple datasets, applying mask
407-
* fit_transform[_multiple] : Fit and transform in one go
408-
* inverse_transform : Create new CIFTI image from masked data
409-
* uncache : Clear cache
404+
* ``transform`` : Load dataset, applying mask
405+
* ``transform_multiple`` : Load multiple datasets, applying mask
406+
* ``inverse_transform`` : Create new CIFTI image from masked data
407+
* ``uncache`` : Clear cache
410408
411409
Example usage
412410
-------------
413411
Use masks from Freesurfer Desikan-Killiany atlas for example subject.
414412
415413
>>> maskfile = '/mnt/hcpdata/Facelab/100610/MNINonLinear/' \\
416414
... 'fsaverage_LR32k/100610.aparc.32k_fs_LR.dlabel.nii'
417-
>>> masker = CiftiMasker(maskfile).fit()
415+
>>> masker = CiftiMasker(maskfile)
418416
419417
Use the ``transform`` method to mask data. By default, this will mask by
420418
the label with a numerical ID of 1 - this will work if the dlabel file
@@ -462,14 +460,14 @@ class CiftiMasker(object):
462460
463461
(900, 0)
464462
465-
Instead of using a dlabel mask file, you can also mask by one of the
463+
Instead of using a CIFTI mask file, you can also mask by one of the
466464
labelled structures contained in the data CIFTI file. This is most useful
467465
for extracting subcortical regions. You can use the full structure name,
468466
or anything recognised by nibabel's ``to_cifti_brain_structure_name``
469467
method. Here, we extract data for the left amygdala, comprising 900
470468
timepoints and 315 voxels.
471469
472-
>>> masker = CiftiMasker('left amygdala').fit()
470+
>>> masker = CiftiMasker('left amygdala')
473471
>>> ROI_data = masker.transform(infile)
474472
>>> print(ROI_data.shape)
475473
@@ -485,13 +483,27 @@ class CiftiMasker(object):
485483
>>> new_img.to_filename('my_masked_data.dtseries.nii')
486484
"""
487485
def __init__(self, mask_img):
486+
# Assign arg to class
488487
self.mask_img = mask_img
489-
self._is_fitted = False
490488

491-
def _check_is_fitted(self):
492-
if not self._is_fitted:
493-
raise Exception('This instance is not fitted yet. '
494-
'Call .fit method first.')
489+
# Select CIFTI structure, or load from file, or pass through objects
490+
try:
491+
self.mask_struct = nib.cifti2.BrainModelAxis \
492+
.to_cifti_brain_structure_name(self.mask_img)
493+
self._mask_is_cifti_struct = True
494+
495+
except ValueError:
496+
if isinstance(self.mask_img, CiftiHandler):
497+
self.mask_handler = copy.deepcopy(CiftiHandler)
498+
self.mask_handler.full_surface = True
499+
elif (isinstance(self.mask_img, str) and os.path.isfile(self.mask_img)) \
500+
or isinstance(self.mask_img, nib.Cifti2Image):
501+
self.mask_handler = CiftiHandler(self.mask_img, full_surface=True)
502+
else:
503+
raise ValueError('Invalid mask image')
504+
505+
self.mask_dict = self.mask_handler.get_all_data(dtype=int)
506+
self._mask_is_cifti_struct = False
495507

496508
def _resample_to_data(self, dict_, data_handler, block='all'):
497509
"""
@@ -562,32 +574,6 @@ def _parse_labelID(self, labelID, mapN):
562574
else:
563575
raise TypeError('Invalid label ID type')
564576

565-
def fit(self):
566-
"""
567-
Load mask
568-
"""
569-
# Select CIFTI structure, or load from file, or pass through objects
570-
try:
571-
self.mask_struct = nib.cifti2.BrainModelAxis \
572-
.to_cifti_brain_structure_name(self.mask_img)
573-
self._mask_is_cifti_struct = True
574-
except ValueError:
575-
if isinstance(self.mask_img, CiftiHandler):
576-
self.mask_handler = copy.deepcopy(CiftiHandler)
577-
self.mask_handler.full_surface = True
578-
elif (isinstance(self.mask_img, str) and os.path.isfile(self.mask_img)) \
579-
or isinstance(self.mask_img, nib.Cifti2Image):
580-
self.mask_handler = CiftiHandler(self.mask_img, full_surface=True)
581-
else:
582-
raise ValueError('Invalid mask image')
583-
584-
self.mask_dict = self.mask_handler.get_all_data(dtype=int)
585-
self._mask_is_cifti_struct = False
586-
587-
# Return
588-
self._is_fitted = True
589-
return self
590-
591577
def transform(self, img, mask_block='all', labelID=1, mapN=0, dtype=None):
592578
"""
593579
Load data from CIFTI and apply mask
@@ -624,9 +610,6 @@ def transform(self, img, mask_block='all', labelID=1, mapN=0, dtype=None):
624610
data_array : ndarray
625611
[nSamples x nGrayOrdinates] array of data values after masking
626612
"""
627-
# Error check
628-
self._check_is_fitted()
629-
630613
# Open handler for data file
631614
if isinstance(img, CiftiHandler):
632615
self.data_handler = copy.deepcopy(img)
@@ -692,14 +675,6 @@ def transform_multiple(self, imgs, vstack=False, *args, **kwargs):
692675
data = np.vstack(data)
693676
return data
694677

695-
def fit_transform(self, *args, **kwargs):
696-
self.fit()
697-
return self.transform(*args, **kwargs)
698-
699-
def fit_transform_multiple(self, *args, **kwargs):
700-
self.fit()
701-
return self.transform_multiple(*args, **kwargs)
702-
703678
def inverse_transform(self, data_array, mask_block='all', labelID=1,
704679
mapN=0, dtype=None, template_img=None,
705680
return_as_cifti=True, *args, **kwargs):
@@ -714,8 +689,8 @@ def inverse_transform(self, data_array, mask_block='all', labelID=1,
714689
array containing masked data.
715690
716691
mask_block : str {all | lh | rh | surface | volume}
717-
Which blocks from the CIFTI array to return data from. Should
718-
match value supplied to forward transform.
692+
Which blocks in the CIFTI array to allocate data to. Should match
693+
value supplied to forward transform.
719694
720695
labelID : int or str
721696
ID or name of label to select if mask contains multiple labels.
@@ -748,9 +723,6 @@ def inverse_transform(self, data_array, mask_block='all', labelID=1,
748723
Data reshaped to full set of grayordinates. Returned as Cifti2Image
749724
object if return_as_cifti is True, otherwise returned as array.
750725
"""
751-
# Error check
752-
self._check_is_fitted()
753-
754726
# Check dtype
755727
if dtype is None:
756728
dtype = data_array.dtype
@@ -804,7 +776,6 @@ def uncache(self):
804776
"""
805777
Clear mask and data from cache
806778
"""
807-
self._check_is_fitted()
808779
if not self._mask_is_cifti_struct:
809780
self.mask_handler.uncache()
810781
self.data_handler.uncache()

nifti.py

+11-34
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,29 @@ class QuickMasker(object):
4444
4545
Methods
4646
-------
47-
* ``fit`` : Load mask
4847
* ``transform`` : Load data and apply mask
49-
* ``fit_transform`` : Fit and transform in one go
5048
* ``inverse_transform`` : Create new NIFTI image from masked data
5149
5250
Example useage
5351
--------------
54-
>>> masker = QuickMasker('/path/to/mask.nii.gz').fit()
52+
>>> masker = QuickMasker('/path/to/mask.nii.gz')
5553
>>> ROI_data = masker.transform('/path/to/data.nii.gz')
5654
>>> masker.inverse_trasnform(ROI_data) \\
5755
... .to_filename('/path/to/masked_data.nii.gz')
5856
"""
5957
def __init__(self, mask, mask2=None):
58+
# Assign args to class
6059
self.mask = mask
6160
self.mask2 = mask2
62-
self._is_fitted = False
6361

64-
def _check_is_fitted(self):
65-
if not self._is_fitted:
66-
raise Exception('Must call .fit() method first')
62+
# Load primary mask
63+
self.mask_img, self.mask_array = self._load_mask(self.mask)
64+
65+
# Load secondary mask?
66+
if self.mask2 is not None:
67+
_, self.mask_array2 = self._load_mask(self.mask2)
68+
else:
69+
self.mask_array2 = None
6770

6871
@staticmethod
6972
def _load_mask(mask):
@@ -92,23 +95,6 @@ def _load_data(img, dtype):
9295
'or numpy array')
9396
return data
9497

95-
def fit(self):
96-
"""
97-
Load mask image
98-
"""
99-
# Load primary mask
100-
self.mask_img, self.mask_array = self._load_mask(self.mask)
101-
102-
# Load secondary mask?
103-
if self.mask2 is not None:
104-
_, self.mask_array2 = self._load_mask(self.mask2)
105-
else:
106-
self.mask_array2 = None
107-
108-
# Finish up and return
109-
self._is_fitted = True
110-
return self
111-
11298
def transform(self, imgs, labelID=None, invert_mask=False, vstack=False,
11399
dtype=np.float64):
114100
"""
@@ -125,7 +111,7 @@ def transform(self, imgs, labelID=None, invert_mask=False, vstack=False,
125111
contained within mask. If None (default), use all non-zero labels.
126112
127113
invert_mask : bool
128-
If True, load from vertices OUTSIDE of mask instead
114+
If True, load from voxels OUTSIDE of mask instead
129115
(default = False)
130116
131117
vstack : bool
@@ -144,8 +130,6 @@ def transform(self, imgs, labelID=None, invert_mask=False, vstack=False,
144130
axis if vstack is True.
145131
"""
146132
# Setup
147-
self._check_is_fitted()
148-
149133
if not isinstance(imgs, (tuple, list)):
150134
imgs = [imgs]
151135

@@ -171,10 +155,6 @@ def transform(self, imgs, labelID=None, invert_mask=False, vstack=False,
171155
# Return
172156
return data
173157

174-
def fit_transform(self, *args, **kwargs):
175-
self.fit()
176-
return self.transform(*args, **kwargs)
177-
178158
def inverse_transform(self, data, labelID=None, invert_mask=False,
179159
dtype=np.float32, return_as_nii=True, header=None,
180160
affine=None, extra=None):
@@ -222,8 +202,6 @@ def inverse_transform(self, data, labelID=None, invert_mask=False,
222202
Unmasked data in requested format.
223203
"""
224204
# Setup
225-
self._check_is_fitted()
226-
227205
if labelID is None:
228206
mask = self.mask_array.astype(bool)
229207
else:
@@ -259,4 +237,3 @@ def inverse_transform(self, data, labelID=None, invert_mask=False,
259237

260238
else:
261239
return inv_data
262-

0 commit comments

Comments
 (0)