Skip to content

Commit 7eadc7e

Browse files
committed
Adding registry workaround
1 parent 4f3ac7d commit 7eadc7e

File tree

5 files changed

+92
-80
lines changed

5 files changed

+92
-80
lines changed

src/workflows/plugins/dbos.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
RegisteredWorkflow,
1717
)
1818
from workflows.runtime.types.internal_state import BrokerState
19-
from workflows.context.context import Context
2019
from workflows.runtime.types.step_function import StepWorkerFunction
2120
from workflows.runtime.types.ticks import WorkflowTick
22-
from workflows.decorators import R
21+
2322
from workflows.workflow import Workflow
2423

2524

@@ -33,50 +32,31 @@ def register(
3332
self,
3433
workflow: Workflow,
3534
workflow_function: ControlLoopFunction,
36-
steps: dict[str, StepWorkerFunction[R]],
35+
steps: dict[str, StepWorkerFunction],
3736
) -> RegisteredWorkflow | None:
3837
"""
3938
Wrap the workflow control loop in a DBOS workflow so ticks are received via DBOS.recv
4039
and sent via DBOS.send, enabling durable orchestration.
4140
"""
4241

43-
44-
45-
# DBOS Python supports async workflow functions; we wrap and return replacement.
46-
# Note: We do not wrap individual step workers here; the control loop retains
47-
# responsibility for step execution. This can be extended later to split steps
48-
# into child workflows if needed.
49-
5042
@DBOS.workflow()
5143
async def _dbos_control_loop(
5244
start_event: Event | None,
5345
init_state: BrokerState | None,
54-
plugin: WorkflowRuntime,
55-
context: Context,
56-
step_workers: dict[str, StepWorkerFunction],
46+
run_id: str,
5747
) -> StopEvent:
58-
# Ensure our runtime knows the workflow id for routing incoming ticks
59-
assert isinstance(plugin, DBOSWorkflowRuntime)
60-
# Pin a stable workflow id for this run using the runtime's run_id
61-
with SetWorkflowID(plugin.run_id):
62-
# Delegate to the original control loop function
63-
return await workflow_function( # type: ignore[misc]
64-
start_event, init_state, plugin, context, step_workers
65-
)
48+
with SetWorkflowID(run_id):
49+
return await workflow_function(start_event, init_state, run_id)
6650

6751
async def wrapper(
6852
start_event: Event | None,
6953
init_state: BrokerState | None,
70-
plugin: WorkflowRuntime,
71-
context: Context,
72-
step_workers: dict[str, StepWorkerFunction],
54+
run_id: str,
7355
) -> StopEvent:
7456
# Call the DBOS workflow directly; DBOS will orchestrate execution
75-
return await _dbos_control_loop(
76-
start_event, init_state, plugin, context, step_workers
77-
)
57+
return await _dbos_control_loop(start_event, init_state, run_id)
7858

79-
return RegisteredWorkflow(workflow_function=wrapper, steps=steps)
59+
return RegisteredWorkflow(workflow_function=_dbos_control_loop, steps=steps)
8060

8161
def new_runtime(self, run_id: str) -> WorkflowRuntime:
8262
runtime: WorkflowRuntime = DBOSWorkflowRuntime(run_id)

src/workflows/runtime/broker.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
StartEvent,
2828
StopEvent,
2929
)
30-
from workflows.runtime.control_loop import create_control_loop, rebuild_state_from_ticks
30+
from workflows.runtime.control_loop import control_loop, rebuild_state_from_ticks
3131
from workflows.runtime.types.internal_state import BrokerState
3232
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime, as_snapshottable
3333
from workflows.runtime.types.results import (
@@ -39,7 +39,10 @@
3939
StepWorkerStateContextVar,
4040
WaitingForEvent,
4141
)
42-
from workflows.runtime.types.step_function import as_step_worker_function
42+
from workflows.runtime.types.step_function import (
43+
StepWorkerFunction,
44+
as_step_worker_function,
45+
)
4346
from workflows.runtime.types.ticks import TickAddEvent, TickCancelRun, WorkflowTick
4447
from workflows.runtime.workflow_registry import workflow_registry
4548

