Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 102 additions & 11 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,97 @@ cdef class Evaluation :

return y

def set_wLasso_ISO(self, img_weights_filename, lambda_perc_iso):
"""
Compute array of weights for ISO compartment from given image and set weighted lasso regularisation.

Parameters
----------
img_weights_filename - string :
Path to the NIFTI file containing the image with the weights for the ISO compartment.
NB: The file can be 3D in the same space as the dwi_filename used (dim and voxel size).
NB: the weights must be greater or equal to 1 (greater the value, more that voxel is penalized).

lambda_perc_iso - float :
percentage of the maximum value of the regularisation parameter for the ISO compartment.
NB: lambda_perc_iso must be a float greater than 0.
"""
tic = time.time()
logger.subinfo('')
logger.info( 'Setting weighted lasso regularisation for ISO compartment' )

# check if all the necessary functions have been called
if self.niiDWI is None :
logger.error( 'Data not loaded; call "load_data()" first' )
if self.DICTIONARY is None :
logger.error( 'Dictionary not loaded; call "load_dictionary()" first' )
if self.KERNELS is None :
logger.error( 'Response functions not generated; call "generate_kernels()" and "load_kernels()" first' )
if self.THREADS is None :
logger.error( 'Threads not set; call "set_threads()" first' )
if self.A is None :
logger.error( 'Operator not built; call "build_operator()" first' )

if self.DICTIONARY['IC']['nF'] <= 0 :
logger.error( 'No streamline found in the dictionary; check your data' )

if int( self.DICTIONARY['nV'] * self.KERNELS['iso'].shape[0] ) == 0 :
logger.error( 'Unable to set regularisation because no isotropic compartment found in the dictionary.' )

# load image and check it
logger.subinfo('Loading image with weights', indent_char='*', indent_lvl=1)

if not exists(pjoin(img_weights_filename)) :
logger.error( 'Image not found' )
weights_nii = nibabel.load(img_weights_filename)
weights_img = np.asanyarray( weights_nii.dataobj ).astype(np.float32)

if weights_img.ndim != 3:
logger.error( 'Weights image must be 3D dataset' )

hdr = weights_nii.header if nibabel.__version__ >= '2.0.0' else weights_nii.get_header()
weights_nii_dim = weights_img.shape[0:3]
weights_nii_pixdim = tuple( hdr.get_zooms()[:3] )

if ( self.get_config('dim') != weights_nii_dim ):
logger.error( 'Dataset does not have the same geometry (number of voxels) as the DWI signal' )
if (self.get_config('pixdim') != weights_nii_pixdim ):
logger.error( 'Dataset does not have the same geometry (voxel size) as the DWI signal' )
if (np.isnan(weights_img).any()):
weights_img[np.isnan(weights_img)] = 1.
logger.warning('Weights image contains NaNs. Those values were changed to 1.')

# compute array of weights from image
array_weights = weights_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ].flatten().astype(np.float32)
if array_weights.size != self.DICTIONARY['nV']:
logger.error( 'Number of voxels in the weights image does not match the number of voxels in the dictionary' )
if np.any(array_weights < 1):
logger.error('All weights must be greater or equal to 1')

if self.KERNELS['iso'].shape[0] > 1:
array_weights = np.tile(array_weights, self.KERNELS['iso'].shape[0])
dict_iso = {}
dict_iso['coeff_weights'] = array_weights

# set reg
logger.subinfo('Setting regularisation', indent_char='*', indent_lvl=1)
ui.set_verbose( 'core', 1 )
self.set_regularisation(
regularisers = (None, None, 'lasso'),
lambdas = (None, None, lambda_perc_iso),
is_nonnegative = (True, True, True),
params = (None, None, dict_iso))
ui.set_verbose( 'core', self.verbose )

# save the weights image to check if is correct
# path_to_save = img_weights_filename.replace('.nii.gz', '_resaved.nii.gz')
# affine = self.niiDWI.affine if nibabel.__version__ >= '2.0.0' else self.niiDWI.get_affine()
# img_to_save = np.copy(self.niiDWI_img[:,:,:,0])
# img_to_save[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ] = np.reshape(array_weights[:self.DICTIONARY['nV']], (self.DICTIONARY['nV'], -1))[:,0]
# nibabel.save( nibabel.Nifti1Image( img_to_save , affine ), path_to_save )

