diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index e0a6beb7..6c83326d 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -628,15 +628,17 @@ def actions( :return: Generator of actions to run """ - @abc.abstractmethod def state(self, state: State, inputs: Dict[str, Any]): - """Gives the state for each of the actions + """Gives the state for each of the actions. + By default, this will give out the current state. That said, + you may want to adjust this -- E.G. to translate state into + a format the sub-actions would expect. :param state: State at the time of running the action :param inputs: Runtime inputs to the action :return: State for the action """ - pass + return state def states( self, state: State, context: ApplicationContext, inputs: Dict[str, Any] diff --git a/burr/tracking/server/schema.py b/burr/tracking/server/schema.py index d26eace8..c770034a 100644 --- a/burr/tracking/server/schema.py +++ b/burr/tracking/server/schema.py @@ -93,30 +93,30 @@ def from_logs(log_lines: List[bytes]) -> List["Step"]: json_line = safe_json_load(line) # TODO -- make these into constants if json_line["type"] == "begin_entry": - begin_step = BeginEntryModel.parse_obj(json_line) + begin_step = BeginEntryModel.model_validate(json_line) steps_by_sequence_id[begin_step.sequence_id].step_start_log = begin_step elif json_line["type"] == "end_entry": - step_end_log = EndEntryModel.parse_obj(json_line) + step_end_log = EndEntryModel.model_validate(json_line) steps_by_sequence_id[step_end_log.sequence_id].step_end_log = step_end_log elif json_line["type"] == "begin_span": - span = BeginSpanModel.parse_obj(json_line) + span = BeginSpanModel.model_validate(json_line) spans_by_id[span.span_id] = PartialSpan( begin_entry=span, end_entry=None, ) elif json_line["type"] == "end_span": - end_span = EndSpanModel.parse_obj(json_line) + end_span = EndSpanModel.model_validate(json_line) span = spans_by_id[end_span.span_id] span.end_entry = end_span elif json_line["type"] == "attribute": - attribute = AttributeModel.parse_obj(json_line) + attribute = AttributeModel.model_validate(json_line) attributes_by_step[attribute.action_sequence_id].append(attribute) elif json_line["type"] in ["begin_stream", "first_item_stream", "end_stream"]: streaming_event = { "begin_stream": InitializeStreamModel, "first_item_stream": FirstItemStreamModel, "end_stream": EndStreamModel, - }[json_line["type"]].parse_obj(json_line) + }[json_line["type"]].model_validate(json_line) steps_by_sequence_id[streaming_event.sequence_id].streaming_events.append( streaming_event ) diff --git a/pyproject.toml b/pyproject.toml index 2b2e4d06..2734ffd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ documentation = [ ] tracking-client = [ - "pydantic" + "pydantic>1" ] tracking-client-s3 = [ diff --git a/tests/core/test_application.py b/tests/core/test_application.py index dff46628..d9767cc3 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -3333,7 +3333,7 @@ def load( builder.with_state_persister(persister) -class TestActionWithoutContext(Action): +class ActionWithoutContext(Action): def run(self, other_param, foo): pass @@ -3352,7 +3352,7 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: return ["other_param", "foo"] -class TestActionWithContext(TestActionWithoutContext): +class ActionWithContext(ActionWithoutContext): def run(self, __context, other_param, foo): pass @@ -3360,7 +3360,7 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: return ["other_param", "foo", "__context"] -class TestActionWithKwargs(TestActionWithoutContext): +class ActionWithKwargs(ActionWithoutContext): def run(self, other_param, foo, **kwargs): pass @@ -3368,7 +3368,7 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: return ["other_param", "foo", "__context"] -class TestActionWithContextTracer(TestActionWithoutContext): +class ActionWithContextTracer(ActionWithoutContext): def run(self, __context, other_param, foo, __tracer): pass @@ -3377,7 +3377,7 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: def test_remap_context_variable_with_mangled_context_kwargs(): - _action = TestActionWithKwargs() + _action = ActionWithKwargs() inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} @@ -3385,11 +3385,11 @@ def test_remap_context_variable_with_mangled_context_kwargs(): def test_remap_context_variable_with_mangled_context(): - _action = TestActionWithContext() + _action = ActionWithContext() inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = { - f"_{TestActionWithContext.__name__}__context": "context_value", + f"_{ActionWithContext.__name__}__context": "context_value", "other_key": "other_value", "foo": "foo_value", } @@ -3397,7 +3397,7 @@ def test_remap_context_variable_with_mangled_context(): def test_remap_context_variable_with_mangled_contexttracer(): - _action = TestActionWithContextTracer() + _action = ActionWithContextTracer() inputs = { "__context": "context_value", @@ -3406,16 +3406,16 @@ def test_remap_context_variable_with_mangled_contexttracer(): "foo": "foo_value", } expected = { - f"_{TestActionWithContextTracer.__name__}__context": "context_value", + f"_{ActionWithContextTracer.__name__}__context": "context_value", "other_key": "other_value", "foo": "foo_value", - f"_{TestActionWithContextTracer.__name__}__tracer": "tracer_value", + f"_{ActionWithContextTracer.__name__}__tracer": "tracer_value", } assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_without_mangled_context(): - _action = TestActionWithoutContext() + _action = ActionWithoutContext() inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 6404414b..4c1f15f3 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -370,6 +370,28 @@ def _group_events_by_app_id( return grouped_events +def test_map_actions_default_state(): + class MapActionsAllApproaches(MapActions): + def actions( + self, state: State, inputs: Dict[str, Any], context: ApplicationContext + ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: + ... + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + ... + + @property + def writes(self) -> list[str]: + return [] + + @property + def reads(self) -> list[str]: + return [] + + state_to_test = State({"foo": "bar", "baz": "qux"}) + assert MapActionsAllApproaches().state(state_to_test, {}).get_all() == state_to_test.get_all() + + def test_e2e_map_actions_sync_subgraph(): """Tests map actions over multiple action types (runnable graph, function, action class...)""" diff --git a/tests/integrations/test_burr_pydantic.py b/tests/integrations/test_burr_pydantic.py index 0d330f0d..878833fc 100644 --- a/tests/integrations/test_burr_pydantic.py +++ b/tests/integrations/test_burr_pydantic.py @@ -3,7 +3,7 @@ import pydantic import pytest -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel, ConfigDict, EmailStr, Field from pydantic.fields import FieldInfo from burr.core import expr @@ -110,8 +110,7 @@ class MyModelWithConfig(pydantic.BaseModel): foo: int arbitrary: Arbitrary - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) SubsetModel = subset_model(MyModelWithConfig, ["foo", "bar"], [], "Subset") assert SubsetModel.__name__ == "MyModelWithConfigSubset"