Skip to content

Commit eb1d2b5

Browse files
committed
more cleanup
1 parent a58dd90 commit eb1d2b5

File tree

6 files changed

+23
-22
lines changed

6 files changed

+23
-22
lines changed

src/workflows/context/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
if (
154154
step_config.context_state_type is not None
155155
and step_config.context_state_type != DictState
156-
and isinstance(step_config.context_state_type, type(BaseModel))
156+
and issubclass(step_config.context_state_type, BaseModel)
157157
):
158158
state_type = step_config.context_state_type
159159
state_types.add(state_type)
@@ -207,14 +207,14 @@ def _init_broker(self, workflow: Workflow) -> WorkflowBroker[MODEL_T]:
207207
)
208208
return self._broker_run
209209

210-
def _internal_run(
210+
def _workflow_run(
211211
self,
212212
workflow: Workflow,
213213
start_event: StartEvent | None = None,
214214
semaphore: asyncio.Semaphore | None = None,
215215
) -> WorkflowHandler:
216216
"""
217-
Called internally from the workflow to run it
217+
called by package internally from the workflow to run it
218218
"""
219219
prev_broker: WorkflowBroker[MODEL_T] | None = None
220220
if self._broker_run is not None:
@@ -243,7 +243,7 @@ async def after_complete() -> None:
243243
after_complete=after_complete,
244244
)
245245

246-
def _internal_cancel_run(self) -> None:
246+
def _workflow_cancel_run(self) -> None:
247247
"""
248248
Called internally from the handler to cancel a context's run
249249
"""

src/workflows/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class StepConfig(BaseModel):
5555

5656

5757
class StepFunction(Protocol, Generic[P, R]):
58-
"""A decorated function, that has some step_config metadata from the @step decorator"""
58+
"""A decorated function, that has some _step_config metadata from the @step decorator"""
5959

6060
_step_config: StepConfig
6161

src/workflows/handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import asyncio
77
from typing import Any, AsyncGenerator
88

9-
from .context.context import Context
9+
from .context import Context
1010
from .errors import WorkflowRuntimeError
1111
from .events import Event, StopEvent, InternalDispatchEvent
1212
from .types import RunResultT
@@ -129,6 +129,9 @@ async def cancel_run(self) -> None:
129129
```
130130
"""
131131
if self.ctx:
132-
self.ctx._internal_cancel_run()
132+
self.ctx._workflow_cancel_run()
133133
if self._run_task is not None:
134-
await self._run_task
134+
try:
135+
await self._run_task
136+
except Exception:
137+
pass

src/workflows/runtime/broker.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,14 @@ def __init__(
124124
# Store the step configs of this workflow, to be used in send_event
125125
self._step_configs = {}
126126

127+
# Make the system step "_done" accept custom stop events
128+
done_step = self._done
129+
done_accepted_events = done_step._step_config.accepted_events
130+
if workflow._stop_event_class not in done_accepted_events:
131+
done_accepted_events.append(workflow._stop_event_class)
132+
127133
steps: dict[str, StepFunction] = {
128-
"_done": self._done,
134+
"_done": done_step,
129135
**workflow._get_steps(),
130136
}
131137
for step_name, step_func in steps.items():
@@ -161,13 +167,6 @@ def __init__(
161167

162168
step_config: StepConfig = step_func._step_config
163169

164-
# Make the system step "_done" accept custom stop events
165-
if (
166-
name == "_done"
167-
and workflow._stop_event_class not in step_config.accepted_events
168-
):
169-
step_config.accepted_events.append(workflow._stop_event_class)
170-
171170
for _ in range(step_config.num_workers):
172171
self._add_step_worker(
173172
name=name,
@@ -220,6 +219,9 @@ def start(
220219
)
221220

222221
async def _run_workflow() -> None:
222+
# defer execution to make sure the task can be captured and passed
223+
# to the handler, protecting against exceptions from before_start
224+
await asyncio.sleep(0)
223225
if before_start is not None:
224226
await before_start()
225227
try:

src/workflows/testing/runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,11 @@ async def run(
6767
"""
6868
handler = self._workflow.run(start_event=start_event, ctx=ctx)
6969
collected_events: list[Event] = []
70-
print("streaming events")
7170
async for event in handler.stream_events(expose_internal=expose_internal):
72-
print("streaming event", type(event).__name__)
7371
if exclude_events and type(event) in exclude_events:
7472
continue
7573
collected_events.append(event)
76-
print("awaiting handler")
7774
result = await handler
78-
print("result", result)
7975
event_freqs: dict[EventType, int] = dict(
8076
Counter([type(ev) for ev in collected_events])
8177
)

src/workflows/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _ensure_events_collected(self) -> list[type[Event]]:
173173
"""
174174
events_found: set[type[Event]] = set()
175175
for step_func in self._get_steps().values():
176-
step_config: StepConfig = getattr(step_func, "__step_config")
176+
step_config: StepConfig = step_func._step_config
177177

178178
# Do not collect events from the done step
179179
if step_func.__name__ == "_done":
@@ -332,7 +332,7 @@ def run(
332332
if ctx.is_running
333333
else self._get_start_event_instance(start_event, **kwargs)
334334
)
335-
return ctx._internal_run(
335+
return ctx._workflow_run(
336336
workflow=self, start_event=start_event_instance, semaphore=self._sem
337337
)
338338

0 commit comments

Comments
 (0)