Skip to content

Commit 2cb45e5

Browse files
committed
refactor _done out
1 parent f864128 commit 2cb45e5

File tree

8 files changed

+125
-137
lines changed

8 files changed

+125
-137
lines changed

src/workflows/context/context.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@
2222
from workflows.decorators import StepConfig
2323
from workflows.errors import (
2424
ContextSerdeError,
25-
WorkflowDone,
2625
WorkflowRuntimeError,
2726
)
2827
from workflows.events import (
2928
Event,
3029
StartEvent,
31-
StopEvent,
3230
)
3331
from workflows.runtime.state import WorkflowBrokerState
3432
from workflows.runtime.broker import WorkflowBroker
@@ -151,7 +149,7 @@ def __init__(
151149

152150
state_types: set[Type[BaseModel]] = set()
153151
for _, step_func in workflow._get_steps().items():
154-
step_config: StepConfig = getattr(step_func, "__step_config")
152+
step_config: StepConfig = step_func._step_config
155153
if (
156154
step_config.context_state_type is not None
157155
and step_config.context_state_type != DictState
@@ -245,21 +243,6 @@ async def after_complete() -> None:
245243
after_complete=after_complete,
246244
)
247245

248-
def _internal_finalize_run(
249-
self,
250-
event: StopEvent,
251-
result: RunResultT,
252-
) -> None:
253-
"""
254-
Called internally from the workflow on a context's completion
255-
"""
256-
if self._broker_run is None:
257-
raise WorkflowRuntimeError("Workflow run is not yet running")
258-
self._broker_run.finalize_run(event, result)
259-
self.write_event_to_stream(event)
260-
# Signal we want to stop the workflow
261-
raise WorkflowDone
262-
263246
def _internal_cancel_run(self) -> None:
264247
"""
265248
Called internally from the handler to cancel a context's run

src/workflows/decorators.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ class StepConfig(BaseModel):
5555

5656

5757
class StepFunction(Protocol, Generic[P, R]):
58-
"""A decorated function, that has some __step_config metadata from the @step decorator"""
58+
"""A decorated function, that has some step_config metadata from the @step decorator"""
5959

60-
__step_config: StepConfig
60+
_step_config: StepConfig
6161

6262
__name__: str
6363
__qualname__: str
@@ -135,24 +135,7 @@ def decorator(func: Callable[P, R]) -> StepFunction[P, R]:
135135
"num_workers must be an integer greater than 0"
136136
)
137137

138-
# This will raise providing a message with the specific validation failure
139-
spec = inspect_signature(func)
140-
validate_step_signature(spec)
141-
142-
event_name, accepted_events = next(iter(spec.accepted_events.items()))
143-
144-
func = cast(StepFunction[P, R], func)
145-
# store the configuration in the function object
146-
func.__step_config = StepConfig(
147-
accepted_events=accepted_events,
148-
event_name=event_name,
149-
return_types=spec.return_types,
150-
context_parameter=spec.context_parameter,
151-
context_state_type=spec.context_state_type,
152-
num_workers=num_workers,
153-
retry_policy=retry_policy,
154-
resources=spec.resources,
155-
)
138+
func = make_step_function(func, num_workers, retry_policy)
156139

157140
# If this is a free function, call add_step() explicitly.
158141
if is_free_function(func.__qualname__):
@@ -167,3 +150,27 @@ def decorator(func: Callable[P, R]) -> StepFunction[P, R]:
167150
# The decorator was used without parentheses, like `@step`
168151
return decorator(func)
169152
return decorator
153+
154+
155+
def make_step_function(
156+
func: Callable[P, R], num_workers: int = 4, retry_policy: RetryPolicy | None = None
157+
) -> StepFunction[P, R]:
158+
# This will raise providing a message with the specific validation failure
159+
spec = inspect_signature(func)
160+
validate_step_signature(spec)
161+
162+
event_name, accepted_events = next(iter(spec.accepted_events.items()))
163+
164+
casted = cast(StepFunction[P, R], func)
165+
casted._step_config = StepConfig(
166+
accepted_events=accepted_events,
167+
event_name=event_name,
168+
return_types=spec.return_types,
169+
context_parameter=spec.context_parameter,
170+
context_state_type=spec.context_state_type,
171+
num_workers=num_workers,
172+
retry_policy=retry_policy,
173+
resources=spec.resources,
174+
)
175+
176+
return casted

src/workflows/runtime/broker.py

Lines changed: 81 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from llama_index_instrumentation.dispatcher import Dispatcher
2525

26-
from workflows.decorators import StepConfig
26+
from workflows.decorators import StepConfig, StepFunction, step
2727
from workflows.errors import (
2828
WorkflowCancelledByUser,
2929
WorkflowDone,
@@ -111,6 +111,86 @@ class WorkflowBroker(Generic[MODEL_T]):
111111
# run state
112112
_handler: WorkflowHandler | None
113113

114+
def __init__(
115+
self,
116+
workflow: Workflow,
117+
context: Context[MODEL_T],
118+
state: WorkflowBrokerState,
119+
run_id: str,
120+
) -> None:
121+
self._context = context
122+
self._handler = None
123+
124+
# Store the step configs of this workflow, to be used in send_event
125+
self._step_configs = {}
126+
127+
steps: dict[str, StepFunction] = {
128+
"_done": self._done,
129+
**workflow._get_steps(),
130+
}
131+
for step_name, step_func in steps.items():
132+
self._step_configs[step_name] = step_func._step_config
133+
134+
# Transient runtime fields (always reinitialized)
135+
self._tasks = set()
136+
self._cancel_flag = asyncio.Event()
137+
self._step_flags = {}
138+
self._step_events_holding = None
139+
self._step_lock = asyncio.Lock()
140+
self._retval = None
141+
142+
self._lock = asyncio.Lock()
143+
144+
self._dispatcher = workflow._dispatcher
145+
146+
self._step_condition = asyncio.Condition(lock=self._step_lock)
147+
self._step_event_written = asyncio.Condition(lock=self._step_lock)
148+
# Keep track of the steps currently running (transient)
149+
self._currently_running_steps = defaultdict(int)
150+
# Default initial values for persistent fields
151+
152+
self._state = state
153+
154+
# initialize running state from workflow
155+
for name, step_func in steps.items():
156+
if name not in self._state.queues:
157+
self._state.queues[name] = asyncio.Queue()
158+
159+
if name not in self._step_flags:
160+
self._step_flags[name] = asyncio.Event()
161+
162+
step_config: StepConfig = step_func._step_config
163+
164+
# Make the system step "_done" accept custom stop events
165+
if (
166+
name == "_done"
167+
and workflow._stop_event_class not in step_config.accepted_events
168+
):
169+
step_config.accepted_events.append(workflow._stop_event_class)
170+
171+
for _ in range(step_config.num_workers):
172+
self._add_step_worker(
173+
name=name,
174+
step=step_func,
175+
config=step_config,
176+
verbose=workflow._verbose,
177+
run_id=run_id,
178+
worker_id=str(uuid.uuid4()),
179+
resource_manager=workflow._resource_manager,
180+
)
181+
182+
# add dedicated cancel task
183+
self._add_cancel_worker()
184+
185+
@step(num_workers=1)
186+
async def _done(self, ev: StopEvent) -> None:
187+
"""Tears down the whole workflow and stop execution."""
188+
result = ev.result if type(ev) is StopEvent else ev
189+
self.finalize_run(ev, result)
190+
self.write_event_to_stream(ev)
191+
# Signal we want to stop the workflow
192+
raise WorkflowDone
193+
114194
@property
115195
def is_running(self) -> bool:
116196
return self._state.is_running
@@ -359,73 +439,6 @@ async def shutdown(self) -> None:
359439
await asyncio.gather(*self._tasks, return_exceptions=True)
360440
self._tasks.clear()
361441

362-
def __init__(
363-
self,
364-
workflow: Workflow,
365-
context: Context[MODEL_T],
366-
state: WorkflowBrokerState,
367-
run_id: str,
368-
) -> None:
369-
self._context = context
370-
self._handler = None
371-
372-
# Store the step configs of this workflow, to be used in send_event
373-
self._step_configs = {}
374-
for step_name, step_func in workflow._get_steps().items():
375-
self._step_configs[step_name] = getattr(step_func, "__step_config", None)
376-
377-
# Transient runtime fields (always reinitialized)
378-
self._tasks = set()
379-
self._cancel_flag = asyncio.Event()
380-
self._step_flags = {}
381-
self._step_events_holding = None
382-
self._step_lock = asyncio.Lock()
383-
self._retval = None
384-
385-
self._lock = asyncio.Lock()
386-
387-
self._dispatcher = workflow._dispatcher
388-
389-
self._step_condition = asyncio.Condition(lock=self._step_lock)
390-
self._step_event_written = asyncio.Condition(lock=self._step_lock)
391-
# Keep track of the steps currently running (transient)
392-
self._currently_running_steps = defaultdict(int)
393-
# Default initial values for persistent fields
394-
395-
self._state = state
396-
397-
# initialize running state from workflow
398-
for name, step_func in workflow._get_steps().items():
399-
if name not in self._state.queues:
400-
self._state.queues[name] = asyncio.Queue()
401-
402-
if name not in self._step_flags:
403-
self._step_flags[name] = asyncio.Event()
404-
405-
# At this point, step_func is guaranteed to have the `__step_config` attribute
406-
step_config: StepConfig = getattr(step_func, "__step_config")
407-
408-
# Make the system step "_done" accept custom stop events
409-
if (
410-
name == "_done"
411-
and workflow._stop_event_class not in step_config.accepted_events
412-
):
413-
step_config.accepted_events.append(workflow._stop_event_class)
414-
415-
for _ in range(step_config.num_workers):
416-
self._add_step_worker(
417-
name=name,
418-
step=step_func,
419-
config=step_config,
420-
verbose=workflow._verbose,
421-
run_id=run_id,
422-
worker_id=str(uuid.uuid4()),
423-
resource_manager=workflow._resource_manager,
424-
)
425-
426-
# add dedicated cancel task
427-
self._add_cancel_worker()
428-
429442
async def _mark_in_progress(
430443
self, name: str, ev: Event, worker_id: str = ""
431444
) -> None:

src/workflows/server/representation_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def _extract_workflow_structure(
8585
# Assuming that `Workflow` is validated before drawing, it's enough to find the first one.
8686
current_stop_event = None
8787
for step_name, step_func in steps.items():
88-
step_config = getattr(step_func, "__step_config", None)
89-
if step_config is None:
90-
continue
88+
step_config = step_func._step_config
9189

9290
for return_type in step_config.return_types:
9391
if issubclass(return_type, StopEvent):
@@ -99,10 +97,8 @@ def _extract_workflow_structure(
9997

10098
# First pass: Add all nodes
10199
for step_name, step_func in steps.items():
102-
step_config = getattr(step_func, "__step_config", None)
103-
if step_config is None:
104-
continue
105-
100+
step_config = step_func._step_config
101+
106102
# Add step node
107103
step_label = (
108104
_truncate_label(step_name, max_label_length)
@@ -198,10 +194,8 @@ def _extract_workflow_structure(
198194

199195
# Second pass: Add edges
200196
for step_name, step_func in steps.items():
201-
step_config = getattr(step_func, "__step_config", None)
202-
if step_config is None:
203-
continue
204-
197+
step_config = step_func._step_config
198+
205199
# Edges from steps to return types
206200
for return_type in step_config.return_types:
207201
if return_type is not type(None):

src/workflows/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def get_steps_from_class(_class: object) -> dict[str, StepFunction]:
169169
all_methods = inspect.getmembers(_class, predicate=inspect.isfunction)
170170

171171
for name, method in all_methods:
172-
if hasattr(method, "__step_config"):
172+
if hasattr(method, "_step_config"):
173173
step_methods[name] = cast(StepFunction, method)
174174

175175
return step_methods
@@ -192,7 +192,7 @@ def get_steps_from_instance(workflow: object) -> dict[str, StepFunction]:
192192
all_methods = inspect.getmembers(workflow, predicate=inspect.ismethod)
193193

194194
for name, method in all_methods:
195-
if hasattr(method, "__step_config"):
195+
if hasattr(method, "_step_config"):
196196
step_methods[name] = cast(StepFunction, method)
197197

198198
return step_methods

src/workflows/workflow.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pydantic import ValidationError
1515

1616
from .context import Context
17-
from .decorators import StepConfig, StepFunction, step
17+
from .decorators import StepConfig, StepFunction
1818
from .errors import (
1919
WorkflowConfigurationError,
2020
WorkflowRuntimeError,
@@ -135,7 +135,7 @@ def _ensure_start_event_class(self) -> type[StartEvent]:
135135
"""
136136
start_events_found: set[type[StartEvent]] = set()
137137
for step_func in self._get_steps().values():
138-
step_config: StepConfig = getattr(step_func, "__step_config")
138+
step_config: StepConfig = step_func._step_config
139139
for event_type in step_config.accepted_events:
140140
if issubclass(event_type, StartEvent):
141141
start_events_found.add(event_type)
@@ -196,7 +196,7 @@ def _ensure_stop_event_class(self) -> type[RunResultT]:
196196
"""
197197
stop_events_found: set[type[StopEvent]] = set()
198198
for step_func in self._get_steps().values():
199-
step_config: StepConfig = getattr(step_func, "__step_config")
199+
step_config: StepConfig = step_func._step_config
200200
for event_type in step_config.return_types:
201201
if issubclass(event_type, StopEvent):
202202
stop_events_found.add(event_type)
@@ -226,7 +226,7 @@ def add_step(cls, func: StepFunction) -> None:
226226
227227
It raises an exception if a step with the same name was already added to the workflow.
228228
"""
229-
step_config: StepConfig | None = getattr(func, "__step_config", None)
229+
step_config: StepConfig | None = getattr(func, "_step_config", None)
230230
if not step_config:
231231
msg = f"Step function {func.__name__} is missing the `@step` decorator."
232232
raise WorkflowValidationError(msg)
@@ -336,13 +336,6 @@ def run(
336336
workflow=self, start_event=start_event_instance, semaphore=self._sem
337337
)
338338

339-
@step(num_workers=1)
340-
async def _done(self, ctx: Context, ev: StopEvent) -> None:
341-
"""Tears down the whole workflow and stop execution."""
342-
ctx._internal_finalize_run(
343-
ev, ev.result if self._stop_event_class is StopEvent else ev
344-
)
345-
346339
def _validate(self) -> bool:
347340
"""
348341
Validate the workflow to ensure it's well-formed.
@@ -359,9 +352,7 @@ def _validate(self) -> bool:
359352
steps_accepting_stop_event: list[str] = []
360353

361354
for name, step_func in self._get_steps().items():
362-
step_config: StepConfig | None = getattr(step_func, "__step_config")
363-
# At this point we know step config is not None, let's make the checker happy
364-
assert step_config is not None
355+
step_config: StepConfig = step_func._step_config
365356

366357
# Check that no user-defined step accepts StopEvent (only _done step should)
367358
if name != "_done":

0 commit comments

Comments
 (0)