Skip to content
Open
228 changes: 185 additions & 43 deletions src/reboost/hpge/psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .. import units
from ..units import ureg as u
from .utils import HPGeScalarRZField
from .utils import HPGePulseShapeLibrary, HPGeRZField

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +83,7 @@ def drift_time(
xloc: ArrayLike,
yloc: ArrayLike,
zloc: ArrayLike,
dt_map: HPGeScalarRZField,
dt_map: HPGeRZField,
coord_offset: pint.Quantity | pyg4ometry.gdml.Position = (0, 0, 0) * u.m,
) -> VectorOfVectors:
"""Calculates drift times for each step (cluster) in an HPGe detector.
Expand Down Expand Up @@ -513,6 +513,7 @@ def _get_waveform_value_surface(
else:
out += E * _interpolate_pulse_model(bulk_template, time, start, start + dt * n, dt, mu)
etmp += E

return out, etmp


Expand Down Expand Up @@ -546,11 +547,12 @@ def _get_waveform_value(
-------
Value of the current waveform
"""
n = len(template)
out = 0
time = start + dt * idx

for i in range(len(edep)):
n = len(template)

E = edep[i]
mu = drift_time[i]

Expand All @@ -559,6 +561,83 @@ def _get_waveform_value(
return out


@numba.njit(cache=True)
def _get_waveform_value_pulse_shape_library(
idx: int,
edep: ak.Array,
drift_time: ak.Array,
r: ak.Array,
z: ak.Array,
pulse_shape_library: tuple[np.array, np.array, np.array],
start: float,
dt: float,
) -> float:
"""Get the value of the waveform at a certain index.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring is out of date, but maybe it can be made very short since it's not even a public function?


Parameters
----------
idx
the index of the time array to find the waveform at.
edep
Array of energies for each step
drift_time
Array of drift times for each step
template
array of the template for the current waveforms
start
first time value of the template
dt
timestep (in ns) for the template.

Returns
-------
Value of the current waveform
"""
out = 0
time = start + dt * idx

for i in range(len(edep)):
ri, zi = _get_template_idx(r[i], z[i], pulse_shape_library[1], pulse_shape_library[2])

n = len(pulse_shape_library[0][ri][zi])

E = edep[i]
mu = drift_time[i]

out += E * _interpolate_pulse_model(
pulse_shape_library[0][ri][zi], time, start, start + dt * n, dt, mu
)

return out


@numba.njit(cache=True)
def _get_template_idx(
r: float,
z: float,
r_grid: np.array,
z_grid: np.array,
) -> tuple[int, int]:
"""Extract the closest template to a given (r,z) point with uniform grid, apart from the first and last point."""
if r < r_grid[1]:
ri = 0
elif r > r_grid[-2]:
ri = len(r_grid)
else:
dr = r_grid[2] - r_grid[1]
ri = int((r - r_grid[1]) / dr) + 1

if z < z_grid[1]:
zi = 0
elif z > z_grid[-2]:
zi = len(z_grid)
else:
dz = z_grid[2] - z_grid[1]
zi = int((z - z_grid[1]) / dz) + 1

return ri, zi


def get_current_template(
low: float = -1000, high: float = 4000, step: float = 1, mean_aoe: float = 1, **kwargs
) -> tuple[NDArray, NDArray]:
Expand Down Expand Up @@ -598,7 +677,10 @@ def _get_waveform_maximum_impl(
t: ArrayLike,
e: ArrayLike,
dist: ArrayLike,
r: ArrayLike,
z: ArrayLike,
template: ArrayLike,
pulse_shape_library: tuple[np.array, np.array, np.array],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this annotation correct?

templates_surface: ArrayLike,
activeness_surface: ArrayLike,
tmin: float,
Expand All @@ -609,18 +691,9 @@ def _get_waveform_maximum_impl(
time_step: int,
surface_step_in_um: float,
include_surface_effects: bool,
use_library: bool,
):
"""Basic implementation to get the maximum of the waveform.

Parameters
----------
t
drift time for each step.
e
energy for each step.
dist
distance to surface for each step.
"""
"""Basic implementation to get the maximum of the waveform."""
max_a = 0
max_t = 0
energy = np.sum(e)
Expand All @@ -634,8 +707,12 @@ def _get_waveform_maximum_impl(
if time < tmin or (time > (tmax + time_step)):
continue

if not has_surface_hit:
if not has_surface_hit and (not use_library):
val_tmp = _get_waveform_value(j, e, t, template, start=start, dt=1.0)
elif use_library:
val_tmp = _get_waveform_value_pulse_shape_library(
j, e, t, r, z, pulse_shape_library, start=start, dt=1.0
)
else:
val_tmp, energy = _get_waveform_value_surface(
j,
Expand Down Expand Up @@ -663,9 +740,13 @@ def _estimate_current_impl(
edep: ak.Array,
dt: ak.Array,
dist_to_nplus: ak.Array,
r: ak.Array,
z: ak.Array,
template: np.array,
pulse_shape_library: tuple[np.array, np.array, np.array],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again annotation here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also update docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes its correct, since you cant pass the class to a numba func

times: np.array,
include_surface_effects: bool,
use_library: bool,
fccd: float,
templates_surface: np.array,
activeness_surface: np.array,
Expand All @@ -684,7 +765,8 @@ def _estimate_current_impl(
dist_to_nplus
Array of distance to nplus contact for each step (can be `None`, in which case no surface effects are included.)
template
array of the bulk pulse template
array of the bulk pulse template, in the case of a full pulse shape library, 3 arrays can be passed corresponding to the
"r" and "z" coordinates of the library and the waveforms for each point.
times
time-stamps for the bulk pulse template
"""
Expand All @@ -693,7 +775,7 @@ def _estimate_current_impl(
energy = np.zeros(len(dt))

time_step = 1
n = len(template)
n = len(times)
start = times[0]

if include_surface_effects:
Expand All @@ -707,6 +789,9 @@ def _estimate_current_impl(
for i in range(len(dt)):
t = np.asarray(dt[i])
e = np.asarray(edep[i])
r_tmp = np.asarray(r[i])
z_tmp = np.asarray(z[i])

dist = np.asarray(dist_to_nplus[i])

# get the expected maximum
Expand All @@ -720,7 +805,6 @@ def _estimate_current_impl(
for j, d in enumerate(dist):
dtmp = int(d / surface_step_in_um)

# Use branchless selection
use_offset = dtmp <= ncols
offset_val = offsets[dtmp] if use_offset else 0.0
time_tmp = t[j] + offset_val * use_offset
Expand All @@ -737,9 +821,12 @@ def _estimate_current_impl(
t,
e,
dist,
template,
templates_surface,
activeness_surface,
r=r_tmp,
z=z_tmp,
template=template,
pulse_shape_library=pulse_shape_library,
templates_surface=templates_surface,
activeness_surface=activeness_surface,
tmin=tmin,
tmax=tmax,
start=start,
Expand All @@ -748,17 +835,73 @@ def _estimate_current_impl(
time_step=time_step,
surface_step_in_um=surface_step_in_um,
include_surface_effects=include_surface_effects,
use_library=use_library,
)

return A, maximum_t, energy


def prepare_surface_inputs(
dist_to_nplus: ak.Array,
edep: ak.Array,
templates_surface: ArrayLike,
activeness_surface,
template: ArrayLike,
) -> tuple:
"""Prepare the inputs needed for surface sims."""
include_surface_effects = False

# prepare surface templates
if templates_surface is not None:
if dist_to_nplus is None:
msg = "Surface effects requested but distance not provided"
raise ValueError(msg)

include_surface_effects = True
else:
# convert types to keep numba happy
templates_surface = np.zeros((1, len(template)))
dist_to_nplus = ak.full_like(edep, np.nan)

# convert types for numba
if activeness_surface is None:
activeness_surface = np.zeros(len(template))

return include_surface_effects, dist_to_nplus, templates_surface, activeness_surface


def prepare_pulse_shape_library(
template: ArrayLike | HPGePulseShapeLibrary,
times: ArrayLike,
edep: ak.Array,
r: ak.Array,
z: ak.Array,
):
"""Prepare the inputs for the full pulse shape library."""
use_library = False
if isinstance(template, HPGePulseShapeLibrary):
# convert to a form we can use
times = template.t
pulse_shape_library = (template.waveforms, template.r, template.z)
template = np.zeros_like(template.waveforms[0][0])
use_library = True

else:
pulse_shape_library = (np.zeros((1, 1, len(template))), np.zeros(1), np.zeros(1))
r = ak.full_like(edep, np.nan)
z = ak.full_like(edep, np.nan)

return use_library, pulse_shape_library, template, times, r, z


def maximum_current(
edep: ArrayLike,
drift_time: ArrayLike,
dist_to_nplus: ArrayLike | None = None,
r: ArrayLike | None = None,
z: ArrayLike | None = None,
*,
template: np.array,
template: np.array | HPGePulseShapeLibrary,
times: np.array,
fccd_in_um: float = 0,
templates_surface: ArrayLike | None = None,
Expand All @@ -776,8 +919,12 @@ def maximum_current(
Array of drift times for each step.
dist_to_nplus
Distance to n-plus electrode, only needed if surface heuristics are enabled.
r
Radial coordinate (only needed if a full PSS library is used)
z
z coordinate (only needed if a full PSS library is used).
template
array of the bulk pulse template
array of the bulk pulse template, can also be a :class:`HPGePulseShapeLibrary`.
times
time-stamps for the bulk pulse template
fccd
Expand All @@ -796,40 +943,35 @@ def maximum_current(
An Array of the maximum current/ time / energy for each hit.
"""
# extract LGDO data and units

drift_time, _ = units.unwrap_lgdo(drift_time)
edep, _ = units.unwrap_lgdo(edep)
dist_to_nplus, _ = units.unwrap_lgdo(dist_to_nplus)
r, _ = units.unwrap_lgdo(r)
z, _ = units.unwrap_lgdo(z)

include_surface_effects = False

if templates_surface is not None:
if dist_to_nplus is None:
msg = "Surface effects requested but distance not provided"
raise ValueError(msg)

include_surface_effects = True
else:
# convert types to keep numba happy
templates_surface = np.zeros((1, len(template)))
dist_to_nplus = ak.full_like(edep, np.nan)

# convert types for numba
if activeness_surface is None:
activeness_surface = np.zeros(len(template))
# prepare inputs for surface sims
include_surface_effects, dist_to_nplus, templates_surface, activeness_surface = (
prepare_surface_inputs(dist_to_nplus, edep, templates_surface, activeness_surface, template)
)

if not ak.all(ak.num(edep, axis=-1) == ak.num(drift_time, axis=-1)):
msg = "edep and drift time must have the same shape"
raise ValueError(msg)
# and for the full PS library
use_library, pulse_shape_library, template, times, r, z = prepare_pulse_shape_library(
template, times, edep, r, z
)

# and now compute the current
curr, time, energy = _estimate_current_impl(
ak.values_astype(ak.Array(edep), np.float64),
ak.values_astype(ak.Array(drift_time), np.float64),
ak.values_astype(ak.Array(dist_to_nplus), np.float64),
r=ak.values_astype(ak.Array(r), np.float64),
z=ak.values_astype(ak.Array(z), np.float64),
template=template,
pulse_shape_library=pulse_shape_library,
times=times,
fccd=fccd_in_um,
include_surface_effects=include_surface_effects,
use_library=use_library,
templates_surface=templates_surface,
activeness_surface=activeness_surface,
surface_step_in_um=surface_step_in_um,
Expand Down
Loading
Loading