Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from unittest.mock import MagicMock

from vellum.client.types.vellum_variable import VellumVariable
from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display

from tests.workflows.basic_tool_calling_node_workflow_deployment_tool_wrapper.workflow import (
BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow,
)


def test_serialize_workflow__workflow_deployment_with_tool_wrapper(vellum_client):
"""
Tests that a workflow deployment with tool wrapper serializes correctly with definition.
"""

# GIVEN a workflow that uses a workflow deployment with tool wrapper
workflow_class = BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow

# AND a mock for the workflow deployment release info
mock_release = MagicMock()
mock_release.deployment.name = "weather-workflow-deployment"
mock_release.description = "A workflow that gets the weather for a given city and date."
mock_release.workflow_version.input_variables = [
VellumVariable(
id="city-id",
key="city",
type="STRING",
required=True,
default=None,
),
VellumVariable(
id="date-id",
key="date",
type="STRING",
required=True,
default=None,
),
VellumVariable(
id="context-id",
key="context",
type="STRING",
required=False,
default=None,
),
]
vellum_client.workflow_deployments.retrieve_workflow_deployment_release.return_value = mock_release

# WHEN we serialize it
workflow_display = get_workflow_display(workflow_class=workflow_class)
serialized_workflow: dict = workflow_display.serialize()

# THEN we should get a serialized representation of the Workflow
assert serialized_workflow.keys() == {
"workflow_raw_data",
"input_variables",
"state_variables",
"output_variables",
}

# AND its input variables should include both query and context
input_variables = serialized_workflow["input_variables"]
assert len(input_variables) == 2
input_keys = {var["key"] for var in input_variables}
assert input_keys == {"query", "context"}

# AND its output variables should be what we expect
output_variables = serialized_workflow["output_variables"]
assert len(output_variables) == 2
output_keys = {var["key"] for var in output_variables}
assert output_keys == {"text", "chat_history"}

# AND the workflow deployment tool should have the definition attribute serialized
workflow_raw_data = serialized_workflow["workflow_raw_data"]
tool_calling_node = workflow_raw_data["nodes"][1]
function_attributes = next(
attribute for attribute in tool_calling_node["attributes"] if attribute["name"] == "functions"
)
assert function_attributes["value"]["type"] == "CONSTANT_VALUE"
assert function_attributes["value"]["value"]["type"] == "JSON"
workflow_deployment_tool = function_attributes["value"]["value"]["value"][0]

# AND the workflow deployment tool should have the correct type
assert workflow_deployment_tool["type"] == "WORKFLOW_DEPLOYMENT"
assert workflow_deployment_tool["name"] == "weather-workflow-deployment"

# AND the workflow deployment tool should have a definition attribute (like code tool)
assert "definition" in workflow_deployment_tool
definition = workflow_deployment_tool["definition"]

# AND the definition should match the expected structure with inputs and examples
context_var = next(var for var in input_variables if var["key"] == "context")
context_input_variable_id = context_var["id"]

assert definition == {
"state": None,
"cache_config": None,
"name": "weatherworkflowdeployment",
"description": "A workflow that gets the weather for a given city and date.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"date": {"type": "string"},
},
"required": ["city", "date"],
"examples": [{"city": "San Francisco", "date": "2025-01-01"}],
},
"inputs": {
"context": {
"type": "WORKFLOW_INPUT",
"input_variable_id": context_input_variable_id,
}
},
"forced": None,
"strict": None,
}
36 changes: 35 additions & 1 deletion ee/vellum_ee/workflows/display/utils/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,14 @@
from vellum.workflows.references.workflow_input import WorkflowInputReference
from vellum.workflows.state.base import BaseState
from vellum.workflows.types.core import JsonArray, JsonObject
from vellum.workflows.types.definition import DeploymentDefinition
from vellum.workflows.types.generics import is_workflow_class
from vellum.workflows.utils.files import virtual_open
from vellum.workflows.utils.functions import compile_function_definition, compile_inline_workflow_function_definition
from vellum.workflows.utils.functions import (
compile_function_definition,
compile_inline_workflow_function_definition,
compile_workflow_deployment_function_definition,
)
from vellum.workflows.utils.uuids import uuid4_from_hash
from vellum.workflows.workflows.base import BaseWorkflow
from vellum_ee.workflows.display.utils.exceptions import (
Expand Down Expand Up @@ -485,6 +490,35 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex
},
}

if isinstance(value, DeploymentDefinition):
# Handle DeploymentDefinition with tool wrapper support (similar to inline workflows)
context = {"executable_id": executable_id, "client": display_context.client}
dict_value = value.model_dump(context=context)

# Compile the function definition (which now handles __vellum_inputs__ and __vellum_examples__)
if display_context.client is not None:
function_definition = compile_workflow_deployment_function_definition(value, display_context.client)

# Handle __vellum_inputs__ for workflow deployments (similar to function tools)
inputs = getattr(value, "__vellum_inputs__", {})
if inputs:
serialized_inputs = {}
for param_name, input_ref in inputs.items():
serialized_input = serialize_value(executable_id, display_context, input_ref)
if serialized_input is not None:
serialized_inputs[param_name] = serialized_input

