diff --git a/src/reboost/optmap/convolve.py b/src/reboost/optmap/convolve.py index d05abb0..b6c2072 100644 --- a/src/reboost/optmap/convolve.py +++ b/src/reboost/optmap/convolve.py @@ -155,11 +155,13 @@ def iterate_stepwise_depositions_scintillate( msg = "the pe processors only support already reshaped output" raise ValueError(msg) - builder = ak.ArrayBuilder() rng = np.random.default_rng() if rng is None else rng - _iterate_stepwise_depositions_scintillate(edep_hits, rng, scint_mat_params, mode, builder) + counts = ak.num(edep_hits.edep) + output_array = _iterate_stepwise_depositions_scintillate( + edep_hits, rng, scint_mat_params, mode, ak.sum(counts) + ) - return builder.snapshot() + return ak.unflatten(output_array, counts) def iterate_stepwise_depositions_numdet( @@ -174,9 +176,9 @@ def iterate_stepwise_depositions_numdet( msg = "the pe processors only support already reshaped output" raise ValueError(msg) - builder = ak.ArrayBuilder() rng = np.random.default_rng() if rng is None else rng - res = _iterate_stepwise_depositions_numdet( + counts = ak.num(edep_hits.num_scint_ph) + output_array, res = _iterate_stepwise_depositions_numdet( edep_hits, rng, np.where(optmap.dets == det)[0][0], @@ -184,7 +186,7 @@ def iterate_stepwise_depositions_numdet( map_scaling_sigma, optmap.edges, optmap.weights, - builder, + ak.sum(counts), ) if res["det_no_stats"] > 0: @@ -199,7 +201,7 @@ def iterate_stepwise_depositions_numdet( (res["oob"] / (res["ib"] + res["oob"])) * 100, ) - return builder.snapshot() + return ak.unflatten(output_array, counts) def iterate_stepwise_depositions_times( @@ -211,11 +213,13 @@ def iterate_stepwise_depositions_times( msg = "the pe processors only support already reshaped output" raise ValueError(msg) - builder = ak.ArrayBuilder() rng = np.random.default_rng() if rng is None else rng - _iterate_stepwise_depositions_times(edep_hits, rng, scint_mat_params, builder) + counts = ak.sum(edep_hits.num_det_ph, axis=1) + output_array = _iterate_stepwise_depositions_times( + edep_hits, rng, scint_mat_params, ak.sum(counts) + ) - return builder.snapshot() + return ak.unflatten(output_array, counts) _pdg_func = numba_pdgid_funcs() @@ -321,16 +325,15 @@ def _iterate_stepwise_depositions_pois( # - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?) @njit(parallel=False, nogil=True, cache=True) def _iterate_stepwise_depositions_scintillate( - edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str, builder + edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str, output_length: int ): pdgid_map = {} + output = np.empty(shape=output_length, dtype=np.int64) + output_index = 0 for rowid in range(len(edep_hits)): # iterate hits hit = edep_hits[rowid] - builder.begin_list() - - # iterate steps inside the hit - for si in range(len(hit.particle)): + for si in range(len(hit.particle)): # iterate steps inside the hit # get the particle information. particle = hit.particle[si] if particle not in pdgid_map: @@ -345,10 +348,11 @@ def _iterate_stepwise_depositions_scintillate( rng, emission_term_model=("poisson" if mode == "no-fano" else "normal_fano"), ) - builder.integer(num_phot) + output[output_index] = num_phot + output_index += 1 - # assert len(hit_output) == len(hit.particle) - builder.end_list() + assert output_index == output_length + return output # - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking @@ -362,13 +366,14 @@ def _iterate_stepwise_depositions_numdet( map_scaling_sigma: float, optmap_edges, optmap_weights, - builder, + output_length: int, ): oob = ib = det_no_stats = 0 + output = np.empty(shape=output_length, dtype=np.int64) + output_index = 0 for rowid in range(len(edep_hits)): # iterate hits hit = edep_hits[rowid] - builder.begin_list() map_scaling_evt = map_scaling if map_scaling_sigma > 0: @@ -400,11 +405,11 @@ def _iterate_stepwise_depositions_numdet( ib += 1 pois_cnt = 0 if detp <= 0.0 else rng.poisson(lam=hit.num_scint_ph[si] * detp) - builder.integer(pois_cnt) - - builder.end_list() + output[output_index] = pois_cnt + output_index += 1 - return {"oob": oob, "ib": ib, "det_no_stats": det_no_stats} + assert output_index == output_length + return output, {"oob": oob, "ib": ib, "det_no_stats": det_no_stats} # - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking @@ -412,13 +417,14 @@ def _iterate_stepwise_depositions_numdet( # - the output dictionary is not threadsafe, so parallel=True is not working with it. @njit(parallel=False, nogil=True, cache=True) def _iterate_stepwise_depositions_times( - edep_hits, rng, scint_mat_params: sc.ComputedScintParams, builder + edep_hits, rng, scint_mat_params: sc.ComputedScintParams, output_length: int ): pdgid_map = {} + output = np.empty(shape=output_length, dtype=np.float64) + output_index = 0 for rowid in range(len(edep_hits)): # iterate hits hit = edep_hits[rowid] - builder.begin_list() assert len(hit.particle) == len(hit.num_det_ph) # iterate steps inside the hit @@ -436,10 +442,12 @@ def _iterate_stepwise_depositions_times( # get time spectrum. # note: we assume "immediate" propagation after scintillation. scint_times = sc.scintillate_times(scint_mat_params, part, pois_cnt, rng) + hit.time[si] - for ti in range(len(scint_times)): - builder.real(scint_times[ti]) + assert len(scint_times) == pois_cnt + output[output_index : output_index + len(scint_times)] = scint_times + output_index += len(scint_times) - builder.end_list() + assert output_index == output_length + return output def _get_scint_params(material: str):