Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,43 @@ class Workflow(BaseWorkflow):
]


def test_serialize_node__lazy_reference_workflow_output():
"""Test that LazyReference can resolve to workflow output references during serialization."""

# GIVEN a node with a LazyReference to a workflow output
class NodeWithWorkflowOutputReference(BaseNode):
workflow_output_ref = LazyReference(lambda: TestWorkflow.Outputs.final_result)

class TestNode(BaseNode):
class Outputs(BaseNode.Outputs):
result: str = "test result"

# AND a workflow that defines an output
class TestWorkflow(BaseWorkflow):
graph = NodeWithWorkflowOutputReference >> TestNode

class Outputs(BaseWorkflow.Outputs):
final_result = TestNode.Outputs.result

# WHEN the node is serialized in the context of the workflow
workflow_display = get_workflow_display(workflow_class=TestWorkflow)
serialized_workflow: dict = workflow_display.serialize()

# THEN the node should properly serialize the workflow output reference
node_with_lazy_reference = next(
node
for node in serialized_workflow["workflow_raw_data"]["nodes"]
if node["id"] == str(NodeWithWorkflowOutputReference.__id__)
)

# AND the workflow output reference should resolve to the underlying node output
assert len(node_with_lazy_reference["attributes"]) == 1
attr = node_with_lazy_reference["attributes"][0]
assert attr["name"] == "workflow_output_ref"
assert attr["value"]["type"] == "NODE_OUTPUT"
assert attr["value"]["node_id"] == str(TestNode.__id__)


def test_serialize_node__workflow_input(serialize_node):
class WorkflowInputGenericNode(BaseNode):
attr: str = Inputs.input
Expand Down
9 changes: 9 additions & 0 deletions ee/vellum_ee/workflows/display/utils/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex
}

if isinstance(value, OutputReference):
if issubclass(value.outputs_class, BaseWorkflow.Outputs):
# Only resolve workflow output references that point to node outputs
if isinstance(value.instance, OutputReference) and issubclass(
value.instance.outputs_class, BaseNode.Outputs
):
return serialize_value(executable_id, display_context, value.instance)
# For other instances (constants, other workflow outputs), raise an error
raise InvalidOutputReferenceError(f"Reference to outputs '{value.outputs_class.__qualname__}' is invalid.")

if value not in display_context.global_node_output_displays:
if issubclass(value.outputs_class, BaseNode.Outputs):
raise InvalidOutputReferenceError(
Expand Down