logger.info( f'[ {format_time(time.time() - tic)} ]' )


def set_regularisation(self, regularisers=(None, None, None), lambdas=(None, None, None), is_nonnegative=(True, True, True), params=(None, None, None)):
"""
Expand Down Expand Up @@ -1208,9 +1299,9 @@ cdef class Evaluation :
regularisation['lambdaISO_perc'] = lambdas[2]
else:
regularisation['lambdaISO_perc'] = lambdas[2]
# if dictISO_params is not None and 'coeff_weights' in dictISO_params:
# if dictISO_params['coeff_weights'].size != regularisation['sizeISO']:
# logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)')
if dictISO_params is not None and 'coeff_weights' in dictISO_params:
if dictISO_params['coeff_weights'].size != regularisation['sizeISO']:
logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)')
elif regularisation['regISO'] == 'smoothness':
logger.error('Not yet implemented')
elif regularisation['regISO'] == 'group_lasso':
Expand All @@ -1224,18 +1315,18 @@ cdef class Evaluation :

# update lambdas using lambda_max
if regularisation['regISO'] == 'lasso':
# if dictISO_params is not None and 'coeff_weights' in dictISO_params:
# regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights'])
# else:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64))
if dictISO_params is not None and 'coeff_weights' in dictISO_params:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights'])
else:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64))
regularisation['lambdaISO'] = regularisation['lambdaISO_perc'] * regularisation['lambdaISO_max']

# print
if regularisation['regISO'] is not None:
# if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params:
# logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' )
# else:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' )
if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' )
else:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' )

logger.subinfo( f'Non-negativity constraint: {regularisation["nnISO"]}', indent_char='-', indent_lvl=2 )
if regularisation['regISO'] is not None:
Expand Down
26 changes: 13 additions & 13 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def init_regularisation(regularisation_params):
all_coeff_weights[startIC:(startIC+sizeIC)] = regularisation_params['dictIC_params']["coeff_weights_kept"]
# if regularisation_params.get('dictEC_params') is not None and "coeff_weights" in regularisation_params['dictEC_params'].keys():
# all_coeff_weights[startEC:(startEC+sizeEC)] = regularisation_params['dictEC_params']["coeff_weights"]
# if regularisation_params.get('dictISO_params') is not None and "coeff_weights" in regularisation_params['dictISO_params'].keys():
# all_coeff_weights[startISO:(startISO+sizeISO)] = regularisation_params['dictISO_params']["coeff_weights"]
if regularisation_params.get('dictISO_params') is not None and "coeff_weights" in regularisation_params['dictISO_params'].keys():
all_coeff_weights[startISO:(startISO+sizeISO)] = regularisation_params['dictISO_params']["coeff_weights"]

############################
# INTRACELLULAR COMPARTMENT#
Expand Down Expand Up @@ -183,18 +183,18 @@ def init_regularisation(regularisation_params):
elif regularisation_params['regISO'] == 'lasso':
lambdaISO = regularisation_params.get('lambdaISO')
# check if weights are provided
# if dictISO_params is not None and "coeff_weights" in dictISO_params.keys():
# omegaISO = lambda x: lambdaISO * np.linalg.norm(all_coeff_weights[startISO:sizeISO]*x[startISO:sizeISO],1)
# if regularisation_params.get('nnISO'):
# proxISO = lambda x, scaling: non_negativity(w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO)
# else:
# proxISO = lambda x, scaling: w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO)
# else:
omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1)
if regularisation_params.get('nnISO'):
proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO)
if dictISO_params is not None and "coeff_weights" in dictISO_params.keys():
omegaISO = lambda x: lambdaISO * np.linalg.norm(all_coeff_weights[startISO:sizeISO]*x[startISO:sizeISO],1)
if regularisation_params.get('nnISO'):
proxISO = lambda x, scaling: non_negativity(w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO)
else:
proxISO = lambda x, scaling: w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO)
else:
proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO)
omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1)
if regularisation_params.get('nnISO'):
proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO)
else:
proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO)

# elif regularisation_params['regISO'] == 'group_lasso':
# lambdaISO = regularisation_params.get('lambdaISO')
Expand Down