|
23 | 23 |
|
24 | 24 | from llama_index_instrumentation.dispatcher import Dispatcher |
25 | 25 |
|
26 | | -from workflows.decorators import StepConfig |
| 26 | +from workflows.decorators import StepConfig, StepFunction, step |
27 | 27 | from workflows.errors import ( |
28 | 28 | WorkflowCancelledByUser, |
29 | 29 | WorkflowDone, |
@@ -111,6 +111,86 @@ class WorkflowBroker(Generic[MODEL_T]): |
111 | 111 | # run state |
112 | 112 | _handler: WorkflowHandler | None |
113 | 113 |
|
| 114 | + def __init__( |
| 115 | + self, |
| 116 | + workflow: Workflow, |
| 117 | + context: Context[MODEL_T], |
| 118 | + state: WorkflowBrokerState, |
| 119 | + run_id: str, |
| 120 | + ) -> None: |
| 121 | + self._context = context |
| 122 | + self._handler = None |
| 123 | + |
| 124 | + # Store the step configs of this workflow, to be used in send_event |
| 125 | + self._step_configs = {} |
| 126 | + |
| 127 | + steps: dict[str, StepFunction] = { |
| 128 | + "_done": self._done, |
| 129 | + **workflow._get_steps(), |
| 130 | + } |
| 131 | + for step_name, step_func in steps.items(): |
| 132 | + self._step_configs[step_name] = step_func._step_config |
| 133 | + |
| 134 | + # Transient runtime fields (always reinitialized) |
| 135 | + self._tasks = set() |
| 136 | + self._cancel_flag = asyncio.Event() |
| 137 | + self._step_flags = {} |
| 138 | + self._step_events_holding = None |
| 139 | + self._step_lock = asyncio.Lock() |
| 140 | + self._retval = None |
| 141 | + |
| 142 | + self._lock = asyncio.Lock() |
| 143 | + |
| 144 | + self._dispatcher = workflow._dispatcher |
| 145 | + |
| 146 | + self._step_condition = asyncio.Condition(lock=self._step_lock) |
| 147 | + self._step_event_written = asyncio.Condition(lock=self._step_lock) |
| 148 | + # Keep track of the steps currently running (transient) |
| 149 | + self._currently_running_steps = defaultdict(int) |
| 150 | + # Default initial values for persistent fields |
| 151 | + |
| 152 | + self._state = state |
| 153 | + |
| 154 | + # initialize running state from workflow |
| 155 | + for name, step_func in steps.items(): |
| 156 | + if name not in self._state.queues: |
| 157 | + self._state.queues[name] = asyncio.Queue() |
| 158 | + |
| 159 | + if name not in self._step_flags: |
| 160 | + self._step_flags[name] = asyncio.Event() |
| 161 | + |
| 162 | + step_config: StepConfig = step_func._step_config |
| 163 | + |
| 164 | + # Make the system step "_done" accept custom stop events |
| 165 | + if ( |
| 166 | + name == "_done" |
| 167 | + and workflow._stop_event_class not in step_config.accepted_events |
| 168 | + ): |
| 169 | + step_config.accepted_events.append(workflow._stop_event_class) |
| 170 | + |
| 171 | + for _ in range(step_config.num_workers): |
| 172 | + self._add_step_worker( |
| 173 | + name=name, |
| 174 | + step=step_func, |
| 175 | + config=step_config, |
| 176 | + verbose=workflow._verbose, |
| 177 | + run_id=run_id, |
| 178 | + worker_id=str(uuid.uuid4()), |
| 179 | + resource_manager=workflow._resource_manager, |
| 180 | + ) |
| 181 | + |
| 182 | + # add dedicated cancel task |
| 183 | + self._add_cancel_worker() |
| 184 | + |
| 185 | + @step(num_workers=1) |
| 186 | + async def _done(self, ev: StopEvent) -> None: |
| 187 | + """Tears down the whole workflow and stop execution.""" |
| 188 | + result = ev.result if type(ev) is StopEvent else ev |
| 189 | + self.finalize_run(ev, result) |
| 190 | + self.write_event_to_stream(ev) |
| 191 | + # Signal we want to stop the workflow |
| 192 | + raise WorkflowDone |
| 193 | + |
114 | 194 | @property |
115 | 195 | def is_running(self) -> bool: |
116 | 196 | return self._state.is_running |
@@ -359,73 +439,6 @@ async def shutdown(self) -> None: |
359 | 439 | await asyncio.gather(*self._tasks, return_exceptions=True) |
360 | 440 | self._tasks.clear() |
361 | 441 |
|
362 | | - def __init__( |
363 | | - self, |
364 | | - workflow: Workflow, |
365 | | - context: Context[MODEL_T], |
366 | | - state: WorkflowBrokerState, |
367 | | - run_id: str, |
368 | | - ) -> None: |
369 | | - self._context = context |
370 | | - self._handler = None |
371 | | - |
372 | | - # Store the step configs of this workflow, to be used in send_event |
373 | | - self._step_configs = {} |
374 | | - for step_name, step_func in workflow._get_steps().items(): |
375 | | - self._step_configs[step_name] = getattr(step_func, "__step_config", None) |
376 | | - |
377 | | - # Transient runtime fields (always reinitialized) |
378 | | - self._tasks = set() |
379 | | - self._cancel_flag = asyncio.Event() |
380 | | - self._step_flags = {} |
381 | | - self._step_events_holding = None |
382 | | - self._step_lock = asyncio.Lock() |
383 | | - self._retval = None |
384 | | - |
385 | | - self._lock = asyncio.Lock() |
386 | | - |
387 | | - self._dispatcher = workflow._dispatcher |
388 | | - |
389 | | - self._step_condition = asyncio.Condition(lock=self._step_lock) |
390 | | - self._step_event_written = asyncio.Condition(lock=self._step_lock) |
391 | | - # Keep track of the steps currently running (transient) |
392 | | - self._currently_running_steps = defaultdict(int) |
393 | | - # Default initial values for persistent fields |
394 | | - |
395 | | - self._state = state |
396 | | - |
397 | | - # initialize running state from workflow |
398 | | - for name, step_func in workflow._get_steps().items(): |
399 | | - if name not in self._state.queues: |
400 | | - self._state.queues[name] = asyncio.Queue() |
401 | | - |
402 | | - if name not in self._step_flags: |
403 | | - self._step_flags[name] = asyncio.Event() |
404 | | - |
405 | | - # At this point, step_func is guaranteed to have the `__step_config` attribute |
406 | | - step_config: StepConfig = getattr(step_func, "__step_config") |
407 | | - |
408 | | - # Make the system step "_done" accept custom stop events |
409 | | - if ( |
410 | | - name == "_done" |
411 | | - and workflow._stop_event_class not in step_config.accepted_events |
412 | | - ): |
413 | | - step_config.accepted_events.append(workflow._stop_event_class) |
414 | | - |
415 | | - for _ in range(step_config.num_workers): |
416 | | - self._add_step_worker( |
417 | | - name=name, |
418 | | - step=step_func, |
419 | | - config=step_config, |
420 | | - verbose=workflow._verbose, |
421 | | - run_id=run_id, |
422 | | - worker_id=str(uuid.uuid4()), |
423 | | - resource_manager=workflow._resource_manager, |
424 | | - ) |
425 | | - |
426 | | - # add dedicated cancel task |
427 | | - self._add_cancel_worker() |
428 | | - |
429 | 442 | async def _mark_in_progress( |
430 | 443 | self, name: str, ev: Event, worker_id: str = "" |
431 | 444 | ) -> None: |
|
0 commit comments