Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,43 @@ class Workflow(BaseWorkflow):
]


def test_serialize_node__lazy_reference_workflow_output():
"""Test that LazyReference can resolve to workflow output references during serialization."""

# GIVEN a node with a LazyReference to a workflow output
class NodeWithWorkflowOutputReference(BaseNode):
workflow_output_ref = LazyReference(lambda: TestWorkflow.Outputs.final_result)

class TestNode(BaseNode):
class Outputs(BaseNode.Outputs):
result: str = "test result"

# AND a workflow that defines an output
class TestWorkflow(BaseWorkflow):
graph = NodeWithWorkflowOutputReference >> TestNode

class Outputs(BaseWorkflow.Outputs):
final_result = TestNode.Outputs.result

# WHEN the node is serialized in the context of the workflow
workflow_display = get_workflow_display(workflow_class=TestWorkflow)
serialized_workflow: dict = workflow_display.serialize()

# THEN the node should properly serialize the workflow output reference
node_with_lazy_reference = next(
node
for node in serialized_workflow["workflow_raw_data"]["nodes"]
if node["id"] == str(NodeWithWorkflowOutputReference.__id__)
)

# AND the workflow output reference should resolve to the underlying node output
assert len(node_with_lazy_reference["attributes"]) == 1
attr = node_with_lazy_reference["attributes"][0]
assert attr["name"] == "workflow_output_ref"
assert attr["value"]["type"] == "NODE_OUTPUT"
assert attr["value"]["node_id"] == str(TestNode.__id__)


def test_serialize_node__workflow_input(serialize_node):
class WorkflowInputGenericNode(BaseNode):
attr: str = Inputs.input
Expand Down
3 changes: 3 additions & 0 deletions ee/vellum_ee/workflows/display/utils/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex
}

if isinstance(value, OutputReference):
if issubclass(value.outputs_class, BaseWorkflow.Outputs):
return serialize_value(executable_id, display_context, value.instance)

if value not in display_context.global_node_output_displays:
if issubclass(value.outputs_class, BaseNode.Outputs):
raise InvalidOutputReferenceError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,31 @@ class Outputs(BaseWorkflow.Outputs):

# WHEN we serialize it
workflow_display = get_workflow_display(workflow_class=Workflow)
serialized_workflow = workflow_display.serialize()

# THEN it should successfully serialize the workflow output reference to a constant
assert isinstance(serialized_workflow, dict)
output_variables = serialized_workflow["output_variables"]
assert isinstance(output_variables, list)
assert output_variables == [{"id": "2b32416b-ccfc-4231-a3a6-d08e76327815", "key": "final", "type": "STRING"}]

# AND the output value should be a constant value
workflow_raw_data = serialized_workflow["workflow_raw_data"]
assert isinstance(workflow_raw_data, dict)
output_values = workflow_raw_data["output_values"]
assert isinstance(output_values, list)
assert output_values == [
{
"output_variable_id": "2b32416b-ccfc-4231-a3a6-d08e76327815",
"value": {"type": "CONSTANT_VALUE", "value": {"type": "STRING", "value": "bar"}},
}
]

# THEN it should raise an error
with pytest.raises(UserFacingException) as exc_info:
workflow_display.serialize()

# AND the error message should be user friendly
assert (
str(exc_info.value)
== """Failed to serialize output 'final': Reference to outputs \
'test_serialize_workflow__workflow_outputs_reference_non_node_outputs.<locals>.FirstWorkflow.Outputs' is invalid."""
)
first_output_variable = output_variables[0]
assert isinstance(first_output_variable, dict)
first_output_value = output_values[0]
assert isinstance(first_output_value, dict)
assert first_output_variable["id"] == first_output_value["output_variable_id"]


def test_serialize_workflow__node_display_class_not_registered():
Expand Down