diff --git a/src/reboost/__main__.py b/src/reboost/__main__.py new file mode 100644 index 0000000..da6bd1e --- /dev/null +++ b/src/reboost/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .cli import cli + +if __name__ == "__main__": + cli() diff --git a/src/reboost/optmap/__main__.py b/src/reboost/optmap/__main__.py new file mode 100644 index 0000000..bf636c3 --- /dev/null +++ b/src/reboost/optmap/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .cli import optical_cli + +if __name__ == "__main__": + optical_cli() diff --git a/src/reboost/optmap/convolve.py b/src/reboost/optmap/convolve.py index 33ed827..bd9d8c9 100644 --- a/src/reboost/optmap/convolve.py +++ b/src/reboost/optmap/convolve.py @@ -32,7 +32,7 @@ class OptmapForConvolve(NamedTuple): def open_optmap(optmap_fn: str) -> OptmapForConvolve: dets = lh5.ls(optmap_fn, "/channels/") - detidx = np.arange(0, dets.shape[0]) + detidx = np.arange(0, len(dets)) optmap_all = lh5.read("/all/prob", optmap_fn) assert isinstance(optmap_all, Histogram) @@ -61,14 +61,14 @@ def open_optmap(optmap_fn: str) -> OptmapForConvolve: raise ValueError(msg) else: detidx = np.array([OPTMAP_ANY_CH]) - dets = np.array(["all"]) + dets = ["all"] # check the exponent from the optical map file if "_hitcounts_exp" in lh5.ls(optmap_fn): msg = "found _hitcounts_exp which is not supported any more" raise RuntimeError(msg) - dets = [d.replace("/channels/", "") for d in dets] + dets = np.array([d.replace("/channels/", "") for d in dets]) return OptmapForConvolve(dets, detidx, optmap_edges, ow) diff --git a/src/reboost/optmap/mapview.py b/src/reboost/optmap/mapview.py index 9b21682..66243c4 100644 --- a/src/reboost/optmap/mapview.py +++ b/src/reboost/optmap/mapview.py @@ -95,7 +95,7 @@ def _read_data( histogram_choice: str = "prob", ) -> tuple[tuple[NDArray], NDArray]: histogram = histogram_choice if histogram_choice != "prob_unc_rel" else "prob" - detid = f"channels/{detid}" if detid != all and not detid.startswith("channels/") else detid + detid = f"channels/{detid}" if detid != "all" and not detid.startswith("channels/") else detid optmap_all = lh5.read(f"/{detid}/{histogram}", optmap_fn) optmap_edges = tuple([b.edges for b in optmap_all.binning]) diff --git a/src/reboost/spms/__init__.py b/src/reboost/spms/__init__.py index b650812..03bbcc8 100644 --- a/src/reboost/spms/__init__.py +++ b/src/reboost/spms/__init__.py @@ -1,5 +1,10 @@ from __future__ import annotations -from .pe import detected_photoelectrons, emitted_scintillation_photons, load_optmap +from .pe import detected_photoelectrons, emitted_scintillation_photons, load_optmap, load_optmap_all -__all__ = ["detected_photoelectrons", "emitted_scintillation_photons", "load_optmap"] +__all__ = [ + "detected_photoelectrons", + "emitted_scintillation_photons", + "load_optmap", + "load_optmap_all", +] diff --git a/src/reboost/spms/pe.py b/src/reboost/spms/pe.py index 5588c52..0884e9f 100644 --- a/src/reboost/spms/pe.py +++ b/src/reboost/spms/pe.py @@ -140,8 +140,8 @@ def detected_photoelectrons( """ hits = ak.Array( { - "num_scint_ph": num_scint_ph, - "particle": particle, + "num_scint_ph": units_conv_ak(num_scint_ph, "dimensionless"), + "particle": units_conv_ak(particle, "dimensionless"), "time": units_conv_ak(time, "ns"), "xloc": units_conv_ak(xloc, "m"), "yloc": units_conv_ak(yloc, "m"), @@ -171,7 +171,12 @@ def emitted_scintillation_photons( material scintillating material name. """ - hits = ak.Array({"edep": units_conv_ak(edep, "keV"), "particle": particle}) + hits = ak.Array( + { + "edep": units_conv_ak(edep, "keV"), + "particle": units_conv_ak(particle, "dimensionless"), + } + ) scint_mat_params = convolve._get_scint_params(material) ph = convolve.iterate_stepwise_depositions_scintillate(hits, scint_mat_params) diff --git a/tests/hit/configs/spms.yaml b/tests/hit/configs/spms.yaml index 07f5d66..45b0793 100644 --- a/tests/hit/configs/spms.yaml +++ b/tests/hit/configs/spms.yaml @@ -1,5 +1,6 @@ objects: spms: "['S001', 'S002']" + optmap_test: reboost.spms.load_optmap_all(ARGS.optmap_path) processing_groups: - name: spms diff --git a/tests/hit/test_build_hit.py b/tests/hit/test_build_hit.py index 0411f79..f6d88f9 100644 --- a/tests/hit/test_build_hit.py +++ b/tests/hit/test_build_hit.py @@ -333,6 +333,7 @@ def test_spms(test_gen_lh5_scint, tmptestdir): m.create_probability() m.write_lh5(map_file, "channels/S001", "overwrite_file") m.write_lh5(map_file, "channels/S002", "write_safe") + m.write_lh5(map_file, "all", "write_safe") outfile = f"{tmptestdir}/spms_hit.lh5" reboost.build_hit(