Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions src/reboost/optmap/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,13 @@ def iterate_stepwise_depositions_scintillate(
msg = "the pe processors only support already reshaped output"
raise ValueError(msg)

builder = ak.ArrayBuilder()
rng = np.random.default_rng() if rng is None else rng
_iterate_stepwise_depositions_scintillate(edep_hits, rng, scint_mat_params, mode, builder)
counts = ak.num(edep_hits.edep)
output_array = _iterate_stepwise_depositions_scintillate(
edep_hits, rng, scint_mat_params, mode, ak.sum(counts)
)

return builder.snapshot()
return ak.unflatten(output_array, counts)


def iterate_stepwise_depositions_numdet(
Expand All @@ -174,17 +176,17 @@ def iterate_stepwise_depositions_numdet(
msg = "the pe processors only support already reshaped output"
raise ValueError(msg)

builder = ak.ArrayBuilder()
rng = np.random.default_rng() if rng is None else rng
res = _iterate_stepwise_depositions_numdet(
counts = ak.num(edep_hits.num_scint_ph)
output_array, res = _iterate_stepwise_depositions_numdet(
edep_hits,
rng,
np.where(optmap.dets == det)[0][0],
map_scaling,
map_scaling_sigma,
optmap.edges,
optmap.weights,
builder,
ak.sum(counts),
)

if res["det_no_stats"] > 0:
Expand All @@ -199,7 +201,7 @@ def iterate_stepwise_depositions_numdet(
(res["oob"] / (res["ib"] + res["oob"])) * 100,
)

return builder.snapshot()
return ak.unflatten(output_array, counts)


def iterate_stepwise_depositions_times(
Expand All @@ -211,11 +213,13 @@ def iterate_stepwise_depositions_times(
msg = "the pe processors only support already reshaped output"
raise ValueError(msg)

builder = ak.ArrayBuilder()
rng = np.random.default_rng() if rng is None else rng
_iterate_stepwise_depositions_times(edep_hits, rng, scint_mat_params, builder)
counts = ak.sum(edep_hits.num_det_ph, axis=1)
output_array = _iterate_stepwise_depositions_times(
edep_hits, rng, scint_mat_params, ak.sum(counts)
)

return builder.snapshot()
return ak.unflatten(output_array, counts)


_pdg_func = numba_pdgid_funcs()
Expand Down Expand Up @@ -321,16 +325,15 @@ def _iterate_stepwise_depositions_pois(
# - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
@njit(parallel=False, nogil=True, cache=True)
def _iterate_stepwise_depositions_scintillate(
edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str, builder
edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str, output_length: int
):
pdgid_map = {}
output = np.empty(shape=output_length, dtype=np.int64)

output_index = 0
for rowid in range(len(edep_hits)): # iterate hits
hit = edep_hits[rowid]
builder.begin_list()

# iterate steps inside the hit
for si in range(len(hit.particle)):
for si in range(len(hit.particle)): # iterate steps inside the hit
# get the particle information.
particle = hit.particle[si]
if particle not in pdgid_map:
Expand All @@ -345,10 +348,11 @@ def _iterate_stepwise_depositions_scintillate(
rng,
emission_term_model=("poisson" if mode == "no-fano" else "normal_fano"),
)
builder.integer(num_phot)
output[output_index] = num_phot
output_index += 1

# assert len(hit_output) == len(hit.particle)
builder.end_list()
assert output_index == output_length
return output


# - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking
Expand All @@ -362,13 +366,14 @@ def _iterate_stepwise_depositions_numdet(
map_scaling_sigma: float,
optmap_edges,
optmap_weights,
builder,
output_length: int,
):
oob = ib = det_no_stats = 0
output = np.empty(shape=output_length, dtype=np.int64)

output_index = 0
for rowid in range(len(edep_hits)): # iterate hits
hit = edep_hits[rowid]
builder.begin_list()

map_scaling_evt = map_scaling
if map_scaling_sigma > 0:
Expand Down Expand Up @@ -400,25 +405,26 @@ def _iterate_stepwise_depositions_numdet(
ib += 1

pois_cnt = 0 if detp <= 0.0 else rng.poisson(lam=hit.num_scint_ph[si] * detp)
builder.integer(pois_cnt)

builder.end_list()
output[output_index] = pois_cnt
output_index += 1

return {"oob": oob, "ib": ib, "det_no_stats": det_no_stats}
assert output_index == output_length
return output, {"oob": oob, "ib": ib, "det_no_stats": det_no_stats}


# - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking
# - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
# - the output dictionary is not threadsafe, so parallel=True is not working with it.
@njit(parallel=False, nogil=True, cache=True)
def _iterate_stepwise_depositions_times(
edep_hits, rng, scint_mat_params: sc.ComputedScintParams, builder
edep_hits, rng, scint_mat_params: sc.ComputedScintParams, output_length: int
):
pdgid_map = {}
output = np.empty(shape=output_length, dtype=np.float64)

output_index = 0
for rowid in range(len(edep_hits)): # iterate hits
hit = edep_hits[rowid]
builder.begin_list()

assert len(hit.particle) == len(hit.num_det_ph)
# iterate steps inside the hit
Expand All @@ -436,10 +442,12 @@ def _iterate_stepwise_depositions_times(
# get time spectrum.
# note: we assume "immediate" propagation after scintillation.
scint_times = sc.scintillate_times(scint_mat_params, part, pois_cnt, rng) + hit.time[si]
for ti in range(len(scint_times)):
builder.real(scint_times[ti])
assert len(scint_times) == pois_cnt
output[output_index : output_index + len(scint_times)] = scint_times
output_index += len(scint_times)

builder.end_list()
assert output_index == output_length
return output


def _get_scint_params(material: str):
Expand Down
Loading