@@ -118,8 +121,6 @@ def start(
118121
self._init_state = previous
119122

120123
async def _run_workflow() -> None:
121-
from workflows.context.context import Context
122-
123124
# defer execution to make sure the task can be captured and passed
124125
# to the handler as async exception, protecting against exceptions from before_start
125126
self._is_running = True
@@ -132,27 +133,35 @@ async def _run_workflow() -> None:
132133
try:
133134
exception_raised = None
134135

135-
step_workers = {}
136+
step_workers: dict[str, StepWorkerFunction] = {}
136137
for name, step_func in workflow._get_steps().items():
137138
# Avoid capturing a bound method (which retains the instance).
138139
# If it's a bound method, extract the unbound function from the class.
139140
unbound = getattr(step_func, "__func__", step_func)
140141
step_workers[name] = as_step_worker_function(unbound)
141142

142-
control_loop_fn = create_control_loop(
143-
workflow,
144-
)
145143
registered = workflow_registry.get_registered_workflow(
146-
workflow, self._plugin, control_loop_fn, step_workers
144+
workflow, self._plugin, control_loop, step_workers
147145
)
148146

149-
workflow_result = await registered.workflow_function(
150-
start_event,
151-
init_state,
152-
self._runtime,
153-
cast(Context, self._context),
154-
registered.steps,
147+
# Register run context prior to invoking control loop
148+
workflow_registry.register_run(
149+
run_id=run_id,
150+
workflow=workflow,
151+
plugin=self._runtime,
152+
context=self._context, # type: ignore
153+
steps=registered.steps,
155154
)
155+
156+
try:
157+
workflow_result = await registered.workflow_function(
158+
start_event,
159+
init_state,
160+
run_id,
161+
)
162+
finally:
163+
# ensure run context is cleaned up even on failure
164+
workflow_registry.delete_run(run_id)
156165
result.set_result(
157166
workflow_result.result
158167
if type(workflow_result) is StopEvent

src/workflows/runtime/control_loop.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
import time
1010
from typing import TYPE_CHECKING
11-
import weakref
11+
1212

1313
from workflows.decorators import R
1414
from workflows.errors import (
@@ -41,7 +41,6 @@
4141
InternalStepWorkerState,
4242
)
4343
from workflows.runtime.types.plugin import (
44-
ControlLoopFunction,
4544
WorkflowRuntime,
4645
as_snapshottable,
4746
)
@@ -69,6 +68,7 @@
6968
import logging
7069

7170
from workflows.workflow import Workflow
71+
from workflows.runtime.workflow_registry import workflow_registry
7272

7373
if TYPE_CHECKING:
7474
from workflows.context.context import Context
@@ -282,37 +282,23 @@ async def _pull() -> None:
282282
await self.cleanup_tasks()
283283

284284

285-
def create_control_loop(
286-
workflow: Workflow,
287-
) -> ControlLoopFunction:
285+
async def control_loop(
286+
start_event: Event | None,
287+
init_state: BrokerState | None,
288+
run_id: str,
289+
) -> StopEvent:
288290
"""
289-
Creates a control loop for a workflow run. Dependencies are provided as initial args, and returns a simple function
290-
with only start event and an optional start state as an arg, that can be easily decorated.
291-
292-
Returns a function that can be called to start a workflow run.
291+
The main async control loop for a workflow run.
293292
"""
294-
295-
workflow_ref = weakref.ref(workflow)
296-
297-
async def control_loop(
298-
start_event: Event | None,
299-
init_state: BrokerState | None,
300-
# TODO - get these 3 out of here! Needs to be inferred from scope somehow for proper distributed, static execution
301-
plugin: WorkflowRuntime,
302-
context: Context,
303-
step_workers: dict[str, StepWorkerFunction],
304-
) -> StopEvent:
305-
"""
306-
The main async control loop for a workflow run.
307-
"""
308-
wf = workflow_ref()
309-
if wf is None:
310-
raise WorkflowRuntimeError("Workflow instance no longer available")
311-
state = init_state or BrokerState.from_workflow(wf)
312-
runner = _ControlLoopRunner(wf, plugin, context, step_workers, state)
313-
return await runner.run(start_event=start_event)
314-
315-
return control_loop
293+
# Prefer run-scoped context if available (set by broker)
294+
current = workflow_registry.get_run(run_id)
295+
if current is None:
296+
raise WorkflowRuntimeError("Run context not found for control loop")
297+
state = init_state or BrokerState.from_workflow(current.workflow)
298+
runner = _ControlLoopRunner(
299+
current.workflow, current.plugin, current.context, current.steps, state
300+
)
301+
return await runner.run(start_event=start_event)
316302

317303

318304
def rebuild_state_from_ticks(

src/workflows/runtime/types/plugin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,5 @@ def __call__(
130130
self,
131131
start_event: Event | None,
132132
init_state: BrokerState | None,
133-
# these will likely be refactored out from the control loop function in the future.
134-
plugin: WorkflowRuntime,
135-
context: Context,
136-
step_workers: dict[str, StepWorkerFunction],
133+
run_id: str,
137134
) -> Coroutine[None, None, StopEvent]: ...

src/workflows/runtime/workflow_registry.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from threading import Lock
22
from weakref import WeakKeyDictionary
3+
from dataclasses import dataclass
4+
from typing import Optional
35
from workflows.runtime.types.plugin import (
46
ControlLoopFunction,
57
Plugin,
68
RegisteredWorkflow,
9+
WorkflowRuntime,
710
)
811
from workflows.workflow import Workflow
912
from workflows.runtime.types.step_function import StepWorkerFunction
10-
from workflows.decorators import R
13+
from typing import TYPE_CHECKING
14+
15+
if TYPE_CHECKING:
16+
from workflows.context.context import Context
1117

1218

1319
class WorkflowPluginRegistry:
@@ -22,13 +28,14 @@ def __init__(self) -> None:
2228
Workflow, dict[type[Plugin], RegisteredWorkflow]
2329
] = WeakKeyDictionary()
2430
self.lock = Lock()
31+
self.run_contexts: dict[str, RegisteredRunContext] = {}
2532

2633
def get_registered_workflow(
2734
self,
2835
workflow: Workflow,
2936
plugin: Plugin,
3037
workflow_function: ControlLoopFunction,
31-
steps: dict[str, StepWorkerFunction[R]],
38+
steps: dict[str, StepWorkerFunction],
3239
) -> RegisteredWorkflow:
3340
plugin_type = type(plugin)
3441

@@ -51,5 +58,38 @@ def get_registered_workflow(
5158
plugin_map[plugin_type] = registered_workflow
5259
return registered_workflow
5360

61+
def register_run(
62+
self,
63+
run_id: str,
64+
workflow: Workflow,
65+
plugin: WorkflowRuntime,
66+
context: "Context",
67+
steps: dict[str, StepWorkerFunction],
68+
) -> None:
69+
self.run_contexts[run_id] = RegisteredRunContext(
70+
run_id=run_id,
71+
workflow=workflow,
72+
plugin=plugin,
73+
context=context,
74+
steps=steps,
75+
)
76+
77+
def get_run(self, run_id: str) -> Optional["RegisteredRunContext"]:
78+
return self.run_contexts.get(run_id)
79+
80+
def delete_run(self, run_id: str) -> None:
81+
self.run_contexts.pop(run_id, None)
82+
5483

5584
workflow_registry = WorkflowPluginRegistry()
85+
86+
87+
@dataclass
88+
class RegisteredRunContext:
89+
run_id: str
90+
workflow: Workflow
91+
plugin: WorkflowRuntime
92+
context: "Context"
93+
steps: dict[str, StepWorkerFunction]
94+
95+

0 commit comments

Comments
 (0)