|
| 1 | +# SPDX-License-Identifier: MIT |
| 2 | +# Copyright (c) 2025 LlamaIndex Inc. |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import time |
| 7 | +from typing import AsyncGenerator |
| 8 | + |
| 9 | +from dbos import DBOS, SetWorkflowID # required extra, import must succeed |
| 10 | + |
| 11 | +from workflows.events import Event, StopEvent |
| 12 | +from workflows.runtime.types.plugin import ( |
| 13 | + ControlLoopFunction, |
| 14 | + Plugin, |
| 15 | + WorkflowRuntime, |
| 16 | + RegisteredWorkflow, |
| 17 | +) |
| 18 | +from workflows.runtime.types.internal_state import BrokerState |
| 19 | +from workflows.runtime.types.step_function import StepWorkerFunction |
| 20 | +from workflows.runtime.types.ticks import WorkflowTick |
| 21 | + |
| 22 | +from workflows.workflow import Workflow |
| 23 | + |
| 24 | + |
| 25 | +@DBOS.step() |
| 26 | +async def _durable_time() -> float: |
| 27 | + return time.time() |
| 28 | + |
| 29 | + |
| 30 | +class DBOSRuntime: |
| 31 | + def register( |
| 32 | + self, |
| 33 | + workflow: Workflow, |
| 34 | + workflow_function: ControlLoopFunction, |
| 35 | + steps: dict[str, StepWorkerFunction], |
| 36 | + ) -> RegisteredWorkflow | None: |
| 37 | + """ |
| 38 | + Wrap the workflow control loop in a DBOS workflow so ticks are received via DBOS.recv |
| 39 | + and sent via DBOS.send, enabling durable orchestration. |
| 40 | + """ |
| 41 | + |
| 42 | + @DBOS.workflow() |
| 43 | + async def _dbos_control_loop( |
| 44 | + start_event: Event | None, |
| 45 | + init_state: BrokerState | None, |
| 46 | + run_id: str, |
| 47 | + ) -> StopEvent: |
| 48 | + with SetWorkflowID(run_id): |
| 49 | + return await workflow_function(start_event, init_state, run_id) |
| 50 | + |
| 51 | + async def wrapper( |
| 52 | + start_event: Event | None, |
| 53 | + init_state: BrokerState | None, |
| 54 | + run_id: str, |
| 55 | + ) -> StopEvent: |
| 56 | + # Call the DBOS workflow directly; DBOS will orchestrate execution |
| 57 | + return await _dbos_control_loop(start_event, init_state, run_id) |
| 58 | + |
| 59 | + return RegisteredWorkflow(workflow_function=_dbos_control_loop, steps=steps) |
| 60 | + |
| 61 | + def new_runtime(self, run_id: str) -> WorkflowRuntime: |
| 62 | + runtime: WorkflowRuntime = DBOSWorkflowRuntime(run_id) |
| 63 | + return runtime |
| 64 | + |
| 65 | + |
| 66 | +dbos_runtime: Plugin = DBOSRuntime() |
| 67 | + |
| 68 | + |
| 69 | +class DBOSWorkflowRuntime: |
| 70 | + """ |
| 71 | + Workflow runtime backed by asyncio mailboxes, with durable timing via DBOS when available. |
| 72 | +
|
| 73 | + - send_event/wait_receive implement the tick mailbox used by the control loop |
| 74 | + - write_to_event_stream/stream_published_events expose published events to callers |
| 75 | + - get_now returns a stable value on first call within a run (durable if DBOS is installed) |
| 76 | + - sleep uses DBOS durable sleep when available, otherwise asyncio.sleep |
| 77 | + - on_tick/replay provide a lightweight snapshot for debug/replay via the broker |
| 78 | + """ |
| 79 | + |
| 80 | + def __init__( |
| 81 | + self, |
| 82 | + run_id: str, |
| 83 | + ) -> None: |
| 84 | + self.run_id = run_id |
| 85 | + |
| 86 | + # Mailbox used by control loop and broker |
| 87 | + async def wait_receive(self) -> WorkflowTick: |
| 88 | + # Receive next tick via DBOS durable notification |
| 89 | + tick = await DBOS.recv_async() |
| 90 | + return tick # type: ignore[return-value] |
| 91 | + |
| 92 | + async def send_event(self, tick: WorkflowTick) -> None: |
| 93 | + await DBOS.send_async(self.run_id, tick) |
| 94 | + |
| 95 | + # Event stream used by handlers/observers |
| 96 | + async def write_to_event_stream(self, event: Event) -> None: |
| 97 | + await DBOS.write_stream_async("published_events", event) |
| 98 | + |
| 99 | + async def stream_published_events(self) -> AsyncGenerator[Event, None]: |
| 100 | + async for event in DBOS.read_stream_async(self.run_id, "published_events"): |
| 101 | + yield event |
| 102 | + |
| 103 | + # Timing utilities |
| 104 | + async def get_now(self) -> float: |
| 105 | + return await _durable_time() |
| 106 | + |
| 107 | + async def sleep(self, seconds: float) -> None: |
| 108 | + await DBOS.sleep_async(seconds) |
| 109 | + |
| 110 | + async def close(self) -> None: |
| 111 | + pass |
0 commit comments