Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion commit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
try:
__version__ = metadata.version('dmri-commit')
except metadata.PackageNotFoundError:
__version__ = 'not installed'
__version__ = 'not installed'
162 changes: 70 additions & 92 deletions commit/core.pyx

Large diffs are not rendered by default.

240 changes: 72 additions & 168 deletions commit/models.pyx
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
#!python
#cython: language_level=3, boundscheck=False, wraparound=False, profile=False

from os import cpu_count as num_cpu
from os.path import join as pjoin

from setuptools import Extension

from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import nibabel as nib

import nibabel
from amico.models import BaseModel, StickZeppelinBall as _StickZeppelinBall, CylinderZeppelinBall as _CylinderZeppelinBall
import amico.util as util

from dicelib.ui import setup_logger, set_verbose, ProgressBar


logger = setup_logger('models')



class StickZeppelinBall(_StickZeppelinBall):
"""Simulate the response functions according to the Stick-Zeppelin-Ball model.
See the AMICO.model module for details.
"""
def resample(self, in_path, idx_out, Ylm_out, doMergeB0, ndirs):
#FIXME: use logger
util.set_verbose(2)
return super().resample(in_path, idx_out, Ylm_out, doMergeB0, ndirs)

Expand All @@ -34,60 +24,15 @@ class CylinderZeppelinBall(_CylinderZeppelinBall):
See the AMICO.model module for details.
"""
def resample(self, in_path, idx_out, Ylm_out, doMergeB0, ndirs):
#FIXME: use logger
util.set_verbose(2)
return super().resample(in_path, idx_out, Ylm_out, doMergeB0, ndirs)


# class VolumeFractions(BaseModel):
# """Implements a simple model where each compartment contributes only with
# its own volume fraction. This model has been created to test there
# ability to remove false positive fibers with COMMIT.
# """
# def __init__(self):
# self.id = 'VolumeFractions'
# self.name = 'Volume fractions'
# self.maps_name = []
# self.maps_descr = []
# self.nolut = True

# def set(self):
# return

# def get_params(self):
# params = {}
# params['id'] = self.id
# params['name'] = self.name
# return params

# def set_solver(self):
# logger.error('Not implemented')

# def generate(self, out_path, aux, idx_in, idx_out, ndirs):
# return

# def resample(self, in_path, idx_out, Ylm_out, doMergeB0, ndirs):
# if doMergeB0:
# nS = 1 + self.scheme.dwi_count
# merge_idx = np.hstack((self.scheme.b0_idx[0], self.scheme.dwi_idx))
# else:
# nS = self.scheme.nS
# merge_idx = np.arange(nS)

# KERNELS = {}
# KERNELS['model'] = self.id
# KERNELS['wmr'] = np.ones((1, ndirs, nS), dtype=np.float32)
# KERNELS['wmh'] = np.ones((0, ndirs, nS), dtype=np.float32)
# KERNELS['iso'] = np.ones((0, nS), dtype=np.float32)
# return KERNELS

# def fit(self, evaluation):
# logger.error('Not implemented')


class ScalarMap( BaseModel ) :
"""Implements a simple model where each compartment contributes only with
its own volume fraction. This model has been created to test there
ability to remove false positive fibers with COMMIT.
"""Implements a simple model where each compartment contributes to a scalar map,
e.g. intra-axonsl signal fraction or myelin water fraction, proportionally to
its local length inside each voxel.
"""

def __init__( self ) :
Expand All @@ -111,7 +56,7 @@ class ScalarMap( BaseModel ) :

def generate( self, out_path, aux, idx_in, idx_out, ndirs ) :
return

def fit(self, evaluation):
"""Placeholder implementation for the abstract method."""
logger.error('Not implemented')
Expand All @@ -134,115 +79,74 @@ class ScalarMap( BaseModel ) :
KERNELS['iso'] = np.ones( (0,nS), dtype=np.float32 )

return KERNELS

def find_idx(self, ic_v, ic_les, dict_icf, max_size, progress_bar, chunk_num):
cdef unsigned int [::1]ic_les_view = ic_les
cdef long les_shape = ic_les.shape[0]
cdef long ic_v_shape = ic_v.shape[0]
cdef int chunk_num_mem = chunk_num
cdef unsigned int [::1]progress_bar_mem = progress_bar


# indices_vox = np.zeros(les_shape*ic_v_shape, dtype=np.uint32)
indices_vox = np.zeros(max_size, dtype=np.uint32)
cdef unsigned int [::1]indices_vox_view = indices_vox

# indices_str = np.zeros(les_shape*ic_v_shape, dtype=np.uint32)
indices_str = np.zeros(max_size, dtype=np.uint32)
cdef unsigned int [::1]indices_str_view = indices_str

cdef unsigned int[::1] ic_v_view = ic_v

cdef int count = 0

cdef unsigned int [::1]dict_icf_view = dict_icf

cdef unsigned int ic_les_val = 0
cdef size_t i, j = 0

with nogil:
for i in range(les_shape):
ic_les_val = ic_les_view[i]
for j in range(ic_v_shape):
if ic_v_view[j] == ic_les_val:# and ic_les_val > 0:
indices_vox_view[count] = ic_v_view[j]
indices_str_view[count] = dict_icf_view[j]
count += 1
progress_bar_mem[chunk_num_mem] += 1

return indices_vox[:count], indices_str[:count]

def _postprocess(self, preproc_dict, verbose=1):
ISO_map = np.array(preproc_dict['niiISO_img'], dtype=np.float32)
IC_map = np.array(preproc_dict['niiIC_img'], dtype=np.float32)

dictionary = preproc_dict['DICTIONARY']

kept = dictionary['TRK']['kept']
x_weights = preproc_dict["streamline_weights"][kept==1].copy()
x_weights_scaled = x_weights.copy()
new_weights = np.zeros_like(kept, dtype=np.float32)


IC = IC_map[dictionary['MASK_ix'], dictionary['MASK_iy'], dictionary['MASK_iz']]
ISO = ISO_map[dictionary['MASK_ix'], dictionary['MASK_iy'], dictionary['MASK_iz']]
ISO_scaled = np.zeros_like(ISO)
def _postprocess(self, evaluation, xic):
"""Rescale the streamline weights using the local tissue damage estimated
in all imaging voxels with the COMMIT_lesion new model.