model_data = function_definition.model_dump()
model_data["inputs"] = serialized_inputs
function_definition_data = model_data
else:
function_definition_data = function_definition.model_dump()

dict_value["definition"] = function_definition_data

dict_ref = serialize_value(executable_id, display_context, dict_value)
return dict_ref

if isinstance(value, BaseModel):
context = {"executable_id": executable_id, "client": display_context.client}
dict_value = value.model_dump(context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,24 @@ class RouterNode(BaseNode[ToolCallingState]):
class DynamicSubworkflowDeploymentNode(SubworkflowDeploymentNode[ToolCallingState], FunctionCallNodeMixin):
"""Node that executes a deployment definition with function call output."""

deployment_definition: Optional[DeploymentDefinition] = None

def run(self) -> Iterator[BaseOutput]:
# Merge arguments with resolved inputs from __vellum_inputs__
merged_inputs = self.arguments.copy()
if self.deployment_definition is not None:
vellum_inputs = getattr(self.deployment_definition, "__vellum_inputs__", {})
if vellum_inputs:
for param_name, param_ref in vellum_inputs.items():
if isinstance(param_ref, BaseDescriptor):
resolved_value = param_ref.resolve(self.state)
else:
resolved_value = param_ref
merged_inputs[param_name] = resolved_value

# Mypy doesn't like instance assignments of class attributes. It's safe in our case tho bc it's what
# we do in the `__init__` method.
self.subworkflow_inputs = self.arguments # type:ignore[misc]
self.subworkflow_inputs = merged_inputs # type:ignore[misc]

# Call the parent run method to execute the subworkflow
outputs = {}
Expand Down Expand Up @@ -507,6 +521,7 @@ def create_function_node(
{
"deployment": deployment,
"release_tag": release_tag,
"deployment_definition": function,
"arguments": arguments_expr,
"function_call_id": function_call_id_expr,
"__module__": __name__,
Expand Down
36 changes: 29 additions & 7 deletions src/vellum/workflows/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Literal,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
Expand Down Expand Up @@ -304,16 +305,27 @@ def compile_workflow_deployment_function_definition(
description = release_info["description"]
input_variables = release_info["input_variables"]

# Get inputs from the decorator if present (to exclude from schema)
inputs = getattr(deployment_definition, "__vellum_inputs__", {})
examples = getattr(deployment_definition, "__vellum_examples__", None)
exclude_params = set(inputs.keys())

properties = {}
required = []

for input_var in input_variables:
# Skip parameters that are in the exclude_params set
if exclude_params and input_var.key in exclude_params:
continue

properties[input_var.key] = _compile_workflow_deployment_input(input_var)

if input_var.required and input_var.default is None:
required.append(input_var.key)

parameters = {"type": "object", "properties": properties, "required": required}
if examples is not None:
parameters["examples"] = examples

return FunctionDefinition(
name=name.replace("-", ""),
Expand Down Expand Up @@ -405,16 +417,21 @@ def compile_vellum_integration_tool_definition(
return FunctionDefinition(name=tool_def.name, description=tool_def.description, parameters={})


ToolType = Union[Callable[..., Any], Type["BaseWorkflow"]]
ToolT = TypeVar(
"ToolT",
Callable[..., Any],
Type["BaseWorkflow[Any, Any]"],
DeploymentDefinition,
)


def tool(
*,
inputs: Optional[dict[str, Any]] = None,
examples: Optional[List[dict[str, Any]]] = None,
) -> Callable[[ToolType], ToolType]:
) -> Callable[[ToolT], ToolT]:
"""
Decorator to configure a tool function or inline workflow.
Decorator to configure a tool function, inline workflow, or workflow deployment.

Currently supports specifying which parameters should come from parent workflow inputs
via the `inputs` mapping. Also supports providing `examples` which will be hoisted
Expand All @@ -440,21 +457,26 @@ class MyInlineWorkflow(BaseWorkflow):

class Outputs(BaseWorkflow.Outputs):
result = MyNode.Outputs.result

Example with workflow deployment:
tool(inputs={
"context": ParentInputs.context,
})(DeploymentDefinition(deployment="my-workflow-deployment"))
"""

def decorator(func: ToolType) -> ToolType:
# Store the inputs mapping on the function/workflow for later use
def decorator(func: ToolT) -> ToolT:
# Store the inputs mapping on the function/workflow/deployment for later use
if inputs is not None:
setattr(func, "__vellum_inputs__", inputs)
# Store the examples on the function/workflow for later use
# Store the examples on the function/workflow/deployment for later use
if examples is not None:
setattr(func, "__vellum_examples__", examples)
return func

return decorator


def use_tool_inputs(**inputs: Any) -> Callable[[Callable], Callable]:
def use_tool_inputs(**inputs: Any) -> Callable[[ToolT], ToolT]:
"""
Decorator to specify which parameters of a tool function should be provided
from the parent workflow inputs rather than from the LLM.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .workflow import BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow

__all__ = ["BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow"]
Loading
Loading