From 2da67a94f61d902483ae5ae8b7d1a7fd78db735c Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Tue, 3 Dec 2024 21:32:40 -0800 Subject: [PATCH] Fixes __tracer for class based actions Refactors __context treatment to also handle __tracer. This now enables one to pass through/request the tracer object in a class based action. --- burr/core/application.py | 36 ++++++++++++++++++++-------------- tests/core/test_application.py | 34 ++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 183dec9a..06bd2286 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -103,28 +103,34 @@ def _adjust_single_step_output( _raise_fn_return_validation_error(output, action_name) -def _remap_context_variable(run_method: Callable, inputs: Dict[str, Any]) -> dict: - """This is a utility function to remap the __context variable to the mangled variable in the function signature. +def _remap_dunder_parameters( + run_method: Callable, inputs: Dict[str, Any], vars_to_remap: List[str] +) -> dict: + """This is a utility function to remap the __dunder parameters to the mangled version in the function signature. - Python mangles the variable name in the function signature, so we need to remap it to the correct variable name. + Python mangles __parameter names in the function signature, so we need to remap it to the correct parameter name. :param run_method: the run method to inspect. :param inputs: the inputs to inspect + :param vars_to_remap: the variables to remap :return: potentially new dict with the remapped variable, else the original dict. """ # Get the signature of the method being run. This should be Function.run() or similar. signature = inspect.signature(run_method) + mangled_params: Dict[str, Optional[str]] = {v: None for v in vars_to_remap} # Find the name-mangled __context variable - mangled_context_name = None - for param in signature.parameters.values(): - if param.name.endswith("__context"): - mangled_context_name = param.name - break - - # If a mangled __context variable is found, remap the value in inputs - if mangled_context_name and "__context" in inputs: + for dunder_param in mangled_params.keys(): + for param in signature.parameters.values(): + if param.name.endswith(dunder_param): + mangled_params[dunder_param] = param.name + break + + # If any mangled __parameter is found, remap the value in inputs + if any(mangled_params.values()): inputs = inputs.copy() - inputs[mangled_context_name] = inputs.pop("__context") + for dunder_param, mangled_name in mangled_params.items(): + if mangled_name and dunder_param in inputs: + inputs[mangled_name] = inputs.pop(dunder_param) return inputs @@ -146,9 +152,9 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name ) state_to_use = state.subset(*function.reads) function.validate_inputs(inputs) - if "__context" in inputs: - # potentially need to remap the __context variable - inputs = _remap_context_variable(function.run, inputs) + if "__context" in inputs or "__tracer" in inputs: + # potentially need to remap the __context & __tracer variables + inputs = _remap_dunder_parameters(function.run, inputs, ["__context", "__tracer"]) result = function.run(state_to_use, **inputs) _validate_result(result, name) return result diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 997ff9ea..dff46628 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -34,7 +34,7 @@ _arun_multi_step_streaming_action, _arun_single_step_action, _arun_single_step_streaming_action, - _remap_context_variable, + _remap_dunder_parameters, _run_function, _run_multi_step_streaming_action, _run_reducer, @@ -3368,12 +3368,20 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: return ["other_param", "foo", "__context"] +class TestActionWithContextTracer(TestActionWithoutContext): + def run(self, __context, other_param, foo, __tracer): + pass + + def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: + return ["other_param", "foo", "__context", "__tracer"] + + def test_remap_context_variable_with_mangled_context_kwargs(): _action = TestActionWithKwargs() inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - assert _remap_context_variable(_action.run, inputs) == expected + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_with_mangled_context(): @@ -3385,11 +3393,29 @@ def test_remap_context_variable_with_mangled_context(): "other_key": "other_value", "foo": "foo_value", } - assert _remap_context_variable(_action.run, inputs) == expected + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + + +def test_remap_context_variable_with_mangled_contexttracer(): + _action = TestActionWithContextTracer() + + inputs = { + "__context": "context_value", + "__tracer": "tracer_value", + "other_key": "other_value", + "foo": "foo_value", + } + expected = { + f"_{TestActionWithContextTracer.__name__}__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + f"_{TestActionWithContextTracer.__name__}__tracer": "tracer_value", + } + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_without_mangled_context(): _action = TestActionWithoutContext() inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - assert _remap_context_variable(_action.run, inputs) == expected + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected