diff --git a/src/dspeed/processing_chain.py b/src/dspeed/processing_chain.py index ff62cdf..21bca2d 100644 --- a/src/dspeed/processing_chain.py +++ b/src/dspeed/processing_chain.py @@ -2632,8 +2632,10 @@ def resolve_dependencies( init_kwargs = {} for arg in init_args_in: if not isinstance(arg, str): - pass + init_args.append(arg) + continue + # find and replace db values for db_var in db_parser.findall(arg): try: db_node = db_dict diff --git a/src/dspeed/processors/iir_filter.py b/src/dspeed/processors/iir_filter.py index dcc3fae..994d7f3 100644 --- a/src/dspeed/processors/iir_filter.py +++ b/src/dspeed/processors/iir_filter.py @@ -57,7 +57,7 @@ def iir_filter( wf_lp: function: iir_filter module: dspeed.processors - args_in: + init_args: - "15*MHz" - 4 - wf @@ -138,7 +138,7 @@ def notch_filter( wf_notch: function: notch_filter module: dspeed.processors - args_in: + init_args: - "15*MHz" - "1.5*MHz" - wf @@ -196,7 +196,7 @@ def peak_filter( wf_peak: function: peak_filter module: dspeed.processors - args_in: + init_args: - "15*MHz" - "1.5*MHz" - wf diff --git a/src/dspeed/processors/poly_fit.py b/src/dspeed/processors/poly_fit.py index fceaa66..86e1783 100644 --- a/src/dspeed/processors/poly_fit.py +++ b/src/dspeed/processors/poly_fit.py @@ -34,7 +34,18 @@ def _poly_fitter(w_in: np.ndarray, inv: np.ndarray, poly_pars: np.ndarray) -> No def poly_fit(length, deg): """Factory function for generating a polynomial fitter for an input of length - `length` to a polynomial of order `deg`.""" + `length` to a polynomial of order `deg`. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + fit_pars: + function: dspeed.processors.poly_fit + args: ["wf_logged", "fit_pars(shape=4)"] + init_args: ["len(wf_logged)", 3] + """ vals_array = np.zeros(2 * deg + 1, dtype="float64") diff --git a/tests/test_processing_chain.py b/tests/test_processing_chain.py index acca063..82e196a 100644 --- a/tests/test_processing_chain.py +++ b/tests/test_processing_chain.py @@ -664,3 +664,42 @@ def test_database_params(geds_raw_tbl): geds_raw_tbl, dsp_config=dsp_config, database={"a": 2, "c": 0}, n_entries=1 ) assert lh5_out["test"][0] == 3 + + def test_init_args(geds_raw_tbl): + dsp_config = { + "outputs": ["test"], + "processors": { + "test": { + "function": "dspeed.processors.poly_fit", + "args": ["waveform", "test(shape=4)"], + "init_args": ["len(waveform)", 3], + } + }, + } + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) + + dsp_config = { + "outputs": ["test"], + "processors": { + "test": { + "function": "dspeed.processors.poly_fit", + "args": ["waveform", "test(shape=3)"], + "init_args": ["len(waveform)", 3], + } + }, + } + with pytest.raises(ProcessingChainError): + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1) + + dsp_config = { + "outputs": ["test"], + "processors": { + "test": { + "function": "dspeed.processors.poly_fit", + "args": ["waveform", "test(shape=4)"], + "init_args": ["len(waveform)", "db.deg"], + "defaults": {"db.deg": 3}, + } + }, + } + build_dsp(geds_raw_tbl, dsp_config=dsp_config, n_entries=1)