diff --git a/invisible_cities/cities/beersheba.py b/invisible_cities/cities/beersheba.py index c30cd5004..493c2b534 100644 --- a/invisible_cities/cities/beersheba.py +++ b/invisible_cities/cities/beersheba.py @@ -137,7 +137,8 @@ def deconvolve_signal(det_db : pd.DataFrame, n_dim : Optional[int]=2, cut_type : Optional[CutType]=CutType.abs, inter_method : Optional[InterpolationMethod]=InterpolationMethod.cubic, - n_iterations_g : Optional[int]=0): + n_iterations_g : Optional[int]=0, + use_gpu : Optional[bool]=False): """ Applies Lucy Richardson deconvolution to SiPM response with a given set of PSFs and parameters. @@ -166,6 +167,7 @@ def deconvolve_signal(det_db : pd.DataFrame, `rel`: cut on the relative value (to the max) of the hits. inter_method : Interpolation method (`nointerpolation`, `nearest`, `linear` or `cubic`). n_iterations_g : Number of Lucy-Richardson iterations for gaussian in 'separate mode' + use_gpu : If True, use GPU for the deconvolution. Default is False. Returns ---------- @@ -185,7 +187,7 @@ def deconvolve_signal(det_db : pd.DataFrame, deconvolution = deconvolve(n_iterations, iteration_tol, sample_width, det_grid, **satellite_params, - inter_method = inter_method) + inter_method=inter_method, use_gpu=use_gpu) if not isinstance(energy_type , HitEnergy ): raise ValueError(f'energy_type {energy_type} is not a valid energy type.') @@ -228,7 +230,7 @@ def deconvolve_hits(df, z): deconv_image = nan_to_num(richardson_lucy(deconv_image, psf, iterations = n_iterations_g, iter_thr = iteration_tol, - **satellite_params)) + **satellite_params, use_gpu=use_gpu)) return create_deconvolution_df(df, deconv_image.flatten(), pos, cut_type, e_cut, n_dim) @@ -456,6 +458,9 @@ def beersheba( files_in : OneOrManyFiles 'cubic' not supported for 3D deconvolution. n_iterations_g : int Number of Lucy-Richardson iterations for gaussian in 'separate mode' + use_gpu : bool + Whether to use the GPU for the deconvolution. + satellite_params : dict, None satellite_start_iter : int Iteration no. when satellite killer starts being used. diff --git a/invisible_cities/reco/deconv_functions.py b/invisible_cities/reco/deconv_functions.py index 799d8e08d..6038198ae 100644 --- a/invisible_cities/reco/deconv_functions.py +++ b/invisible_cities/reco/deconv_functions.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import warnings from typing import List from typing import Tuple @@ -7,12 +8,24 @@ from typing import Optional from typing import Union +from functools import lru_cache + from scipy import interpolate from scipy.signal import fftconvolve from scipy.signal import convolve from scipy.spatial.distance import cdist from scipy import ndimage as ndi +# Check if there is a GPU available +try: + import cupy as cp + import cupyx as cpx + from cupyx.scipy.signal import convolve + from cupyx.scipy.signal import fftconvolve + import cupyx.scipy.ndimage as cpx_ndi +except ImportError: + warnings.warn("Impossible to import cupy. Computations will be done in CPU.") + from ..core .core_functions import shift_to_bin_centers from .. core.core_functions import binedges_from_bincenters from ..core .core_functions import in_range @@ -23,7 +36,8 @@ from .. types.symbols import CutType -def collect_component_sizes(im_mask : np.ndarray) -> (np.ndarray, np.ndarray): +def collect_component_sizes(im_mask : np.ndarray, + use_gpu : Optional[bool]=False) -> (np.ndarray, np.ndarray): ''' A function that returns the sizes of different clusters of 1s and 0s within the data for removal of satellites. @@ -45,16 +59,24 @@ def collect_component_sizes(im_mask : np.ndarray) -> (np.ndarray, np.ndarray): # label deposits within the array # hardcoded to include diagonals in the grouping stage (2) # count the bins of each labelled deposit - footprint = ndi.generate_binary_structure(im_mask.ndim, 2) - labels, _ = ndi.label(im_mask, footprint) - component_sizes = np.bincount(labels.ravel()) + + xdi = ndi + xp = np + if use_gpu: + xdi = cpx_ndi + xp = cp + + footprint = xdi.generate_binary_structure(im_mask.ndim, 2) + labels, _ = xdi.label(im_mask, footprint) + component_sizes = xp.bincount(labels.ravel()) return labels, component_sizes def generate_satellite_mask(im_deconv : np.ndarray, satellite_max_size : int, e_cut : float, - cut_type : Optional[CutType]=CutType.abs) -> np.ndarray: + cut_type : Optional[CutType]=CutType.abs, + use_gpu : Optional[bool]=False) -> np.ndarray: ''' An adaptation to the scikit-image (v0.24.0) function [1], identifies satellite energy depositions within deconvolution image by size @@ -87,17 +109,20 @@ def generate_satellite_mask(im_deconv : np.ndarray, ---------- .. [1] https://github.com/scikit-image/scikit-image/blob/main/skimage/morphology/misc.py#L59-L151 ''' + xp = np + if use_gpu: + xp = cp if cut_type is CutType.rel: im_deconv = im_deconv / im_deconv.max() # separate different regions below and above e_cut # then label regions (components) appropriately and determine their sizes. - labels, component_sizes = collect_component_sizes(im_deconv >= e_cut) + labels, component_sizes = collect_component_sizes(im_deconv >= e_cut, use_gpu) # check if no satellites within deposit return False array # (mask that removes no satellites). if len(component_sizes) <= 2: - return np.full(im_deconv.shape, False) + return xp.full(im_deconv.shape, False) # Find regions smaller than `satellite_max_size` and mask them, # ignoring the first region (background). Read gist for full explanation. @@ -286,6 +311,7 @@ def find_nearest(array : np.ndarray, idx = (np.abs (array - value)).argmin() return array[idx] + no_satellite_killer = dict(satellite_start_iter = None, satellite_max_size = 0, e_cut = 0, @@ -301,7 +327,8 @@ def deconvolve(n_iterations : int, satellite_max_size : int, e_cut : float, cut_type : Optional[CutType] = CutType.abs, - inter_method : InterpolationMethod = InterpolationMethod.cubic + inter_method : InterpolationMethod = InterpolationMethod.cubic, + use_gpu : Optional[bool] = False ) -> Callable: """ Deconvolves a given set of data (sensor position and its response) @@ -322,7 +349,8 @@ def deconvolve(n_iterations : int, iteration_tol : Stopping threshold (difference between iterations). sample_width : Sampling size of the sensors. det_grid : xy-coordinates of the detector grid, to interpolate on them - inter_method : Interpolation method. + inter_method : Interpolation method. Default is cubic. + use_gpu : Use GPU for the deconvolution. Default is False. Returns ------- @@ -343,14 +371,33 @@ def deconvolve(data : Tuple[np.ndarray, ...], psf_deco = psf.factor.values.reshape(psf.loc[:, columns].nunique().values) deconv_image = np.nan_to_num(richardson_lucy(inter_signal, psf_deco, satellite_start_iter, satellite_max_size, e_cut, cut_type, - n_iterations, iteration_tol)) + n_iterations, iteration_tol, use_gpu)) return deconv_image, inter_pos return deconvolve -def richardson_lucy(image, psf, satellite_start_iter, satellite_max_size, e_cut, cut_type, iterations=50, iter_thr=0.): +@lru_cache +def is_gpu_available() -> bool: + """Check if a GPU is available for computations. + Returns + ------- + bool + True if a GPU is available, False otherwise. + """ + try: + if cp.cuda.runtime.getDeviceCount(): + print("GPUs available:", cp.cuda.runtime.getDeviceCount()) + return True + else: + warnings.warn("Cupy is installed but no GPUs are available. Computations will be done in CPU.") + return False + except NameError: + return False + + +def richardson_lucy(image, psf, satellite_start_iter, satellite_max_size, e_cut, cut_type, iterations=50, iter_thr=0., use_gpu=False): """Richardson-Lucy deconvolution (modification from scikit-image package). The modification adds a value=0 protection, the possibility to stop iterating @@ -377,7 +424,9 @@ def richardson_lucy(image, psf, satellite_start_iter, satellite_max_size, e_cut, regularisation. iter_thr : float, optional Threshold on the relative difference between iterations to stop iterating. - + use_gpu : bool, optional + If True, use GPU for the deconvolution. Default is False. + Returns ------- im_deconv : ndarray @@ -410,30 +459,46 @@ def richardson_lucy(image, psf, satellite_start_iter, satellite_max_size, e_cut, else: convolve_method = convolve - image = image.astype(float) - psf = psf.astype(float) - im_deconv = 0.5 * np.ones(image.shape) + xp = np + using_gpu = use_gpu and is_gpu_available() + if using_gpu: + xp = cp + + # Convert image and psf to the appropriate array type (float) + image = xp.asarray(image, dtype=float) + psf = xp.asarray(psf, dtype=float) + im_deconv = 0.5 * xp.ones(image.shape) s = slice(None, None, -1) psf_mirror = psf[(s,) * psf.ndim] ### Allow for n-dim mirroring. - eps = np.finfo(image.dtype).eps ### Protection against 0 value + eps = xp.finfo(image.dtype).eps ### Protection against 0 value ref_image = image/image.max() for i in range(iterations): x = convolve_method(im_deconv, psf, 'same') - np.place(x, x==0, eps) ### Protection against 0 value + xp.place(x, x==0, eps) ### Protection against 0 value relative_blur = image / x im_deconv *= convolve_method(relative_blur, psf_mirror, 'same') - # if satellite parameters are provided kill satellites after each iteration. if satellite_start_iter is not None and i >= satellite_start_iter: - sat_mask = generate_satellite_mask(im_deconv, satellite_max_size, e_cut, cut_type) + sat_mask = generate_satellite_mask(im_deconv, satellite_max_size, e_cut, cut_type, using_gpu) im_deconv[sat_mask] = 0 - - with np.errstate(divide='ignore', invalid='ignore'): - rel_diff = np.nansum(np.divide(((im_deconv/im_deconv.max() - ref_image)**2), ref_image)) + + # This is needed because Cupy does not have a full errstate implementation + if use_gpu and is_gpu_available(): + rel_diff_array = (im_deconv/im_deconv.max() - ref_image) ** 2 + with cpx.errstate(): + rel_diff_array = xp.where(ref_image != 0, xp.divide(rel_diff_array, ref_image), 0) + rel_diff = xp.sum(rel_diff_array) + else: + with np.errstate(divide='ignore', invalid='ignore'): + rel_diff = xp.sum(xp.divide(((im_deconv/im_deconv.max() - ref_image)**2), ref_image)) + if rel_diff < iter_thr: ### Break if a given threshold is reached. break ref_image = im_deconv/im_deconv.max() + if using_gpu: + im_deconv = cp.asnumpy(im_deconv) + return im_deconv