Parameters
----------
evaluation : object
Evaluation object, to enable accessing the whole content of the
whole evaluation object

xic : np.array
Streamline weights

Returns
-------
np.array
The rescaled streamline weights accounting for lesions
"""
if not self.lesion_mask:
# nothing to do if lesion mask is not given
return xic

RESULTS_path = evaluation.get_config('RESULTS_path')
niiISO_img = np.asanyarray( nibabel.load( pjoin(RESULTS_path,'compartment_ISO.nii.gz') ).dataobj ).astype(np.float32)
ISO = niiISO_img[evaluation.DICTIONARY['MASK_ix'], evaluation.DICTIONARY['MASK_iy'], evaluation.DICTIONARY['MASK_iz']]
if np.count_nonzero(ISO>0) == 0:
logger.warning('No lesions found')
return xic

# rescale the input scalar map in each voxel according to estimated lesion contributions
niiIC_img = np.asanyarray( nibabel.load( pjoin(RESULTS_path,'compartment_IC.nii.gz') ).dataobj ).astype(np.float32)
IC = niiIC_img[evaluation.DICTIONARY['MASK_ix'], evaluation.DICTIONARY['MASK_iy'], evaluation.DICTIONARY['MASK_iz']]
ISO_scaled = np.zeros_like(ISO, dtype=np.float32)
ISO_scaled[ISO>0] = (IC[ISO>0] - ISO[ISO>0]) / IC[ISO>0]
ISO_scaled_save = np.zeros_like(niiISO_img, dtype=np.float32)
ISO_scaled_save[evaluation.DICTIONARY['MASK_ix'], evaluation.DICTIONARY['MASK_iy'], evaluation.DICTIONARY['MASK_iz']] = ISO_scaled
affine = evaluation.niiDWI.affine if nibabel.__version__ >= '2.0.0' else evaluation.niiDWI.get_affine()
nibabel.save(nibabel.Nifti1Image(ISO_scaled_save, affine), pjoin(RESULTS_path,'compartment_IC_lesion_scaled.nii.gz'))

ISO_scaled_save = np.zeros_like(ISO_map)
ISO_scaled_save[dictionary['MASK_ix'], dictionary['MASK_iy'], dictionary['MASK_iz']] = ISO_scaled
nib.save(nib.Nifti1Image(ISO_scaled_save, preproc_dict['affine']), pjoin(preproc_dict["RESULTS_path"],'compartment_IC_lesion_scaled.nii.gz'))

idx_les = np.argwhere(ISO_scaled > 0)[:,0].astype(np.uint32)
if idx_les.shape[0] == 0:
logger.error('No lesion found in the input image.')
return


result = []
dict_idx_v = dictionary['IC']['v']
cdef unsigned int [::1]dict_idx_v_view = dict_idx_v

cdef unsigned int cpu_count = num_cpu()
cdef unsigned int[:] find_idx_progress = np.zeros(cpu_count, dtype=np.uint32)

n = idx_les.shape[0]
c = n // cpu_count
max_size = int(3e9/cpu_count)
chunks = []
for ii, jj in zip(range(0, n, c), range(c, n+1, c)):
chunks.append((ii, jj))
if chunks[len(chunks)-1][1] != n:
chunks[len(chunks)-1] = (chunks[len(chunks)-1][0], n)
logger
logger.subinfo('Recomputing streamlines weights accounting for lesions', indent_lvl=3, indent_char='->', with_progress=True)
with ProgressBar(multithread_progress=find_idx_progress, total=n,
disable=verbose<3, hide_on_exit=True, subinfo=True) as pbar:
with ThreadPoolExecutor(max_workers=cpu_count) as executor:
futures = [executor.submit(self.find_idx, dict_idx_v_view, idx_les[ii:jj], dictionary['IC']['fiber'], max_size, find_idx_progress, chunk_num) for chunk_num, (ii, jj) in enumerate(chunks)]
for future in as_completed(futures):
result.append(future.result())

idx_vox = []
idx_str = []
for r in result:
idx_vox.extend(r[0].tolist())
idx_str.extend(r[1].tolist())

cdef double [::1] x_weights_view = x_weights
cdef double [::1] x_weights_scaled_view = x_weights_scaled
cdef double x_weight = 0
cdef float [::1] ISO_scaled_view = ISO_scaled

cdef size_t vox = 0
cdef size_t str_i = 0

for vox, str_i in zip(idx_vox, idx_str):
x_weight = x_weights_view[str_i] * ISO_scaled_view[vox]
if x_weight < x_weights_scaled_view[str_i]:
x_weights_scaled_view[str_i] = x_weight
# save the map of local tissue damage estimated in each voxel
nibabel.save( nibabel.Nifti1Image( niiISO_img, affine ), pjoin(RESULTS_path,'compartment_lesion.nii.gz') )

new_weights[kept==1] = x_weights_scaled
coeffs_format='%.5e'
# override ISO map and set it to 0
nibabel.save( nibabel.Nifti1Image( 0*niiISO_img, affine), pjoin(RESULTS_path,'compartment_ISO.nii.gz') )

np.savetxt( pjoin(preproc_dict["RESULTS_path"],'streamline_weights.txt'), new_weights, fmt=coeffs_format )
# rescale each streamline weight
kept = evaluation.DICTIONARY['TRK']['kept']
cdef double [::1] xic_view = xic[kept==1]
cdef double [::1] xic_scaled_view = xic[kept==1].copy()
cdef float [::1] ISO_scaled_view = ISO_scaled
cdef unsigned int [::1] idx_v_view = evaluation.DICTIONARY['IC']['v']
cdef unsigned int [::1] idx_f_view = evaluation.DICTIONARY['IC']['fiber']
cdef size_t i, idx_v, idx_f
cdef double val

# Rescaling streamline weights accounting for lesions
for i in range(evaluation.DICTIONARY['IC']['v'].shape[0]):
idx_v = idx_v_view[i]
val = ISO_scaled_view[idx_v]
if val > 0:
idx_f = idx_f_view[i]
#TODO: allow considering other than the min value
if xic_view[idx_f] * val < xic_scaled_view[idx_f]:
xic_scaled_view[idx_f] = xic_view[idx_f] * val

# return rescaled streamline weights
xic_scaled = np.zeros_like(kept, dtype=np.float32)
xic_scaled[kept==1] = xic_scaled_view
return xic_scaled
24 changes: 10 additions & 14 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def init_regularisation(regularisation_params):
# check if regularisations are in the list
if regularisation_params['regIC'] not in list_regularizers or regularisation_params['regEC'] not in list_regularizers or regularisation_params['regISO'] not in list_regularizers:
logger.error('Regularisation not in the list')

startIC = regularisation_params.get('startIC')
sizeIC = regularisation_params.get('sizeIC')
startEC = regularisation_params.get('startEC')
Expand Down Expand Up @@ -97,7 +97,7 @@ def init_regularisation(regularisation_params):
proxIC = lambda x, scaling: non_negativity(prox_group_lasso(x,groupIdxIC,groupSizeIC,dictIC_params['group_weights'],scaling*lambda_group_IC),startIC,sizeIC)
else:
proxIC = lambda x, scaling: prox_group_lasso(x,groupIdxIC,groupSizeIC,dictIC_params['group_weights'],scaling*lambda_group_IC)

elif regularisation_params['regIC'] == 'sparse_group_lasso':
if not len(dictIC_params['group_idx_kept']) == len(dictIC_params['group_weights']):
logger.error('Number of groups and weights do not match')
Expand Down Expand Up @@ -132,12 +132,10 @@ def init_regularisation(regularisation_params):
proxIC = lambda x, scaling: prox_group_lasso(soft_thresholding(x,scaling*lambdaIC,startIC,sizeIC),groupIdxIC,groupSizeIC,dictIC_params['group_weights'],scaling*lambda_group_IC)


###########################
# EXTRCELLULAR COMPARTMENT#
###########################

dictEC_params = regularisation_params.get('dictEC_params')

#############################
# EXTRACELLULAR COMPARTMENT #
#############################
# dictEC_params = regularisation_params.get('dictEC_params')
if regularisation_params['regEC'] is None:
omegaEC = lambda x: 0.0
if regularisation_params.get('nnEC')==True:
Expand Down Expand Up @@ -167,12 +165,10 @@ def init_regularisation(regularisation_params):
# proxEC = lambda x: projection_onto_l2_ball(x, lambdaEC, startEC, sizeEC)


########################
# ISOTROPIC COMPARTMENT#
########################

dictISO_params = regularisation_params.get('dictISO_params')

#########################
# ISOTROPIC COMPARTMENT #
#########################
# dictISO_params = regularisation_params.get('dictISO_params')
if regularisation_params['regISO'] is None:
omegaISO = lambda x: 0.0
if regularisation_params.get('nnISO')==True:
Expand Down
Loading