Skip to content

Commit 71ea2d8

Browse files
committed
WIP for function-based actions
1 parent 2b814d8 commit 71ea2d8

File tree

8 files changed

+218
-297
lines changed

8 files changed

+218
-297
lines changed

burr/core/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from burr.core.state import State
44

55
__all__ = [
6+
"action",
67
"Action",
8+
"Application",
79
"ApplicationBuilder",
810
"Condition",
9-
"Result",
1011
"default",
11-
"when",
1212
"expr",
13-
"Application",
13+
"Result",
1414
"State",
15+
"when",
1516
]

burr/core/action.py

+49-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import ast
33
import copy
44
import inspect
5-
from typing import Callable, List, Tuple, Union
5+
import types
6+
from typing import Any, Callable, List, Protocol, Tuple, TypeVar, Union
67

78
from burr.core.state import State
89

@@ -169,7 +170,11 @@ class FunctionBasedAction(Action):
169170
ACTION_FUNCTION = "action_function"
170171

171172
def __init__(
172-
self, fn: Callable[[State], Tuple[dict, State]], reads: List[str], writes: List[str]
173+
self,
174+
fn: Callable[..., Tuple[dict, State]],
175+
reads: List[str],
176+
writes: List[str],
177+
bound_params: dict = None,
173178
):
174179
"""Instantiates a function-based action with the given function, reads, and writes.
175180
The function must take in a state and return a tuple of (result, new_state).
@@ -183,13 +188,18 @@ def __init__(
183188
self._reads = reads
184189
self._writes = writes
185190
self._state_created = None
191+
self._bound_params = bound_params if bound_params is not None else {}
192+
193+
@property
194+
def fn(self) -> Callable:
195+
return self._fn
186196

187197
@property
188198
def reads(self) -> list[str]:
189199
return self._reads
190200

191201
def run(self, state: State) -> dict:
192-
result, new_state = self._fn(state)
202+
result, new_state = self._fn(state, **self._bound_params)
193203
self._state_created = new_state
194204
return result
195205

@@ -202,7 +212,23 @@ def update(self, result: dict, state: State) -> State:
202212
raise ValueError(
203213
"FunctionBasedAction.run must be called before FunctionBasedAction.update"
204214
)
205-
return self._state_created
215+
# TODO -- validate that all the keys are contained -- fix up subset to handle this
216+
# TODO -- validate that we've (a) written only to the write ones (by diffing the read ones),
217+
# and (b) written to no more than the write ones
218+
return self._state_created.subset(*self._writes)
219+
220+
def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
221+
"""Binds parameters to the function.
222+
Note that there is no reason to call this by the user. This *could*
223+
be done at the class level, but given that API allows for constructor parameters
224+
(which do the same thing in a cleaner way), it is best to keep it here for now.
225+
226+
:param kwargs:
227+
:return:
228+
"""
229+
new_action = copy.copy(self)
230+
new_action._bound_params = {**self._bound_params, **kwargs}
231+
return new_action
206232

207233

208234
def _validate_action_function(fn: Callable):
@@ -225,7 +251,23 @@ def _validate_action_function(fn: Callable):
225251
)
226252

227253

228-
def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable]:
254+
C = TypeVar("C", bound=Callable) # placeholder for any Callable
255+
256+
257+
class FunctionRepresentingAction(Protocol[C]):
258+
action_function: FunctionBasedAction
259+
__call__: C
260+
261+
def bind(self, **kwargs: Any):
262+
...
263+
264+
265+
def bind(self: FunctionRepresentingAction, **kwargs: Any) -> FunctionRepresentingAction:
266+
self.action_function = self.action_function.with_params(**kwargs)
267+
return self
268+
269+
270+
def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]:
229271
"""Decorator to create a function-based action. This is user-facing.
230272
Note that, in the future, with typed state, we may not need this for
231273
all cases.
@@ -235,8 +277,9 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable
235277
:return: The decorator to assign the function as an action
236278
"""
237279

238-
def decorator(fn: Callable) -> Callable:
280+
def decorator(fn) -> FunctionRepresentingAction:
239281
setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes))
282+
setattr(fn, "bind", types.MethodType(bind, fn))
240283
return fn
241284

242285
return decorator

burr/core/application.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import collections
22
import dataclasses
33
import logging
4-
from typing import Any, AsyncGenerator, Generator, List, Literal, Optional, Set, Tuple, Union
4+
from typing import (
5+
Any,
6+
AsyncGenerator,
7+
Callable,
8+
Generator,
9+
List,
10+
Literal,
11+
Optional,
12+
Set,
13+
Tuple,
14+
Union,
15+
)
516

617
from burr.core.action import Action, Condition, Function, Reducer, create_action, default
718
from burr.core.state import State
@@ -66,7 +77,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
6677
:return:
6778
"""
6879
state_to_use = state.subset(*reducer.writes)
69-
new_state = reducer.update(result, state_to_use)
80+
new_state = reducer.update(result, state_to_use).subset(*reducer.writes)
7081
keys_in_new_state = set(new_state.keys())
7182
extra_keys = keys_in_new_state - set(reducer.writes)
7283
if extra_keys:
@@ -440,7 +451,7 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder":
440451
self.start = action
441452
return self
442453

