diff --git a/commit/core.pyx b/commit/core.pyx index 479816b..ae9a607 100644 --- a/commit/core.pyx +++ b/commit/core.pyx @@ -770,6 +770,97 @@ cdef class Evaluation : return y + def set_wLasso_ISO(self, img_weights_filename, lambda_perc_iso): + """ + Compute array of weights for ISO compartment from given image and set weighted lasso regularisation. + + Parameters + ---------- + img_weights_filename - string : + Path to the NIFTI file containing the image with the weights for the ISO compartment. + NB: The file can be 3D in the same space as the dwi_filename used (dim and voxel size). + NB: the weights must be greater or equal to 1 (greater the value, more that voxel is penalized). + + lambda_perc_iso - float : + percentage of the maximum value of the regularisation parameter for the ISO compartment. + NB: lambda_perc_iso must be a float greater than 0. + """ + tic = time.time() + logger.subinfo('') + logger.info( 'Setting weighted lasso regularisation for ISO compartment' ) + + # check if all the necessary functions have been called + if self.niiDWI is None : + logger.error( 'Data not loaded; call "load_data()" first' ) + if self.DICTIONARY is None : + logger.error( 'Dictionary not loaded; call "load_dictionary()" first' ) + if self.KERNELS is None : + logger.error( 'Response functions not generated; call "generate_kernels()" and "load_kernels()" first' ) + if self.THREADS is None : + logger.error( 'Threads not set; call "set_threads()" first' ) + if self.A is None : + logger.error( 'Operator not built; call "build_operator()" first' ) + + if self.DICTIONARY['IC']['nF'] <= 0 : + logger.error( 'No streamline found in the dictionary; check your data' ) + + if int( self.DICTIONARY['nV'] * self.KERNELS['iso'].shape[0] ) == 0 : + logger.error( 'Unable to set regularisation because no isotropic compartment found in the dictionary.' ) + + # load image and check it + logger.subinfo('Loading image with weights', indent_char='*', indent_lvl=1) + + if not exists(pjoin(img_weights_filename)) : + logger.error( 'Image not found' ) + weights_nii = nibabel.load(img_weights_filename) + weights_img = np.asanyarray( weights_nii.dataobj ).astype(np.float32) + + if weights_img.ndim != 3: + logger.error( 'Weights image must be 3D dataset' ) + + hdr = weights_nii.header if nibabel.__version__ >= '2.0.0' else weights_nii.get_header() + weights_nii_dim = weights_img.shape[0:3] + weights_nii_pixdim = tuple( hdr.get_zooms()[:3] ) + + if ( self.get_config('dim') != weights_nii_dim ): + logger.error( 'Dataset does not have the same geometry (number of voxels) as the DWI signal' ) + if (self.get_config('pixdim') != weights_nii_pixdim ): + logger.error( 'Dataset does not have the same geometry (voxel size) as the DWI signal' ) + if (np.isnan(weights_img).any()): + weights_img[np.isnan(weights_img)] = 1. + logger.warning('Weights image contains NaNs. Those values were changed to 1.') + + # compute array of weights from image + array_weights = weights_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ].flatten().astype(np.float32) + if array_weights.size != self.DICTIONARY['nV']: + logger.error( 'Number of voxels in the weights image does not match the number of voxels in the dictionary' ) + if np.any(array_weights < 1): + logger.error('All weights must be greater or equal to 1') + + if self.KERNELS['iso'].shape[0] > 1: + array_weights = np.tile(array_weights, self.KERNELS['iso'].shape[0]) + dict_iso = {} + dict_iso['coeff_weights'] = array_weights + + # set reg + logger.subinfo('Setting regularisation', indent_char='*', indent_lvl=1) + ui.set_verbose( 'core', 1 ) + self.set_regularisation( + regularisers = (None, None, 'lasso'), + lambdas = (None, None, lambda_perc_iso), + is_nonnegative = (True, True, True), + params = (None, None, dict_iso)) + ui.set_verbose( 'core', self.verbose ) + + # save the weights image to check if is correct + # path_to_save = img_weights_filename.replace('.nii.gz', '_resaved.nii.gz') + # affine = self.niiDWI.affine if nibabel.__version__ >= '2.0.0' else self.niiDWI.get_affine() + # img_to_save = np.copy(self.niiDWI_img[:,:,:,0]) + # img_to_save[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ] = np.reshape(array_weights[:self.DICTIONARY['nV']], (self.DICTIONARY['nV'], -1))[:,0] + # nibabel.save( nibabel.Nifti1Image( img_to_save , affine ), path_to_save ) + + logger.info( f'[ {format_time(time.time() - tic)} ]' ) + def set_regularisation(self, regularisers=(None, None, None), lambdas=(None, None, None), is_nonnegative=(True, True, True), params=(None, None, None)): """ @@ -1208,9 +1299,9 @@ cdef class Evaluation : regularisation['lambdaISO_perc'] = lambdas[2] else: regularisation['lambdaISO_perc'] = lambdas[2] - # if dictISO_params is not None and 'coeff_weights' in dictISO_params: - # if dictISO_params['coeff_weights'].size != regularisation['sizeISO']: - # logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)') + if dictISO_params is not None and 'coeff_weights' in dictISO_params: + if dictISO_params['coeff_weights'].size != regularisation['sizeISO']: + logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)') elif regularisation['regISO'] == 'smoothness': logger.error('Not yet implemented') elif regularisation['regISO'] == 'group_lasso': @@ -1224,18 +1315,18 @@ cdef class Evaluation : # update lambdas using lambda_max if regularisation['regISO'] == 'lasso': - # if dictISO_params is not None and 'coeff_weights' in dictISO_params: - # regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights']) - # else: - regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64)) + if dictISO_params is not None and 'coeff_weights' in dictISO_params: + regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights']) + else: + regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64)) regularisation['lambdaISO'] = regularisation['lambdaISO_perc'] * regularisation['lambdaISO_max'] # print if regularisation['regISO'] is not None: - # if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params: - # logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' ) - # else: - logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' ) + if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params: + logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' ) + else: + logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' ) logger.subinfo( f'Non-negativity constraint: {regularisation["nnISO"]}', indent_char='-', indent_lvl=2 ) if regularisation['regISO'] is not None: diff --git a/commit/solvers.py b/commit/solvers.py index 730f06d..140cfc5 100755 --- a/commit/solvers.py +++ b/commit/solvers.py @@ -38,8 +38,8 @@ def init_regularisation(regularisation_params): all_coeff_weights[startIC:(startIC+sizeIC)] = regularisation_params['dictIC_params']["coeff_weights_kept"] # if regularisation_params.get('dictEC_params') is not None and "coeff_weights" in regularisation_params['dictEC_params'].keys(): # all_coeff_weights[startEC:(startEC+sizeEC)] = regularisation_params['dictEC_params']["coeff_weights"] - # if regularisation_params.get('dictISO_params') is not None and "coeff_weights" in regularisation_params['dictISO_params'].keys(): - # all_coeff_weights[startISO:(startISO+sizeISO)] = regularisation_params['dictISO_params']["coeff_weights"] + if regularisation_params.get('dictISO_params') is not None and "coeff_weights" in regularisation_params['dictISO_params'].keys(): + all_coeff_weights[startISO:(startISO+sizeISO)] = regularisation_params['dictISO_params']["coeff_weights"] ############################ # INTRACELLULAR COMPARTMENT# @@ -183,18 +183,18 @@ def init_regularisation(regularisation_params): elif regularisation_params['regISO'] == 'lasso': lambdaISO = regularisation_params.get('lambdaISO') # check if weights are provided - # if dictISO_params is not None and "coeff_weights" in dictISO_params.keys(): - # omegaISO = lambda x: lambdaISO * np.linalg.norm(all_coeff_weights[startISO:sizeISO]*x[startISO:sizeISO],1) - # if regularisation_params.get('nnISO'): - # proxISO = lambda x, scaling: non_negativity(w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO) - # else: - # proxISO = lambda x, scaling: w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO) - # else: - omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1) - if regularisation_params.get('nnISO'): - proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO) + if dictISO_params is not None and "coeff_weights" in dictISO_params.keys(): + omegaISO = lambda x: lambdaISO * np.linalg.norm(all_coeff_weights[startISO:sizeISO]*x[startISO:sizeISO],1) + if regularisation_params.get('nnISO'): + proxISO = lambda x, scaling: non_negativity(w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO) + else: + proxISO = lambda x, scaling: w_soft_thresholding(x,all_coeff_weights,scaling*lambdaISO,startISO,sizeISO) else: - proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO) + omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1) + if regularisation_params.get('nnISO'): + proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO) + else: + proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO) # elif regularisation_params['regISO'] == 'group_lasso': # lambdaISO = regularisation_params.get('lambdaISO')