diff --git a/src/dspeed/build_dsp.py b/src/dspeed/build_dsp.py index 0dac1c44..c0ddb0ab 100644 --- a/src/dspeed/build_dsp.py +++ b/src/dspeed/build_dsp.py @@ -7,54 +7,92 @@ import logging import os +import re import time from collections.abc import Collection, Mapping +from concurrent.futures import ProcessPoolExecutor +from fnmatch import fnmatch +from functools import partial +from itertools import chain -import h5py -import numpy as np -from lgdo import lh5 +from lgdo import LGDO, Struct, Table, lh5 from tqdm.auto import tqdm from yaml import safe_load -from .errors import DSPFatal +from .errors import DSPFatal, ProcessingChainError from .processing_chain import build_processing_chain log = logging.getLogger("dspeed") def build_dsp( - f_raw: str, - f_dsp: str, + raw_in: str | LGDO, + dsp_out: str | None = None, dsp_config: str | Mapping = None, lh5_tables: Collection[str] | str = None, + base_group: str = None, database: str | Mapping = None, outputs: Collection[str] = None, - n_max: int = np.inf, write_mode: str = None, + entry_list: Collection[int] = None, + entry_mask: Collection[bool] = None, + i_start: int = 0, + n_entries: int | None = None, buffer_len: int = 3200, block_width: int = 16, - chan_config: Mapping[str, str] = None, + processes: int = None, + chan_config: str | Mapping[str, str] = None, ) -> None: """Convert raw-tier LH5 data into dsp-tier LH5 data by running a sequence of processors via the :class:`~.processing_chain.ProcessingChain`. Parameters ---------- - f_raw - name of raw-tier LH5 file to read from. - f_dsp - name of dsp-tier LH5 file to write to. + raw_in + raw data to process. Can be name of raw-tier LH5 file to read from, + LH5Iterator, or LGDO Table + dsp_out + name of file in which to output data. If None return a :class:`lgdo.Struct` or + :class:`lgdo.Table` dsp_config - :class:`dict` or name of JSON or YAML file containing - :class:`~.processing_chain.ProcessingChain` config. See - :func:`~.processing_chain.build_processing_chain` for details. + :class:`dict` or name of JSON or YAML file containing the recipe for computing + DSP parameters. If ``chan_config`` is provided, this is the default configuration + to use. Can only be ``None`` if ``chan_config`` is provided, in which case we + skip channels that are not found in ``chan_config`` The format is as follows: + + .. code-block:: json + + { + "inputs" : [ + { "file": "fname", "group": "gname", "prefix": "pre_" }, + ] + "outputs" : [ "par1", "par2" ] + "processors" : { + ... + } + } + + - ``inputs`` (optional) -- list of files/lh5 table names to read input data from. + these will be friended to any input data provided to build_processing_chain. + - ``file`` -- file path + - ``group`` -- lh5 table group name. + - ``prefix`` (optional) -- prefix to disambiguate variable names + - ``suffix`` (optional) -- suffix to disambiguate variable names + - ``outputs`` (optional) -- list of output parameters (strings) to compute by + default. This will be used if no argument is provided for ``outputs`` + - ``processors`` -- configuration for :class:`~.processing_chain.ProcessingChain`. + See :func:`~.processing_chain.build_processing_chain` for details. lh5_tables list of LGDO groups to process in the input file. These table should include all input variables for processing or contain a subgroup called raw that contains such a table. If ``None``, process - all valid groups. Note that wildcards are accepted (e.g. "ch*"). + all valid groups. Note that wildcards are accepted (e.g. "ch*"). Not a + valid argument if ``raw_in`` is an :class:`lgdo.Table`. + base_group + name of group in which to find tables listed in ``lh5_tables``. By default, + check if there is a base group called ``raw``, otherwise use no base. database - dictionary or name of JSON or YAMLfile containing a parameter database. See + dictionary or name of JSON or YAML file containing a parameter database. See :func:`~.processing_chain.build_processing_chain` for details. outputs list of parameter names to write to the output file. If not provided, @@ -71,71 +109,103 @@ def build_dsp( block_width number of waveforms to process at a time. chan_config - contains JSON or YAML DSP configuration file names for every table in - `lh5_tables`. + an ordered mapping, or a json file containing such a mapping, from + a channel or wildcard pattern to a DSP config. Loop over channels in + ``lh5_tables`` and match them to a separate DSP config. If no matching + channel or pattern is found, use ``dsp_config`` as a default. If channel + matches several patterns, use the first one found; an ordered mapping + can be used to override certain patterns. For example: + + .. code-block:: JSON + { + "ch1*": "config1.json", + "ch2000000": "config2.json", + "ch2*": "config3.json" + } + + will process all channels beginning with 2, except for 2000000, with config3. """ + db_parser = re.compile(r"(?![^\w_.])db\.[\w_.]+") - if chan_config is not None: - # clear existing output files - if write_mode == "r": - if os.path.isfile(f_dsp): - os.remove(f_dsp) - write_mode = "a" - - for tb, dsp_config in chan_config.items(): - log.debug(f"processing table: {tb} with DSP config file {dsp_config}") - try: - build_dsp( - f_raw, - f_dsp, - dsp_config, - [tb], - database, - outputs, - n_max, - write_mode, - buffer_len, - block_width, - ) - except RuntimeError: - log.debug(f"table {tb} not found") - return - - raw_store = lh5.LH5Store() - lh5_file = raw_store.gimme_file(f_raw, "r") - if lh5_file is None: - raise ValueError(f"input file not found: {f_raw}") - return - - # if no group is specified, assume we want to decode every table in the file - if lh5_tables is None: - lh5_tables = lh5.ls(f_raw) - elif isinstance(lh5_tables, str): - lh5_tables = lh5.ls(f_raw, lh5_tables) - elif isinstance(lh5_tables, Collection): - lh5_tables = [tab for tab_wc in lh5_tables for tab in lh5.ls(f_raw, tab_wc)] - elif not ( - hasattr(lh5_tables, "__iter__") - and all(isinstance(el, str) for el in lh5_tables) - ): - raise RuntimeError("lh5_tables must be None, a string, or a list of strings") - - # check if group points to raw data; sometimes 'raw' is nested, e.g g024/raw - for i, tb in enumerate(lh5_tables): - if ( - "raw" not in tb - and not isinstance(raw_store.gimme_file(lh5_file, "r")[tb], h5py.Dataset) - and lh5.ls(lh5_file, f"{tb}/raw") + if isinstance(lh5_tables, str): + lh5_tables = [lh5_tables] + + if isinstance(raw_in, (Table, lh5.LH5Iterator)): + # single table + + # in this case, lh5_tables will just be used for naming output group + if base_group is None: + base_group = "" + if lh5_tables is None: + lh5_tables = [""] + elif len(lh5_tables) > 1: + raise RuntimeError( + "Cannot have more than one value in lh5_tables for input of type Table or LH5Iterator" + ) + + elif isinstance(raw_in, str): + # file name + # default base_group behavior + if base_group is None: + if lh5.ls(raw_in, "raw"): + base_group = "raw" + else: + base_group = "" + + # if no group is specified, assume we want to decode every table in the file + if lh5_tables is None: + lh5_tables = lh5.ls(raw_in, f"{base_group}/*") + elif isinstance(lh5_tables, str): + lh5_tables = lh5.ls(raw_in, f"{base_group}/f{lh5_tables}") + elif isinstance(lh5_tables, Collection): + lh5_tables = [ + tab + for tab_wc in lh5_tables + for tab in lh5.ls(raw_in, f"{base_group}/{tab_wc}") + ] + + elif not ( + isinstance(lh5_tables, Collection) + and all(isinstance(el, str) for el in lh5_tables) ): - lh5_tables[i] = f"{tb}/raw" - elif not lh5.ls(lh5_file, tb): - del lh5_tables[i] + raise RuntimeError( + "lh5_tables must be None, a string, or a collection of strings" + ) - if len(lh5_tables) == 0: - raise RuntimeError(f"could not find any valid LH5 table in {f_raw}") + # check if group points to raw data; sometimes 'raw' is nested, e.g g024/raw + for i, tb in enumerate(lh5_tables): + if lh5.ls(raw_in, f"{tb}/*") == [f"{tb}/raw"]: + lh5_tables[i] = f"{tb}/raw" + elif not lh5.ls(raw_in, tb): + del lh5_tables[i] - # get the database parameters. For now, this will just be a dict in a - # file, but eventually we will want to interface with the metadata repo + if len(lh5_tables) == 0: + raise RuntimeError(f"could not find any valid LH5 table in {raw_in}") + + else: + raise RuntimeError( + f"raw_in was not a file name, Table, or LH5Iterator: {raw_in}" + ) + + # get the config(s) + if isinstance(dsp_config, str): + with open(lh5.utils.expand_path(dsp_config)) as config_file: + dsp_config = safe_load(config_file) + + if isinstance(chan_config, str): + with open(lh5.utils.expand_path(chan_config)) as config_file: + # safe_load is order preserving, but doesn't load into an OrderedDict + # and so may not be totally robust here... + chan_config = safe_load(config_file) + elif chan_config is None: + chan_config = {} + + for chan, config in chan_config.items(): + if isinstance(config, str): + with open(lh5.utils.expand_path(config)) as config_file: + chan_config[chan] = safe_load(config_file) + + # get the database parameters if isinstance(database, str): with open(lh5.utils.expand_path(database)) as db_file: database = safe_load(db_file) @@ -143,100 +213,285 @@ def build_dsp( if database and not isinstance(database, Mapping): raise ValueError("input database is not a valid JSON or YAML file or dict") - if write_mode is None and os.path.isfile(f_dsp): - raise FileExistsError( - f"output file {f_dsp} exists. Set the 'write_mode' keyword" - ) + # Setup output + if dsp_out is None: + # Output to tables + if lh5_tables == [""]: + dsp_st = Table() + else: + dsp_st = Struct() + else: + # Output to file + if write_mode is None and os.path.isfile(dsp_out): + raise FileExistsError( + f"output file {dsp_out} exists. Set the 'write_mode' keyword" + ) + + # clear existing output files + if write_mode == "r": + if os.path.isfile(dsp_out): + os.remove(dsp_out) + + dsp_st = lh5.LH5Store(keep_open=True) - # clear existing output files - if write_mode == "r": - if os.path.isfile(f_dsp): - os.remove(f_dsp) + # Initialize processing + if processes and isinstance(raw_in, (lh5.LH5Iterator, str)): + pool = ProcessPoolExecutor(max_workers=processes) + else: + pool = None + dsp_names = [] + async_results = [] # loop over tables to run DSP on for tb in lh5_tables: - # load primary table and build processing chain and output table - tot_n_rows = raw_store.read_n_rows(tb, f_raw) - if n_max and n_max < tot_n_rows: - tot_n_rows = n_max - - chan_name = tb.split("/")[0] - log.info(f"Processing table {tb} with {tot_n_rows} rows") - start = time.time() - db_dict = database.get(chan_name) if database else None - if db_dict is not None: - log.info(f"Found database for {chan_name}") - tb_name = tb.replace("/raw", "/dsp") - - write_offset = 0 - raw_store.gimme_file(f_dsp, "a") - if write_mode == "a" and lh5.ls(f_dsp, tb_name): - write_offset = raw_store.read_n_rows(tb_name, f_dsp) - - loading_time = 0 - write_time = 0 - start = time.time() - # Main processing loop - lh5_it = lh5.LH5Iterator(f_raw, tb, buffer_len=buffer_len, n_entries=tot_n_rows) - proc_chain = None - curr = time.time() - loading_time += curr - start - processing_time = 0 - - for lh5_in in lh5_it: - loading_time += time.time() - curr - # Initialize - - if proc_chain is None: - proc_chain_start = time.time() - proc_chain, lh5_it.field_mask, tb_out = build_processing_chain( - lh5_in, dsp_config, db_dict, outputs, block_width + # get the config to use + this_config = dsp_config + for pat, config in chan_config.items(): + if fnmatch(tb, pat): + this_config = config + break + + # get the DB values + if tb not in ("", "raw"): + chan_name = next(k for k in tb.split("/") if k not in ("", "raw")) + db_dict = database.get(chan_name) if database else None + if db_dict is not None: + log.info(f"Found database for {chan_name}") + else: + db_dict = database + + # get input as either table or iterator + if isinstance(raw_in, str): + # Setup lh5 iterator from input + lh5_in = lh5.LH5Iterator( + raw_in, + tb, + entry_list=entry_list, + entry_mask=entry_mask, + i_start=i_start, + n_entries=n_entries, + buffer_len=buffer_len, + ) + else: + lh5_in = raw_in + + # Check for aux input files + inputs = [] + config_inputs = this_config.get("inputs", []) + if isinstance(config_inputs, Mapping): + inputs += [ + ( + config_inputs["file"], + config_inputs["group"], + config_inputs.get("prefix", ""), + config_inputs.get("suffix", ""), ) - if log.getEffectiveLevel() >= logging.INFO: - progress_bar = tqdm( - desc=f"Processing table {tb}", - total=tot_n_rows, - delay=2, - unit=" rows", - ) - log.info( - f"Table: {tb} processing chain built in {time.time() - proc_chain_start:.2f} seconds" + ] + elif isinstance(config_inputs, Collection): + inputs += [ + (ci["file"], ci["group"], ci.get("prefix", ""), ci.get("suffix", "")) + for ci in config_inputs + ] + + for file, group, prefix, suffix in inputs: + # check if file points to a db override + if db_parser.fullmatch(file): + try: + db_node = db_dict + for db_key in file.split(".")[1:]: + db_node = db_node[db_key] + log.debug(f"database lookup: found {db_node} for {file}") + except (KeyError, TypeError): + raise ProcessingChainError(f"did not find {file} in database.") + + # check if group points to a db override + if db_parser.fullmatch(group): + try: + db_node = db_dict + for db_key in file.split(".")[1:]: + db_node = db_node[db_key] + log.debug(f"database lookup: found {db_node} for {group}") + except (KeyError, TypeError): + raise ProcessingChainError(f"did not find {group} in database.") + + if isinstance(lh5_in, lh5.LH5Iterator): + lh5_in.add_friend( + lh5.LH5Iterator( + file, + group, + entry_list=entry_list, + entry_mask=entry_mask, + i_start=i_start, + n_entries=n_entries, + buffer_len=buffer_len, + ), + prefix=prefix, + suffix=suffix, + ) + else: + lh5_in.join( + lh5.read(group, file, n_rows=len(lh5_in)), + prefix=prefix, + suffix=suffix, ) - entries = lh5_it.current_global_entries - processing_time_start = time.time() - try: - proc_chain.execute(0, len(lh5_in)) - except DSPFatal as e: - # Update the wf_range to reflect the file position - e.wf_range = f"{entries[0]}-{entries[-1]}" - raise e - processing_time += time.time() - processing_time_start - write_start = time.time() - raw_store.write( - obj=tb_out, - name=tb_name, - lh5_file=f_dsp, - wo_mode="o" if write_mode == "u" else "a", - write_start=write_offset + entries[0], + processors = this_config["processors"] + + # Get outputs from config if they weren't provided + if outputs is None: + outputs = this_config["outputs"] + + # start processing + if isinstance(lh5_in, lh5.LH5Iterator): + if n_entries is not None: + lh5_in.n_entries = n_entries + dsp_result = lh5_in.map( + _build_dsp, + processes=pool, + chunks=1, + aggregate=Table.append, + begin=partial( + _initialize_channel, + processors, + db_dict, + outputs, + buffer_len, + block_width, + ), + terminate=_finish_channel, ) - write_time += time.time() - write_start - if log.getEffectiveLevel() >= logging.INFO: - progress_bar.update(len(lh5_in)) - - curr = time.time() - if log.getEffectiveLevel() >= logging.INFO: - progress_bar.close() - - log.info(f"Table {tb} processed in {time.time() - start:.2f} seconds") - log.debug(f"Table {tb} loading time: {loading_time:.2f} seconds") - log.debug(f"Table {tb} write time: {write_time:.2f} seconds") - log.debug(f"Table {tb} processing time: {processing_time:.2f} seconds") - - if log.getEffectiveLevel() >= logging.DEBUG: - times = proc_chain.get_timing() - log.debug("Processor timing info: ") - for proc, t in dict( - sorted(times.items(), key=lambda item: item[1], reverse=True) - ).items(): - log.debug(f"{proc}: {t:.3f} s") + else: + if n_entries is not None: + lh5_in.resize(n_entries) + _initialize_channel( + processors, db_dict, outputs, buffer_len, block_width, lh5_in + ) + dsp_result = _build_dsp(lh5_in) + _finish_channel(lh5_in) + + # if not multiprocessing, record results; otherwise add them to our list of asynchronous results + dsp_name = tb.replace("raw", "dsp") + if not pool: + _record_result(dsp_result, dsp_st, dsp_name, dsp_out, write_mode, i_start) + else: + dsp_names.append(dsp_name) + async_results.append(dsp_result) + + # if multiprocessing, record results now + if pool: + pool.shutdown(wait=False) + for dsp_name, dsp_tb in zip(dsp_names, chain(*async_results)): + _record_result(dsp_tb, dsp_st, dsp_name, dsp_out, write_mode, i_start) + + if isinstance(dsp_st, Struct): + return dsp_st + + +def _initialize_channel(processors, db_dict, outputs, buffer_len, block_width, lh5_it): + """Initialize a processing chain for the current table""" + global t0 + t0 = time.time() + + n_rows = len(lh5_it) + if isinstance(lh5_it, lh5.LH5Iterator): + tb_name = lh5_it.groups[0] + tb_in = lh5_it.read(0) + else: + tb_name = "raw" + tb_in = lh5_it + log.info(f"Processing table {tb_name} with {n_rows} rows") + + # Setup processing chain + global proc_chain, field_mask, tb_out + proc_chain, field_mask, tb_out = build_processing_chain( + processors, + tb_in, + db_dict=db_dict, + outputs=outputs, + buffer_len=buffer_len, + block_width=block_width, + ) + + if isinstance(lh5_it, lh5.LH5Iterator): + lh5_it.reset_field_mask(field_mask) + + global progress_bar + if log.getEffectiveLevel() >= logging.INFO: + progress_bar = tqdm( + desc=f"Processing table {tb_name}", + total=n_rows, + delay=2, + unit=" rows", + ) + + global timing_info + timing_info = {"Disk read": 0, "build_processing_chain": time.time() - t0} + + global t_iter + t_iter = time.time() + + +def _finish_channel(lh5_it): + global t0 # noqa: F824 + tb_name = lh5_it.groups[0] if isinstance(lh5_it, lh5.LH5Iterator) else "raw" + log.info(f"Table {tb_name} processed in {time.time() - t0:.2f} s") + + global timing_info, proc_chain # noqa: F824 + timing_info.update(proc_chain.get_timing()) + for name, t in timing_info.items(): + log.debug(f"- {name}: {t} s") + + global progress_bar # noqa: F824 + if log.getEffectiveLevel() >= logging.INFO: + progress_bar.close() + + +def _build_dsp(tb_in, lh5_it=None): + global timing_info, t_iter # noqa: F824 + timing_info["Disk read"] += time.time() - t_iter + + i_entry = lh5_it.current_i_entry if isinstance(lh5_it, lh5.LH5Iterator) else 0 + try: + proc_chain.execute(0, len(tb_in)) + except DSPFatal as e: + # Update the wf_range to reflect the file position + e.wf_range = f"{i_entry}-{i_entry+len(tb_in)}" + raise e + + if log.getEffectiveLevel() >= logging.INFO: + progress_bar.update(len(tb_in)) + + t_iter = time.time() + + # copy and return result + tb_out.resize(len(tb_in)) + return tb_out + + +def _record_result(dsp_tb, dsp_st, dsp_name, dsp_out, write_mode, i_start): + # combine tables from each processing chunk + t0 = time.time() + + if dsp_name == "": + dsp_name = "dsp" + # write to file/update struct + if isinstance(dsp_st, lh5.LH5Store): + # if name has nesting, build a struct + gp_name = dsp_name.split("/", 1) + if len(gp_name) == 2: + dsp_tb = Struct({gp_name[1]: dsp_tb}) + gp_name = gp_name[0] + + dsp_st.write( + dsp_tb, + name=gp_name, + lh5_file=dsp_out, + wo_mode="o" if write_mode == "u" else "a", + write_start=i_start, + ) + + log.info(f"Wrote table {dsp_name} (write took {time.time() - t0} s).") + elif isinstance(dsp_st, Table): + dsp_st.update(dsp_tb) + elif isinstance(dsp_st, Struct): + dsp_st.update({dsp_name: dsp_tb}) diff --git a/src/dspeed/cli.py b/src/dspeed/cli.py index c42ec851..b0e304dd 100644 --- a/src/dspeed/cli.py +++ b/src/dspeed/cli.py @@ -180,7 +180,7 @@ def dspeed_cli(): lh5_tables=args.hdf5_groups, database=args.database, outputs=args.output_pars, - n_max=args.max_rows, + n_entries=args.max_rows, write_mode=args.writemode, buffer_len=args.chunk, block_width=args.block, diff --git a/src/dspeed/processing_chain.py b/src/dspeed/processing_chain.py index 1d33f7f0..62733c6b 100644 --- a/src/dspeed/processing_chain.py +++ b/src/dspeed/processing_chain.py @@ -8,6 +8,7 @@ import ast import importlib import itertools as it +import json import logging import re import time @@ -21,10 +22,9 @@ import lgdo import numpy as np -from lgdo import LGDO, lh5 from numba import guvectorize, vectorize from pint import Quantity, Unit -from yaml import dump, safe_load +from yaml import safe_load from .errors import DSPFatal, ProcessingChainError from .processors.round_to_nearest import round_to_nearest @@ -33,7 +33,6 @@ from .utils import numba_defaults_kwargs as nb_kwargs log = logging.getLogger("dspeed") -sto = lh5.LH5Store() # Filler value for variables to be automatically deduced later auto = "auto" @@ -258,7 +257,7 @@ def _make_buffer(self) -> np.ndarray: if self.is_const else (self.proc_chain._block_width,) + self.shape ) - len = np.prod(shape) + len = int(np.prod(shape)) # Flattened array, with padding to allow memory alignment buf = np.zeros(len + 64 // self.dtype.itemsize, dtype=self.dtype) # offset to ensure memory alignment @@ -519,8 +518,8 @@ def set_constant( return param def link_input_buffer( - self, varname: str, buff: np.ndarray | LGDO = None - ) -> np.ndarray | LGDO: + self, varname: str, buff: np.ndarray | lgdo.LGDO = None + ) -> np.ndarray | lgdo.LGDO: """Link an input buffer to a variable. Parameters @@ -599,8 +598,8 @@ def link_input_buffer( return buff def link_output_buffer( - self, varname: str, buff: np.ndarray | LGDO = None - ) -> np.ndarray | LGDO: + self, varname: str, buff: np.ndarray | lgdo.LGDO = None + ) -> np.ndarray | lgdo.LGDO: """Link an output buffer to a variable. Parameters @@ -713,7 +712,10 @@ def execute(self, start: int = 0, stop: int = None) -> None: if stop is None: stop = self._buffer_len for i in range(start, stop, self._block_width): - self._execute_procs(i, min(i + self._block_width, stop)) + try: + self._execute_procs(i, min(i + self._block_width, stop)) + except IndexError: + break def get_variable( self, expr: str, get_names_only: bool = False, expr_only: bool = False @@ -922,46 +924,10 @@ def _parse_expr( ret = ret.to(ureg.dimensionless).magnitude return ret - name = "(" + op_form.format(str(lhs), str(rhs)) + ")" - if isinstance(lhs, ProcChainVar) and isinstance(rhs, ProcChainVar): - if is_in_pint(lhs.unit) and is_in_pint(rhs.unit): - unit = op(Quantity(lhs.unit), Quantity(rhs.unit)).u - if unit == ureg.dimensionless: - unit = None - elif lhs.unit is not None and rhs.unit is not None: - if type(node.op) in (ast.Mult, ast.Div, ast.FloorDiv): - unit = op_form.format(str(lhs.unit), str(rhs.unit)) - else: - unit = str(lhs.unit) - elif lhs.unit is not None: - unit = lhs.unit - else: - unit = rhs.unit - # If both vars are coordinates, this is probably not a coord. - # If one var is a coord, this is probably a coord - out = ProcChainVar( - self, - name, - grid=None if lhs.is_coord and rhs.is_coord else auto, - is_coord=( - False if lhs.is_coord is True and rhs.is_coord is True else auto - ), - unit=unit, - ) - elif isinstance(lhs, ProcChainVar): - out = ProcChainVar( - self, - name, - unit=lhs.unit, - is_coord=lhs.is_coord, - ) - else: - out = ProcChainVar( - self, - name, - unit=rhs.unit, - is_coord=rhs.is_coord, - ) + out = ProcChainVar( + self, + "(" + op_form.format(str(lhs), str(rhs)) + ")", + ) proc_man = ProcessorManager(self, op, [lhs, rhs, out]) self._proc_managers.append(proc_man) @@ -1266,8 +1232,10 @@ def _loadlh5(path_to_file, path_in_file: str) -> np.array: # noqa: N805 list: The loaded data. """ + from lgdo import lh5 + try: - loaded_data = sto.read(path_in_file, path_to_file) + loaded_data = lh5.read(path_in_file, path_to_file) if isinstance(loaded_data, lgdo.types.Scalar): loaded_data = loaded_data.value else: @@ -1812,12 +1780,17 @@ def __init__(self, io_array: lgdo.Array, var: ProcChainVar) -> None: ) def read(self, start: int, end: int) -> None: + if start >= len(self.io_array): + raise IndexError end = min(end, len(self.io_array)) self.raw_var[0 : end - start, ...] = self.io_array[start:end, ...] def write(self, start: int, end: int) -> None: self.io_array.resize(end) - self.io_array[start:end, ...] = self.raw_var[0 : end - start, ...] + if self.var.is_const: + self.io_array[start:end, ...] = self.raw_var[...] + else: + self.io_array[start:end, ...] = self.raw_var[0 : end - start, ...] def __str__(self) -> str: return f"{self.var} linked to lgdo.Array(shape={self.io_array.shape}, dtype={self.io_array.dtype}, attrs={self.io_array.attrs})" @@ -1872,6 +1845,8 @@ def __init__(self, io_array: np.ArrayOfEqualSizedArrays, var: ProcChainVar) -> N ) def read(self, start: int, end: int) -> None: + if start >= len(self.io_array): + raise IndexError end = min(end, len(self.io_array)) self.raw_var[0 : end - start, ...] = self.io_array[start:end, ...] @@ -1976,6 +1951,8 @@ def _vov2nda(flat_arr_in, cl_in, start_idx_in, l_out, aoa_out): # noqa: N805 prev_cl = cl def read(self, start: int, end: int) -> None: + if start >= len(self.io_vov): + raise IndexError end = min(end, len(self.io_vov)) self.raw_var = 0 if np.issubdtype(self.raw_var.dtype, np.integer) else np.nan LGDOVectorOfVectorsIOManager._vov2nda( @@ -2062,10 +2039,14 @@ def __init__(self, wf_table: lgdo.WaveformTable, variable: ProcChainVar) -> None self.io_wf.dt_units = dt_units def read(self, start: int, end: int) -> None: + if start >= len(self.io_wf): + raise IndexError + end = min(end, len(self.io_wf)) self.wf_var[0 : end - start, ...] = self.io_wf.values[start:end, ...] self.t0_var[0 : end - start, ...] = self.io_wf.t0[start:end, ...] def write(self, start: int, end: int) -> None: + self.io_wf.resize(end) self.io_wf.values[start:end, ...] = self.wf_var[0 : end - start, ...] if self.variable_t0: self.io_wf.t0[start:end, ...] = self.t0_var[0 : end - start, ...] @@ -2080,32 +2061,27 @@ def __str__(self) -> str: def build_processing_chain( - lh5_in: lgdo.Table, - dsp_config: dict | str, + processors: dict | str, + tb_in: lgdo.Table = None, db_dict: dict = None, outputs: list[str] = None, + buffer_len: int = 3200, block_width: int = 16, ) -> tuple[ProcessingChain, list[str], lgdo.Table]: - """Produces a :class:`ProcessingChain` object and an LH5 - :class:`~lgdo.types.table.Table` for output parameters from an input LH5 + """Produces a :class:`ProcessingChain` object and an LGDO + :class:`~lgdo.types.table.Table` for output parameters from an input LGDO :class:`~lgdo.types.table.Table` and a JSON or YAML recipe. Parameters ---------- - lh5_in - HDF5 table from which raw data is read. At least one row of entries - should be read in prior to calling this! - - dsp_config + processors A dictionary or YAML/JSON filename containing the recipes for computing DSP parameter from raw parameters. The format is as follows: .. code-block:: json { - "outputs" : [ "par1", "par2" ] - "processors" : { - "name1, name2" : { + "name1, name2" : { "function" : "func1" "module" : "mod1" "args" : ["arg1", 3, "arg2"] @@ -2113,14 +2089,10 @@ def build_processing_chain( "init_args" : ["arg1", 3, "arg2"] "unit" : ["u1", "u2"] "defaults" : {"arg1": "defval1"} - } - } + }, + ... } - - ``outputs`` -- list of output parameters (strings) to compute by - default. See `outputs` argument - - ``processors`` -- configuration dictionary - - ``name1, name2`` -- dictionary. key contains comma-separated names of parameters computed @@ -2143,17 +2115,19 @@ def build_processing_chain( - ``unit`` -- list of strings. Units for parameters - ``defaults`` -- dictionary. Default value to be used for arguments read from the database + - The dictionary can also be nested in another, keyed as ``processors`` + tb_in + input table. This table will be linked to use as inputs when + executing processors. Can be empty (for now), as long as fields + and attrs are set. db_dict A nested :class:`dict` pointing to values for database arguments. As instance, if a processor uses the argument ``db.trap.risetime``, it will look up ``db_dict['trap']['risetime']`` and use the found value. - If no value is found, use the default defined in `dsp_config`. - + If no value is found, use the default defined in `processors`. outputs - List of parameters to put in the output LH5 table. If ``None``, - use the parameters in the ``"outputs"`` list from `dsp_config`. - + List of parameters to put in the output LGDO table. block_width number of entries to process at once. To optimize performance, a multiple of 16 is preferred, but if performance is not an issue @@ -2161,33 +2135,33 @@ def build_processing_chain( Returns ------- - (proc_chain, field_mask, lh5_out) + (proc_chain, field_mask, tb_out) - `proc_chain` -- :class:`ProcessingChain` object that is executed - - `field_mask` -- list of input fields that are used - - `lh5_out` -- output :class:`~lgdo.table.Table` containing processed - values + - `field_mask` -- list of names of input fields that will be used. + This can be used to ensure only needed values are read in. + - `tb_out` -- output :class:`~lgdo.table.Table` with size 0, with + fields and attrs set up to contain outputs """ - proc_chain = ProcessingChain(block_width, lh5_in.size) - - if isinstance(dsp_config, str): - with open(lh5.utils.expand_path(dsp_config)) as f: - dsp_config = safe_load(f) - elif dsp_config is None: - dsp_config = {"outputs": [], "processors": {}} - elif isinstance(dsp_config, MutableMapping): + db_parser = re.compile(r"(?![^\w_.])db\.[\w_.]+") + + if isinstance(processors, str): + with open(processors) as f: + processors = safe_load(f) + elif processors is None: + processors = {} + elif isinstance(processors, MutableMapping): # We don't want to modify the input! - dsp_config = deepcopy(dsp_config) + processors = deepcopy(processors) else: - raise ValueError("dsp_config must be a dict, json/yaml file, or None") + raise ValueError("processors must be a dict, json/yaml file, or None") - if outputs is None: - outputs = dsp_config["outputs"] + if "processors" in processors: + processors = processors["processors"] - processors = dsp_config["processors"] + proc_chain = ProcessingChain(block_width, buffer_len) # prepare the processor list multi_out_procs = {} - db_parser = re.compile(r"(?![^\w_.])db\.[\w_.]+") for key, node in processors.items(): # if we have multiple outputs, add each to the processesors list keys = [k for k in re.split(",| ", key) if k != ""] @@ -2305,15 +2279,12 @@ def resolve_dependencies( log.debug(f"copied output parameters: {copy_par_list}") log.debug(f"processed output parameters: {out_par_list}") - # Now add all of the input buffers from lh5_in (and also the clk time) + # Now add all of the input buffers from tb_in for input_par in input_par_list: - buf_in = lh5_in.get(input_par) - if buf_in is None: - log.warning( - f"I don't know what to do with '{input_par}'. Building output without it!" - ) + if input_par not in tb_in: + log.warning(f"'{input_par}' not found in input files or dsp config.") try: - proc_chain.link_input_buffer(input_par, buf_in) + proc_chain.link_input_buffer(input_par, tb_in[input_par]) except Exception as e: raise ProcessingChainError( f"Exception raised while linking input buffer '{input_par}'." @@ -2329,19 +2300,18 @@ def resolve_dependencies( if "args" not in recipe: fun_str = recipe if isinstance(recipe, str) else recipe["function"] fun_var = proc_chain.get_variable(fun_str) - if not isinstance(fun_var, ProcChainVar): - raise ProcessingChainError( - f"Could not find function {recipe['function']}" + if isinstance(fun_var, ProcChainVar): + new_var = proc_chain.add_variable( + name=proc_par, + dtype=fun_var.dtype, + shape=fun_var.shape, + grid=fun_var.grid, + unit=fun_var.unit, + is_coord=fun_var.is_coord, ) - new_var = proc_chain.add_variable( - name=proc_par, - dtype=fun_var.dtype, - shape=fun_var.shape, - grid=fun_var.grid, - unit=fun_var.unit, - is_coord=fun_var.is_coord, - ) - new_var._buffer = fun_var._buffer + new_var._buffer = fun_var._buffer + else: + new_var = proc_chain.set_constant(varname=proc_par, val=fun_var) log.debug(f"setting {new_var} = {fun_var}") continue @@ -2478,23 +2448,25 @@ def resolve_dependencies( except Exception as e: raise ProcessingChainError( - "Exception raised while attempting to add processor:\n" + dump(recipe) + "Exception raised while attempting to add processor:\n" + + json.dumps(recipe, indent=2) ) from e # build the output buffers - lh5_out = lgdo.Table(size=proc_chain._buffer_len) + tb_out = lgdo.Table(size=buffer_len) # add inputs that are directly copied for copy_par in copy_par_list: - buf_in = lh5_in.get(copy_par) - if buf_in is None: + if copy_par not in tb_in: log.warning( - f"Did not find {copy_par} in either input file or parameter list. Building output without it!" + f"'{copy_par}' not found in input files or dsp config. Building output without it!" ) else: - lh5_out.add_field(copy_par, buf_in) + if len(tb_in) < len(tb_out): + tb_out.resize(len(tb_in)) + tb_out.add_field(copy_par, tb_in[copy_par]) - # finally, add the output buffers to lh5_out and the proc chain + # finally, add the output buffers to tb_out and the proc chain for out_par in out_par_list: try: buf_out = proc_chain.link_output_buffer(out_par) @@ -2502,11 +2474,13 @@ def resolve_dependencies( if isinstance(recipe, str): recipe = processors[recipe] buf_out.attrs.update(recipe.get("lh5_attrs", {})) - lh5_out.add_field(out_par, buf_out) + buf_out.resize(len(tb_out)) + tb_out.add_field(out_par, buf_out) except Exception as e: raise ProcessingChainError( f"Exception raised while linking output buffer {out_par}." ) from e field_mask = input_par_list + copy_par_list - return (proc_chain, field_mask, lh5_out) + + return (proc_chain, field_mask, tb_out) diff --git a/src/dspeed/vis/waveform_browser.py b/src/dspeed/vis/waveform_browser.py index 72674063..4b4f1ca2 100644 --- a/src/dspeed/vis/waveform_browser.py +++ b/src/dspeed/vis/waveform_browser.py @@ -5,15 +5,16 @@ import math import string import sys +from collections.abc import Collection, Mapping import lgdo import matplotlib.pyplot as plt import numpy as np -import pandas -import pint from cycler import cycler from lgdo.lh5 import LH5Iterator +from lgdo.types import Table from matplotlib.lines import Line2D +from pint import Quantity, Unit from ..processing_chain import build_processing_chain from ..units import unit_registry as ureg @@ -32,22 +33,22 @@ class WaveformBrowser: def __init__( self, - files_in: str | list[str] | LH5Iterator, # noqa: F821 - lh5_group: str | list[str] = "", + raw_in: str | Collection[str] | LH5Iterator | Table, + lh5_group: str | Collection[str] = "", base_path: str = "", - entry_list: list[int] | list[list[int]] = None, - entry_mask: list[int] | list[list[int]] = None, - dsp_config: dict | str = None, - database: str | dict = None, - aux_values: pandas.DataFrame = None, - lines: str | list[str] = None, - styles: dict[str, list] | str = None, - legend: str | list[str] = None, - legend_opts: dict = None, + entry_list: Collection[int] | Collection[Collection[int]] = None, + entry_mask: Collection[bool] | Collection[Collection[bool]] = None, + dsp_config: str | Mapping = None, + database: str | Mapping = None, + aux_values: Mapping[np.ndarray] = None, + lines: str | Collection[str] = None, + styles: Mapping[str, Collection] | str = None, + legend: str | Collection[str] = None, + legend_opts: Mapping = None, n_drawn: int = 1, - x_unit: pint.Unit | str = None, - x_lim: tuple[float | str | pint.Quantity] = None, - y_lim: tuple[float | str | pint.Quantity] = None, + x_unit: str | Unit = None, + x_lim: Collection[float | str | Quantity] = None, + y_lim: Collection[float | str | Quantity] = None, norm: str = None, align: str = None, buffer_len: int = 128, @@ -56,9 +57,9 @@ def __init__( """ Parameters ---------- - files_in - name of file or list of files to browse. Can use wildcards. Can - also pass an LH5Iterator + raw_in + raw data with waveforms. Can be a file or list of lh5 files + (requires use of lh5_group argument), an LH5Iterator lh5_group HDF5 base group(s) to read containing a LGDO table that contains @@ -67,7 +68,7 @@ def __init__( the same group will be assigned to each file found base_path - base path for file. See :class:`~lgdo.lh5.LH5Store`. + base path for files. See :class:`~lgdo.lh5.LH5Store`. entry_list list of event indices to draw. If it is a nested list, use local @@ -157,18 +158,15 @@ def __init__( self.next_entry = 0 # data i/o initialization - if isinstance(files_in, LH5Iterator): - self.lh5_it = files_in + if isinstance(raw_in, LH5Iterator): + self.lh5_it = raw_in else: - # HACK: do not read VOV "tracelist", cannot be handled correctly by LH5Iterator - # remove this hack once VOV support is properly implemented self.lh5_it = LH5Iterator( - files_in, + raw_in, lh5_group, base_path=base_path, entry_list=entry_list, entry_mask=entry_mask, - field_mask={"tracelist": False}, buffer_len=buffer_len, ) @@ -203,7 +201,7 @@ def __init__( self.lines = {line: [] for line in lines} # styles - if isinstance(styles, (list, tuple)): + if isinstance(styles, Collection) and not isinstance(styles, str): self.styles = [None for _ in self.lines] for i, sty in enumerate(styles): if isinstance(sty, str): @@ -252,7 +250,7 @@ def __init__( legend_format += f"{st}{{{name}:{form}{cv}}}" self.legend_format.append(legend_format) - self.legend_kwargs = legend_opts if isinstance(legend_opts, dict) else {} + self.legend_kwargs = legend_opts if isinstance(legend_opts, Mapping) else {} # make processing chain and output buffer outputs = list(self.lines) + list(self.legend_vals) @@ -266,8 +264,8 @@ def __init__( outputs = [o for o in outputs if o not in self.aux_vals] self.proc_chain, field_mask, self.lh5_out = build_processing_chain( - self.lh5_in, dsp_config, + self.lh5_in, db_dict=database, outputs=outputs, block_width=block_width, @@ -356,7 +354,7 @@ def clear_data(self) -> None: self.n_stored = 0 def find_entry( - self, entry: int | list[int], append: bool = True, safe: bool = False + self, entry: int | Collection[int], append: bool = True, safe: bool = False ) -> None: """Find the requested data associated with entry in input files and place store it internally without drawing it. @@ -512,13 +510,13 @@ def draw_current(self, clear: bool = True) -> None: leg_handles = [] leg_labels = [] - if not isinstance(self.styles, list): + if not isinstance(self.styles, Collection): styles = self.styles # draw lines default_style = itertools.cycle(cycler(plt.rcParams["axes.prop_cycle"])) for i, lines in enumerate(self.lines.values()): - if isinstance(self.styles, list): + if isinstance(self.styles, Collection): styles = self.styles[i] else: styles = self.styles @@ -594,7 +592,7 @@ def _update_auto_limit(self, x: np.ndarray, y: np.ndarray) -> None: def draw_entry( self, - entry: int | list[int], + entry: int | Collection[int], append: bool = False, clear: bool = True, safe: bool = False, diff --git a/tests/configs/icpc-dsp-config.yaml b/tests/configs/icpc-dsp-config-yaml.yaml similarity index 88% rename from tests/configs/icpc-dsp-config.yaml rename to tests/configs/icpc-dsp-config-yaml.yaml index 32a1ca9d..26789e34 100644 --- a/tests/configs/icpc-dsp-config.yaml +++ b/tests/configs/icpc-dsp-config-yaml.yaml @@ -1,38 +1,38 @@ outputs: - -tp_min - -tp_max - -wf_min - -wf_max - -bl_mean - -bl_std - -bl_slope - -bl_intercept - -pz_slope - -pz_std - -pz_mean - -trapTmax - -tp_0_est - -tp_0_atrap - -tp_10 - -tp_20 - -tp_50 - -tp_80 - -tp_90 - -tp_99 - -tp_100 - -tp_01 - -tp_95 - -A_max - -QDrift - -dt_eff - -tp_aoe_max - -tp_aoe_samp - -trapEmax - -trapEftp - -cuspEmax - -zacEmax - -zacEftp - -cuspEftp +- tp_min +- tp_max +- wf_min +- wf_max +- bl_mean +- bl_std +- bl_slope +- bl_intercept +- pz_slope +- pz_std +- pz_mean +- trapTmax +- tp_0_est +- tp_0_atrap +- tp_10 +- tp_20 +- tp_50 +- tp_80 +- tp_90 +- tp_99 +- tp_100 +- tp_01 +- tp_95 +- A_max +- QDrift +- dt_eff +- tp_aoe_max +- tp_aoe_samp +- trapEmax +- trapEftp +- cuspEmax +- zacEmax +- zacEftp +- cuspEftp processors: tp_min, tp_max, wf_min, wf_max: function: dspeed.processors.min_max @@ -58,7 +58,10 @@ processors: unit: [ADC, ADC, ADC, ADC] t0_kernel: function: dspeed.processors.t0_filter - args: [128*ns, 2*us, "t0_kernel(round(128*ns+2*us, wf_pz.period), 'f')"] + args: + - 128*ns/wf_pz.period + - 2*us/wf_pz.period + - t0_kernel(round((128*ns+2*us)/wf_pz.period), 'f') coord_grid: wf_pz unit: ADC wf_t0_filter: @@ -122,10 +125,10 @@ processors: cusp_kernel: function: dspeed.processors.cusp_filter args: - - db.cusp.sigma - - round(db.cusp.flat, wf_blsub.period) - - db.pz.tau - - cusp_kernel(len(wf_blsub)-round(33.6*us - 4.8*us, wf_blsub.period), 'f') + - db.cusp.sigma/wf_blsub.period + - round(db.cusp.flat/wf_blsub.period) + - db.pz.tau/wf_blsub.period + - cusp_kernel(round(len(wf_blsub)-(33.6*us/wf_blsub.period)-(4.8*us/wf_blsub.period)), 'f') defaults: db.cusp.sigma: 20*us db.cusp.flat: 3*us @@ -135,10 +138,10 @@ processors: wf_cusp: function: dspeed.processors.fft_convolve_wf args: - - wf_blsub[:len(wf_blsub)-round(33.6*us, wf_blsub.period)] + - wf_blsub[:len(wf_blsub)-round(33.6*us/wf_blsub.period)] - cusp_kernel - "'v'" - - wf_cusp(round(4.8*us, wf_blsub.period) + 1, 'f') + - wf_cusp(round(4.8*us/wf_blsub.period) + 1, 'f') unit: ADC cuspEmax: function: numpy.amax @@ -256,7 +259,4 @@ processors: function: dspeed.processors.min_max args: [curr_av, aoe_t_min, tp_aoe_max, A_min, A_max] unit: [ns, ns, ADC/sample, ADC/sample] - tp_aoe_samp: - function: dspeed.processors.add - args: [tp_0_est, tp_aoe_max/16, tp_aoe_samp] - unit: ns + tp_aoe_samp: tp_0_est + tp_aoe_max/16 diff --git a/tests/processors/test_histogram.py b/tests/processors/test_histogram.py index fd69c6b3..57a0fdd1 100644 --- a/tests/processors/test_histogram.py +++ b/tests/processors/test_histogram.py @@ -1,39 +1,22 @@ -import os +import numpy as np +import pytest -from lgdo import lh5 +from dspeed.errors import DSPFatal +from dspeed.processors import histogram -from dspeed import build_dsp +def test_histogram_fixed_width(compare_numba_vs_python): + vals = np.arange(100) * 2 / 3 + with pytest.raises(DSPFatal): + histogram(vals, np.zeros(10), np.zeros(10)) -def test_histogram_fixed_width(lgnd_test_data, tmptestdir): - dsp_file = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds__numpy_test_dsp.lh5" - dsp_config = { - "outputs": ["hist_weights", "hist_borders"], - "processors": { - "hist_weights , hist_borders": { - "function": "histogram", - "module": "dspeed.processors.histogram", - "args": ["waveform", "hist_weights(100)", "hist_borders(101)"], - "unit": ["none", "ADC"], - } - }, - } - build_dsp( - f_raw=lgnd_test_data.get_path( - "lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5" - ), - f_dsp=dsp_file, - dsp_config=dsp_config, - write_mode="r", - ) - assert os.path.exists(dsp_file) + hist_weights = np.zeros(66) + hist_edges = np.zeros(67) + histogram(vals, hist_weights, hist_edges) + assert all(hist_edges == np.arange(67)) + assert all(hist_weights[0::2] == 2) and all(hist_weights[1::2] == 1) - df = lh5.read_as( - "geds/dsp/", dsp_file, "pd", field_mask=["hist_weights", "hist_borders"] - ) - - assert len(df["hist_weights"][0]) + 1 == len(df["hist_borders"][0]) - for i in range(2, len(df["hist_borders"][0])): - a = df["hist_borders"][0][i - 1] - df["hist_borders"][0][i - 2] - b = df["hist_borders"][0][i] - df["hist_borders"][0][i - 1] - assert round(a, 2) == round(b, 2) + vals[5] = np.nan + histogram(vals, hist_weights, hist_edges) + assert all(np.isnan(hist_edges)) + assert all(hist_weights == 0) diff --git a/tests/test_build_dsp.py b/tests/test_build_dsp.py index 0ae28975..72471309 100644 --- a/tests/test_build_dsp.py +++ b/tests/test_build_dsp.py @@ -1,16 +1,18 @@ import os from pathlib import Path -import lgdo +import numpy as np import pytest -from lgdo.lh5 import ls, read +from lgdo import Struct, Table, VectorOfVectors, lh5 +from test_utils import isclose from dspeed import build_dsp config_dir = Path(__file__).parent / "configs" -def test_build_dsp_json(lgnd_test_data, tmptestdir): +@pytest.fixture(scope="session") +def dsp_test_file_geds(lgnd_test_data, tmptestdir): out_name = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds_dsp.lh5" build_dsp( lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5"), @@ -21,17 +23,21 @@ def test_build_dsp_json(lgnd_test_data, tmptestdir): ) assert os.path.exists(out_name) + return out_name -def test_build_dsp_yaml(lgnd_test_data, tmptestdir): - out_name = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds_dsp.lh5" - build_dsp( + +def test_build_dsp_yaml(lgnd_test_data, dsp_test_file_geds): + dsp_out = build_dsp( lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5"), - out_name, - dsp_config=f"{config_dir}/icpc-dsp-config.yaml", + dsp_config=f"{config_dir}/icpc-dsp-config-yaml.yaml", database={"pz": {"tau": 27460.5}}, write_mode="r", ) - assert os.path.exists(out_name) + assert isinstance(dsp_out, Struct) + + # TODO: make sure the configs are the same for json and yaml so we can do this + # dsp_file = lh5.read("geds/dsp", dsp_test_file_geds) + # assert isclose(dsp_out, dsp_file) def test_build_dsp_errors(lgnd_test_data, tmptestdir): @@ -51,6 +57,93 @@ def test_build_dsp_errors(lgnd_test_data, tmptestdir): ) +# test different input types +def test_dsp_in_types(lgnd_test_data): + # input from file + raw_path = lgnd_test_data.get_path( + "lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5" + ) + dsp_file = build_dsp( + raw_path, + dsp_config=f"{config_dir}/icpc-dsp-config.json", + lh5_tables="geds/raw", + database={"pz": {"tau": 27460.5}}, + ) + assert isinstance(dsp_file, Struct) + assert "geds" in dsp_file and "dsp" in dsp_file["geds"] + + # input iterator directly + raw_it = lh5.LH5Iterator(raw_path, "geds/raw") + dsp_it = build_dsp( + raw_it, + dsp_config=f"{config_dir}/icpc-dsp-config.json", + lh5_tables="geds/raw", + database={"pz": {"tau": 27460.5}}, + ) + assert isinstance(dsp_it, Struct) + assert "geds" in dsp_it and "dsp" in dsp_it["geds"] + + # input table directly + raw_tb = lh5.read("geds/raw", raw_path) + dsp_tb = build_dsp( + raw_tb, + dsp_config=f"{config_dir}/icpc-dsp-config.json", + lh5_tables="geds/raw", + database={"pz": {"tau": 27460.5}}, + ) + assert isinstance(dsp_tb, Struct) + assert "geds" in dsp_tb and "dsp" in dsp_tb["geds"] + + # make sure these all give the same result + assert isclose(dsp_file, dsp_it) + assert isclose(dsp_file, dsp_tb) + + +# test different output types +def test_dsp_out_struct(lgnd_test_data, dsp_test_file_geds): + dsp_tb = lh5.read("geds/dsp", dsp_test_file_geds) + assert isinstance(dsp_tb, Table) + + dsp_st = build_dsp( + lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5"), + dsp_config=f"{config_dir}/icpc-dsp-config.json", + database={"pz": {"tau": 27460.5}}, + write_mode="r", + ) + assert isinstance(dsp_st, Struct) + assert "geds" in dsp_st and "dsp" in dsp_st["geds"] + assert len(dsp_tb) == len(dsp_st["geds"]["dsp"]) + assert isclose(dsp_tb, dsp_st["geds"]["dsp"]) + + +# test input field in dsp config +def test_aux_inputs(lgnd_test_data, dsp_test_file_geds): + raw_in = lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5") + + # This config will find wf_min in an already calculated dsp file + # and recalculate it as wf_min2 and make sure they are the same + dsp_config = { + "inputs": [{"file": dsp_test_file_geds, "group": "geds/dsp", "suffix": "1"}], + "outputs": ["compare", "tp_max1", "tp_max2"], + "processors": { + "tp_min2, tp_max2, wf_min2, wf_max2": { + "function": "min_max", + "module": "dspeed.processors", + "args": ["waveform", "tp_min2", "tp_max2", "wf_min2", "wf_max2"], + "unit": ["ns", "ns", "ADC", "ADC"], + }, + "compare": "tp_max1 - tp_max2", + }, + } + + dsp_out = build_dsp( + raw_in, + lh5_tables="geds/raw", + dsp_config=dsp_config, + ) + assert np.all(np.isclose(dsp_out["geds"]["dsp"]["compare"], 0)) + + @pytest.fixture(scope="session") def dsp_test_file_spm(lgnd_test_data, tmptestdir): chan_config = { @@ -64,7 +157,7 @@ def dsp_test_file_spm(lgnd_test_data, tmptestdir): lgnd_test_data.get_path("lh5/L200-comm-20211130-phy-spms.lh5"), out_file, {}, - n_max=5, + n_entries=5, lh5_tables=chan_config.keys(), chan_config=chan_config, write_mode="r", @@ -76,13 +169,37 @@ def dsp_test_file_spm(lgnd_test_data, tmptestdir): def test_build_dsp_spms_channelwise(dsp_test_file_spm): - assert ls(dsp_test_file_spm) == ["ch0", "ch1", "ch2"] - assert ls(dsp_test_file_spm, "ch0/") == ["ch0/dsp"] - assert ls(dsp_test_file_spm, "ch0/dsp/") == [ + assert lh5.ls(dsp_test_file_spm) == ["ch0", "ch1", "ch2"] + assert lh5.ls(dsp_test_file_spm, "ch0/") == ["ch0/dsp"] + assert lh5.ls(dsp_test_file_spm, "ch0/dsp/") == [ "ch0/dsp/energies", "ch0/dsp/trigger_pos", ] - lh5_obj = read("/ch0/dsp/energies", dsp_test_file_spm) - assert isinstance(lh5_obj, lgdo.VectorOfVectors) + lh5_obj = lh5.read("/ch0/dsp/energies", dsp_test_file_spm) + assert isinstance(lh5_obj, VectorOfVectors) assert len(lh5_obj) == 5 + + +def test_build_dsp_multiprocessing(lgnd_test_data, dsp_test_file_spm): + chan_config = { + "ch0/raw": f"{config_dir}/sipm-dsp-config.json", + "ch1/raw": f"{config_dir}/sipm-dsp-config.json", + "ch2/raw": f"{config_dir}/sipm-dsp-config.json", + } + + dsp_out = build_dsp( + lgnd_test_data.get_path("lh5/L200-comm-20211130-phy-spms.lh5"), + n_entries=5, + lh5_tables=chan_config.keys(), + chan_config=chan_config, + write_mode="r", + processes=2, + ) + + lh5_obj = lh5.read("/ch0", dsp_test_file_spm) + assert dsp_out["ch0"] == lh5_obj + lh5_obj = lh5.read("/ch1", dsp_test_file_spm) + assert dsp_out["ch1"] == lh5_obj + lh5_obj = lh5.read("/ch2", dsp_test_file_spm) + assert dsp_out["ch2"] == lh5_obj diff --git a/tests/test_list_parsing.py b/tests/test_list_parsing.py deleted file mode 100644 index 573ab50e..00000000 --- a/tests/test_list_parsing.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pathlib import Path - -import numpy as np -from lgdo import lh5 - -from dspeed import build_dsp - -config_dir = Path(__file__).parent / "configs" - - -def test_list_parsing(lgnd_test_data, tmptestdir): - dsp_file = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds__numpy_test_dsp.lh5" - dsp_config = { - "outputs": ["wf_out"], - "processors": { - "wf_out": { - "function": "add", - "module": "numpy", - "args": ["[1,2,3,4,5]", "[6,7,8,9,10]", "out=wf_out"], - "unit": "ADC", - }, - }, - } - build_dsp( - f_raw=lgnd_test_data.get_path( - "lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5" - ), - f_dsp=dsp_file, - dsp_config=dsp_config, - write_mode="r", - ) - assert os.path.exists(dsp_file) - - df = lh5.read_as("geds/dsp/", dsp_file, "pd", n_rows=5, field_mask=["wf_out"]) - - assert np.all(df["wf_out"][:] == np.array([7, 9, 11, 13, 15])) diff --git a/tests/test_numpy_constants_parsing.py b/tests/test_numpy_constants_parsing.py deleted file mode 100644 index 97a3c975..00000000 --- a/tests/test_numpy_constants_parsing.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -from pathlib import Path - -import numpy as np -from lgdo import lh5 - -from dspeed import build_dsp - -config_dir = Path(__file__).parent / "configs" - - -def test_build_dsp(lgnd_test_data, tmptestdir): - dsp_file = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds__numpy_test_dsp.lh5" - build_dsp( - f_raw=lgnd_test_data.get_path( - "lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5" - ), - f_dsp=dsp_file, - dsp_config=f"{config_dir}/numpy-parsing.json", - write_mode="r", - ) - assert os.path.exists(dsp_file) - - -def test_numpy_math_constants_dsp(tmptestdir): - dsp_file = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds__numpy_test_dsp.lh5" - df = lh5.read_as( - "geds/dsp/", dsp_file, "pd", field_mask=["timestamp", "calc1", "calc2", "calc3"] - ) - - a1 = df["timestamp"] - df["timestamp"] - np.pi * df["timestamp"] - a2 = df["timestamp"] - df["timestamp"] - np.pi - a3 = df["timestamp"] - df["timestamp"] - np.pi * np.e - - f1 = df["calc1"] - f2 = df["calc2"] - f3 = df["calc3"] - - assert (a1 == f1).all() - assert (a2 == f2).all() - assert (a3 == f3).all() - - -def test_numpy_infinity_and_nan_dsp(tmptestdir): - dsp_file = f"{tmptestdir}/LDQTA_r117_20200110T105115Z_cal_geds__numpy_test_dsp.lh5" - df = lh5.read_as( - "geds/dsp/", dsp_file, "pd", field_mask=["calc4", "calc5", "calc6"] - ) - - assert (np.isnan(df["calc4"])).all() - assert (np.isneginf(df["calc5"])).all() - assert (np.isnan(df["calc6"])).all() diff --git a/tests/test_operators.py b/tests/test_operators.py deleted file mode 100644 index 79693d10..00000000 --- a/tests/test_operators.py +++ /dev/null @@ -1,33 +0,0 @@ -import lgdo -import numpy as np - -from dspeed.processing_chain import build_processing_chain - - -def test_waveform_slicing(): - dsp_config = { - "outputs": ["eq", "neq", "gt", "gte", "lt", "lte"], - "processors": { - "eq": "w_in == 5", - "neq": "w_in != 5", - "gt": "w_in > 5", - "gte": "w_in >= 5", - "lt": "w_in < 5", - "lte": "w_in <= 5", - }, - } - w_in = np.arange(10) - tbl_in = lgdo.types.Table( - {"w_in": lgdo.types.ArrayOfEqualSizedArrays(nda=w_in.reshape((1, 10)))} - ) - proc_chain, _, tbl_out = build_processing_chain(tbl_in, dsp_config) - proc_chain.execute(0, 1) - - assert list(tbl_out.keys()) == ["eq", "neq", "gt", "gte", "lt", "lte"] - assert all([tbl_out[k].nda.dtype == np.dtype("bool") for k in tbl_out.keys()]) - assert all(tbl_out["eq"].nda[0] == (w_in == 5)) - assert all(tbl_out["neq"].nda[0] == (w_in != 5)) - assert all(tbl_out["gt"].nda[0] == (w_in > 5)) - assert all(tbl_out["gte"].nda[0] == (w_in >= 5)) - assert all(tbl_out["lt"].nda[0] == (w_in < 5)) - assert all(tbl_out["lte"].nda[0] == (w_in <= 5)) diff --git a/tests/test_processing_chain.py b/tests/test_processing_chain.py index b446caec..6f1e325f 100644 --- a/tests/test_processing_chain.py +++ b/tests/test_processing_chain.py @@ -2,27 +2,7 @@ import numpy as np import pytest -from dspeed.processing_chain import build_processing_chain - - -def test_waveform_slicing(geds_raw_tbl): - dsp_config = { - "outputs": ["wf_blsub"], - "processors": { - "wf_blsub": { - "function": "bl_subtract", - "module": "dspeed.processors", - "args": ["waveform[0:100]", "baseline", "wf_blsub"], - "unit": "ADC", - }, - }, - } - proc_chain, _, tbl_out = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) - - assert list(tbl_out.keys()) == ["wf_blsub"] - assert isinstance(tbl_out["wf_blsub"], lgdo.WaveformTable) - assert tbl_out["wf_blsub"].wf_len == 100 +from dspeed import build_dsp def test_processor_none_arg(geds_raw_tbl): @@ -38,12 +18,111 @@ def test_processor_none_arg(geds_raw_tbl): } }, } - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(raw_in=geds_raw_tbl, dsp_config=dsp_config, n_entries=1) dsp_config["processors"]["wf_cum"]["args"][2] = "None" - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(raw_in=geds_raw_tbl, dsp_config=dsp_config, n_entries=1) + + +def test_numpy_math_constants_dsp(lgnd_test_data): + dsp_config = { + "outputs": ["timestamp", "calc1", "calc2", "calc3", "calc4", "calc5", "calc6"], + "processors": { + "calc1": "np.pi*timestamp", + "calc2": "np.pi", + "calc3": "np.pi*np.e", + "calc4": "np.nan", + "calc5": "np.inf", + "calc6": "np.nan*timestamp", + }, + } + + f_raw = lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5") + dsp_out = build_dsp(raw_in=f_raw, dsp_config=dsp_config) + df = dsp_out["geds"]["dsp"].view_as("pd") + + assert (df["calc1"] == np.pi * df["timestamp"]).all() + assert (df["calc2"] == np.pi).all() + assert (df["calc3"] == np.pi * np.e).all() + assert (np.isnan(df["calc4"])).all() + assert (np.isinf(df["calc5"])).all() + assert (np.isnan(df["calc6"])).all() + + +def test_list_parsing(lgnd_test_data, tmptestdir): + dsp_config = { + "outputs": ["wf_out", "ievt"], + "processors": { + "a1": "[1,2,3,4,5]", + "a2": "[6,7,8,9,10]", + "wf_out": "a1+a2", + }, + } + + raw_in = lgnd_test_data.get_path("lh5/LDQTA_r117_20200110T105115Z_cal_geds_raw.lh5") + dsp_out = build_dsp(raw_in=raw_in, dsp_config=dsp_config, n_entries=1) + assert np.all(dsp_out["geds"]["dsp"]["wf_out"].nda == np.array([7, 9, 11, 13, 15])) + + +def test_comparators(): + dsp_config = { + "outputs": ["eq", "neq", "gt", "gte", "lt", "lte"], + "processors": { + "eq": "w_in == 5", + "neq": "w_in != 5", + "gt": "w_in > 5", + "gte": "w_in >= 5", + "lt": "w_in < 5", + "lte": "w_in <= 5", + }, + } + w_in = np.arange(10) + tbl_in = lgdo.types.Table( + {"w_in": lgdo.types.ArrayOfEqualSizedArrays(nda=w_in.reshape((1, 10)))} + ) + tbl_out = build_dsp(tbl_in, dsp_config=dsp_config, n_entries=1) + + assert set(tbl_out.keys()) == {"eq", "neq", "gt", "gte", "lt", "lte"} + assert all([tbl_out[k].nda.dtype == np.dtype("bool") for k in tbl_out.keys()]) + assert all(tbl_out["eq"].nda[0] == (w_in == 5)) + assert all(tbl_out["neq"].nda[0] == (w_in != 5)) + assert all(tbl_out["gt"].nda[0] == (w_in > 5)) + assert all(tbl_out["gte"].nda[0] == (w_in >= 5)) + assert all(tbl_out["lt"].nda[0] == (w_in < 5)) + assert all(tbl_out["lte"].nda[0] == (w_in <= 5)) + + +def test_waveform_slicing(geds_raw_tbl): + dsp_config = { + "outputs": ["waveform", "wf_sample", "wf_slice", "wf_slice_stride"], + "processors": { + "wf_sample": {"function": "waveform[50]"}, + "wf_slice": {"function": "waveform[50:100]"}, + "wf_slice_stride": {"function": "waveform[50:100:2]"}, + }, + } + tbl_out = build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=10) + + assert isinstance(tbl_out.waveform, lgdo.WaveformTable) + assert isinstance(tbl_out.wf_sample, lgdo.Array) + assert isinstance(tbl_out.wf_slice, lgdo.WaveformTable) + assert isinstance(tbl_out.wf_slice_stride, lgdo.WaveformTable) + + assert np.all(tbl_out.waveform.values[:, 50] == tbl_out.wf_sample) + assert np.all(tbl_out.waveform.values[:, 50:100] == tbl_out.wf_slice.values) + assert np.all( + tbl_out.waveform.t0.nda + 50 * tbl_out.waveform.dt.nda + == tbl_out.wf_slice.t0.nda + ) + assert np.all(tbl_out.waveform.dt.nda == tbl_out.wf_slice.dt.nda) + assert np.all( + tbl_out.waveform.values[:, 50:100:2] == tbl_out.wf_slice_stride.values + ) + assert np.all( + tbl_out.waveform.t0.nda + 50 * tbl_out.waveform.dt.nda + == tbl_out.wf_slice_stride.t0.nda + ) + assert np.all(tbl_out.waveform.dt.nda == tbl_out.wf_slice_stride.dt.nda / 2) def test_processor_kwarg_assignment(geds_raw_tbl): @@ -59,13 +138,11 @@ def test_processor_kwarg_assignment(geds_raw_tbl): } }, } - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) dsp_config["processors"]["wf_cum"]["args"][1] = "dtypo=None" - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) with pytest.raises(TypeError): - proc_chain.execute(0, 1) + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) def test_processor_dtype_arg(geds_raw_tbl): @@ -81,8 +158,7 @@ def test_processor_dtype_arg(geds_raw_tbl): } }, } - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) def test_scipy_gauss_filter(geds_raw_tbl): @@ -104,8 +180,7 @@ def test_scipy_gauss_filter(geds_raw_tbl): } }, } - proc_chain, _, _ = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) def test_histogram_processor_fixed_width(spms_raw_tbl): @@ -120,8 +195,7 @@ def test_histogram_processor_fixed_width(spms_raw_tbl): } }, } - proc_chain, _, _ = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) def test_processor_variable_array_output(spms_raw_tbl): @@ -147,9 +221,7 @@ def test_processor_variable_array_output(spms_raw_tbl): } }, } - - proc_chain, _, _ = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) def test_proc_chain_unit_conversion(spms_raw_tbl): @@ -178,8 +250,7 @@ def test_proc_chain_unit_conversion(spms_raw_tbl): }, }, } - proc_chain, _, lh5_out = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) assert lh5_out["a_unitless"][0] == lh5_out["a_ns"][0] assert lh5_out["a_unitless"][0] == lh5_out["a_us"][0] assert lh5_out["a_unitless"][0] == lh5_out["a_ghz"][0] @@ -247,8 +318,7 @@ def test_proc_chain_coordinate_grid(spms_raw_tbl): }, } - proc_chain, _, lh5_out = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) assert lh5_out["a_window"][0] == lh5_out["a_downsample"][0] assert lh5_out["tp_window"][0] == lh5_out["tp"][0] assert -128 < lh5_out["tp_downsample"][0] - lh5_out["tp"][0] < 128 @@ -260,8 +330,7 @@ def test_proc_chain_round(spms_raw_tbl): "processors": {"waveform_round": "round(waveform, 4)"}, } - proc_chain, _, lh5_out = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) assert np.all( np.rint(spms_raw_tbl["waveform"].values[0] / 4) * 4 == lh5_out["waveform_round"].values[0] @@ -274,8 +343,7 @@ def test_proc_chain_as_type(spms_raw_tbl): "processors": {"waveform_32": "astype(waveform, 'float32')"}, } - proc_chain, _, lh5_out = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) assert np.all( spms_raw_tbl["waveform"].values[0] == lh5_out["waveform_32"].values[0] ) @@ -307,8 +375,7 @@ def test_output_types(spms_raw_tbl): }, } - proc_chain, _, lh5_out = build_processing_chain(spms_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(spms_raw_tbl, dsp_config=dsp_config, n_entries=1) assert isinstance(lh5_out["n_max_out"], lgdo.Array) assert isinstance(lh5_out["wf_out"], lgdo.WaveformTable) assert isinstance(lh5_out["aoa_out"], lgdo.ArrayOfEqualSizedArrays) @@ -328,8 +395,7 @@ def test_output_attrs(geds_raw_tbl): } }, } - proc_chain, _, lh5_out = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) assert lh5_out["wf_blsub"].attrs["test_attr"] == "This is a test" @@ -345,13 +411,10 @@ def test_database_params(geds_raw_tbl): }, }, } - - proc_chain, _, lh5_out = build_processing_chain(geds_raw_tbl, dsp_config) - proc_chain.execute(0, 1) + lh5_out = build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) assert lh5_out["test"][0] == 8 - proc_chain, _, lh5_out = build_processing_chain( - geds_raw_tbl, dsp_config, db_dict={"a": 2, "c": 0} + lh5_out = build_dsp( + geds_raw_tbl, dsp_config=dsp_config, database={"a": 2, "c": 0}, n_entries=1 ) - proc_chain.execute(0, 1) assert lh5_out["test"][0] == 3 diff --git a/tests/test_utils.py b/tests/test_utils.py index 97f34c1c..e420c27a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,38 @@ +import lgdo +import numpy as np + from dspeed.utils import numba_defaults def test_numba_defaults_loading(): numba_defaults.cache = False numba_defaults.boundscheck = True + + +def isclose(lhs, rhs, rtol=1e-5, atol=1e-8, equal_nan=True): + # an is close comparison for LGDO structures + + if isinstance(lhs, lgdo.Struct) and isinstance(rhs, lgdo.Struct): + if set(lhs) != set(rhs) or lhs.attrs != rhs.attrs: + return False + + for k in lhs: + if not isclose(lhs[k], rhs[k], rtol=rtol, atol=atol, equal_nan=equal_nan): + return False + return True + + elif isinstance(lhs, lgdo.Array) and isinstance(rhs, lgdo.Array): + if len(lhs) != len(rhs) or lhs.attrs != rhs.attrs: + return False + return np.all(np.isclose(lhs, rhs, rtol=rtol, atol=atol, equal_nan=equal_nan)) + + elif isinstance(lhs, lgdo.VectorOfVectors) and isinstance( + rhs, lgdo.VectorOfVectors + ): + if len(lhs) != len(rhs) or lhs.attrs != rhs.attrs: + return False + return lhs.cumulative_length == rhs.cumulative_length and np.all( + np.isclose(lhs, rhs, rtol=rtol, atol=atol, equal_nan=equal_nan) + ) + + return False