443-
def with_actions(self, **actions: Action) -> "ApplicationBuilder":
454+
def with_actions(self, **actions: Union[Action, Callable]) -> "ApplicationBuilder":
444455
"""Adds an action to the application. The actions are granted names (using the with_name)
445456
method post-adding, using the kw argument. Thus, this is the only supported way to add actions.
446457

burr/integrations/streamlit.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List, Optional
55

66
from burr.core import Application
7+
from burr.core.action import FunctionBasedAction
78
from burr.integrations.base import require_plugin
89
from burr.integrations.hamilton import Hamilton, StateSource
910

@@ -176,6 +177,7 @@ def render_action(state: AppState):
176177
st.header(f"`{current_node}`")
177178
action_object = actions[current_node]
178179
is_hamilton = isinstance(action_object, Hamilton)
180+
is_function_api = isinstance(action_object, FunctionBasedAction)
179181

180182
def format_read(var):
181183
out = f"- `{var}`"
@@ -210,6 +212,9 @@ def format_write(var):
210212
if is_hamilton:
211213
digraph = action_object.visualize_step(show_legend=False)
212214
st.graphviz_chart(digraph, use_container_width=False)
215+
elif is_function_api:
216+
code = inspect.getsource(action_object.fn)
217+
st.code(code, language="python")
213218
else:
214219
code = inspect.getsource(action_object.__class__)
215220
st.code(code, language="python")

examples/counter/application.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,22 @@
1+
from typing import Tuple
2+
13
import burr.core
2-
from burr.core import Action, Result, State, default, expr
4+
from burr.core import Result, State, default, expr
5+
from burr.core.action import action
36
from burr.lifecycle import StateAndResultsFullLogger
47

58

6-
class CounterAction(Action):
7-
@property
8-
def reads(self) -> list[str]:
9-
return ["counter"]
10-
11-
def run(self, state: State) -> dict:
12-
return {"counter": state["counter"] + 1}
13-
14-
@property
15-
def writes(self) -> list[str]:
16-
return ["counter"]
17-
18-
def update(self, result: dict, state: State) -> State:
19-
return state.update(**result)
9+
@action(reads=["counter"], writes=["counter"])
10+
def counter(state: State) -> Tuple[dict, State]:
11+
result = {"counter": state["counter"] + 1}
12+
return result, state.update(**result)
2013

2114

2215
def application(count_up_to: int = 10, log_file: str = None):
2316
return (
2417
burr.core.ApplicationBuilder()
25-
.with_state(
26-
counter=0,
27-
)
28-
.with_actions(counter=CounterAction(), result=Result(["counter"]))
18+
.with_state(counter=0)
19+
.with_actions(counter=counter, result=Result(["counter"]))
2920
.with_transitions(
3021
("counter", "counter", expr(f"counter < {count_up_to}")),
3122
("counter", "result", default),

examples/cowsay/application.py

+27-54
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,14 @@
11
import random
22
import time
3-
from typing import List, Optional
3+
from typing import Tuple
44

55
import cowsay
66

77
from burr.core import Action, Application, ApplicationBuilder, State, default, expr
8+
from burr.core.action import action
89
from burr.lifecycle import PostRunStepHook
910

1011

11-
class CowSay(Action):
12-
def __init__(self, say_what: List[Optional[str]]):
13-
super(CowSay, self).__init__()
14-
self.say_what = say_what
15-
16-
@property
17-
def reads(self) -> list[str]:
18-
return []
19-
20-
def run(self, state: State) -> dict:
21-
say_what = random.choice(self.say_what)
22-
return {
23-
"cow_said": cowsay.get_output_string("cow", say_what) if say_what is not None else None
24-
}
25-
26-
@property
27-
def writes(self) -> list[str]:
28-
return ["cow_said"]
29-
30-
def update(self, result: dict, state: State) -> State:
31-
return state.update(**result)
32-
33-
34-
class CowShouldSay(Action):
35-
@property
36-
def reads(self) -> list[str]:
37-
return []
38-
39-
def run(self, state: State) -> dict:
40-
if not random.randint(0, 3):
41-
return {"cow_should_speak": True}
42-
return {"cow_should_speak": False}
43-
44-
@property
45-
def writes(self) -> list[str]:
46-
return ["cow_should_speak"]
47-
48-
def update(self, result: dict, state: State) -> State:
49-
return state.update(**result)
50-
51-
5212
class PrintWhatTheCowSaid(PostRunStepHook):
5313
def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
5414
if action.name != "cow_should_say" and state["cow_said"] is not None:
@@ -65,6 +25,19 @@ def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
6525
time.sleep(self.sleep_time)
6626

6727

28+
@action(reads=[], writes=["cow_said"])
29+
def cow_said(state: State, say_what: list[str]) -> Tuple[dict, State]:
30+
said = random.choice(say_what)
31+
result = {"cow_said": cowsay.get_output_string("cow", said) if say_what is not None else None}
32+
return result, state.update(**result)
33+
34+
35+
@action(reads=[], writes=["cow_should_speak"])
36+
def cow_should_speak(state: State) -> Tuple[dict, State]:
37+
result = {"cow_should_speak": random.randint(0, 3) == 0}
38+
return result, state.update(**result)
39+
40+
6841
def application(in_terminal: bool = False) -> Application:
6942
hooks = (
7043
[
@@ -76,21 +49,21 @@ def application(in_terminal: bool = False) -> Application:
7649
)
7750
return (
7851
ApplicationBuilder()
79-
.with_state(
80-
cow_said=None,
81-
)
52+
.with_state(cow_said=None)
8253
.with_actions(
83-
say_nothing=CowSay([None]),
84-
say_hello=CowSay(["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]),
85-
cow_should_say=CowShouldSay(),
54+
say_nothing=cow_said.bind(say_what=None),
55+
say_hello=cow_said.bind(
56+
say_what=["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]
57+
),
58+
cow_should_speak=cow_should_speak,
8659
)
8760
.with_transitions(
88-
("cow_should_say", "say_hello", expr("cow_should_speak")),
89-
("say_hello", "cow_should_say", default),
90-
("cow_should_say", "say_nothing", expr("not cow_should_speak")),
91-
("say_nothing", "cow_should_say", default),
61+
("cow_should_speak", "say_hello", expr("cow_should_speak")),
62+
("say_hello", "cow_should_speak", default),
63+
("cow_should_speak", "say_nothing", expr("not cow_should_speak")),
64+
("say_nothing", "cow_should_speak", default),
9265
)
93-
.with_entrypoint("cow_should_say")
66+
.with_entrypoint("cow_should_speak")
9467
.with_hooks(*hooks)
9568
.build()
9669
)
@@ -100,4 +73,4 @@ def application(in_terminal: bool = False) -> Application:
10073
app = application(in_terminal=True)
10174
app.visualize(output_file_path="cowsay.png", include_conditions=True, view=True)
10275
while True:
103-
state, result, action = app.step()
76+
s, r, action = app.step()

0 commit comments

Comments
 (0)