diff --git a/src/vellum/workflows/runner/runner.py b/src/vellum/workflows/runner/runner.py index 380617893..bfd33d1e2 100644 --- a/src/vellum/workflows/runner/runner.py +++ b/src/vellum/workflows/runner/runner.py @@ -1135,6 +1135,14 @@ def _stream(self) -> None: if self._trigger is not None: self._trigger.__on_workflow_fulfilled__(final_state) + # Drain any events produced by the trigger hook (e.g., snapshot events from state mutations) + # and forward them to the outer queue so stream consumers can observe them + try: + while event := self._workflow_event_inner_queue.get_nowait(): + self._workflow_event_outer_queue.put(event) + except Empty: + pass + fulfilled_outputs = self.workflow.Outputs() for descriptor, value in fulfilled_outputs: if isinstance(value, BaseDescriptor): diff --git a/tests/workflows/chat_message_trigger_execution/tests/test_chat_message_trigger_execution.py b/tests/workflows/chat_message_trigger_execution/tests/test_chat_message_trigger_execution.py index 65b6edd48..6b6a23519 100644 --- a/tests/workflows/chat_message_trigger_execution/tests/test_chat_message_trigger_execution.py +++ b/tests/workflows/chat_message_trigger_execution/tests/test_chat_message_trigger_execution.py @@ -1,5 +1,9 @@ """Tests for ChatMessageTrigger workflow execution.""" +from vellum.client.types import ChatMessage +from vellum.workflows.events.workflow import WorkflowExecutionSnapshottedEvent +from vellum.workflows.workflows.event_filters import all_workflow_event_filter + from tests.workflows.chat_message_trigger_execution.workflows.simple_chat_workflow import ( ChatState, SimpleChatTrigger, @@ -42,3 +46,44 @@ def test_chat_message_trigger__workflow_output_reference(): assert final_state.chat_history[0].text == "Hello" assert final_state.chat_history[1].role == "ASSISTANT" assert final_state.chat_history[1].text == "Hello from assistant!" + + +def test_chat_message_trigger__emits_snapshot_events_for_trigger_state_mutations(): + """Tests that snapshot events are emitted when trigger mutates state in lifecycle hooks.""" + + # GIVEN a workflow using SimpleChatTrigger + workflow = SimpleChatWorkflow() + + # AND a trigger with message + trigger = SimpleChatTrigger(message="Hello") + + # WHEN we stream the workflow events with all_workflow_event_filter to include snapshot events + events = list(workflow.stream(trigger=trigger, event_filter=all_workflow_event_filter)) + + # THEN we should have snapshot events for the trigger's state mutations + snapshot_events = [e for e in events if isinstance(e, WorkflowExecutionSnapshottedEvent)] + + # AND there should be at least 2 snapshot events (one for user message, one for assistant message) + # These are emitted when ChatMessageTrigger appends to chat_history in __on_workflow_initiated__ + # and __on_workflow_fulfilled__ + assert len(snapshot_events) >= 2, f"Expected at least 2 snapshot events, got {len(snapshot_events)}" + + # AND the first snapshot event should contain just the user message (from __on_workflow_initiated__) + user_message_snapshot = snapshot_events[0] + assert user_message_snapshot.state.chat_history == [ + ChatMessage(role="USER", text="Hello", content=None, source=None), + ] + + # AND the last snapshot event should contain the full chat history with both messages + # (from __on_workflow_fulfilled__) + last_snapshot = snapshot_events[-1] + assert last_snapshot.state.chat_history == [ + ChatMessage(role="USER", text="Hello", content=None, source=None), + ChatMessage(role="ASSISTANT", text="Hello from assistant!", content=None, source=None), + ] + + # AND the snapshot events should appear before the fulfilled event + event_names = [e.name for e in events] + last_snapshot_idx = max(i for i, e in enumerate(events) if isinstance(e, WorkflowExecutionSnapshottedEvent)) + fulfilled_idx = next(i for i, name in enumerate(event_names) if name == "workflow.execution.fulfilled") + assert last_snapshot_idx < fulfilled_idx, "Snapshot events should appear before fulfilled event"