diff --git a/ee/vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_workflow_deployment_tool_wrapper_serialization.py b/ee/vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_workflow_deployment_tool_wrapper_serialization.py new file mode 100644 index 000000000..f992856f3 --- /dev/null +++ b/ee/vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_workflow_deployment_tool_wrapper_serialization.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock + +from vellum.client.types.vellum_variable import VellumVariable +from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display + +from tests.workflows.basic_tool_calling_node_workflow_deployment_tool_wrapper.workflow import ( + BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow, +) + + +def test_serialize_workflow__workflow_deployment_with_tool_wrapper(vellum_client): + """ + Tests that a workflow deployment with tool wrapper serializes correctly with definition. + """ + + # GIVEN a workflow that uses a workflow deployment with tool wrapper + workflow_class = BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow + + # AND a mock for the workflow deployment release info + mock_release = MagicMock() + mock_release.deployment.name = "weather-workflow-deployment" + mock_release.description = "A workflow that gets the weather for a given city and date." + mock_release.workflow_version.input_variables = [ + VellumVariable( + id="city-id", + key="city", + type="STRING", + required=True, + default=None, + ), + VellumVariable( + id="date-id", + key="date", + type="STRING", + required=True, + default=None, + ), + VellumVariable( + id="context-id", + key="context", + type="STRING", + required=False, + default=None, + ), + ] + vellum_client.workflow_deployments.retrieve_workflow_deployment_release.return_value = mock_release + + # WHEN we serialize it + workflow_display = get_workflow_display(workflow_class=workflow_class) + serialized_workflow: dict = workflow_display.serialize() + + # THEN we should get a serialized representation of the Workflow + assert serialized_workflow.keys() == { + "workflow_raw_data", + "input_variables", + "state_variables", + "output_variables", + } + + # AND its input variables should include both query and context + input_variables = serialized_workflow["input_variables"] + assert len(input_variables) == 2 + input_keys = {var["key"] for var in input_variables} + assert input_keys == {"query", "context"} + + # AND its output variables should be what we expect + output_variables = serialized_workflow["output_variables"] + assert len(output_variables) == 2 + output_keys = {var["key"] for var in output_variables} + assert output_keys == {"text", "chat_history"} + + # AND the workflow deployment tool should have the definition attribute serialized + workflow_raw_data = serialized_workflow["workflow_raw_data"] + tool_calling_node = workflow_raw_data["nodes"][1] + function_attributes = next( + attribute for attribute in tool_calling_node["attributes"] if attribute["name"] == "functions" + ) + assert function_attributes["value"]["type"] == "CONSTANT_VALUE" + assert function_attributes["value"]["value"]["type"] == "JSON" + workflow_deployment_tool = function_attributes["value"]["value"]["value"][0] + + # AND the workflow deployment tool should have the correct type + assert workflow_deployment_tool["type"] == "WORKFLOW_DEPLOYMENT" + assert workflow_deployment_tool["name"] == "weather-workflow-deployment" + + # AND the workflow deployment tool should have a definition attribute (like code tool) + assert "definition" in workflow_deployment_tool + definition = workflow_deployment_tool["definition"] + + # AND the definition should match the expected structure with inputs and examples + context_var = next(var for var in input_variables if var["key"] == "context") + context_input_variable_id = context_var["id"] + + assert definition == { + "state": None, + "cache_config": None, + "name": "weatherworkflowdeployment", + "description": "A workflow that gets the weather for a given city and date.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "date": {"type": "string"}, + }, + "required": ["city", "date"], + "examples": [{"city": "San Francisco", "date": "2025-01-01"}], + }, + "inputs": { + "context": { + "type": "WORKFLOW_INPUT", + "input_variable_id": context_input_variable_id, + } + }, + "forced": None, + "strict": None, + } diff --git a/ee/vellum_ee/workflows/display/utils/expressions.py b/ee/vellum_ee/workflows/display/utils/expressions.py index a42023d0f..12229610a 100644 --- a/ee/vellum_ee/workflows/display/utils/expressions.py +++ b/ee/vellum_ee/workflows/display/utils/expressions.py @@ -56,9 +56,14 @@ from vellum.workflows.references.workflow_input import WorkflowInputReference from vellum.workflows.state.base import BaseState from vellum.workflows.types.core import JsonArray, JsonObject +from vellum.workflows.types.definition import DeploymentDefinition from vellum.workflows.types.generics import is_workflow_class from vellum.workflows.utils.files import virtual_open -from vellum.workflows.utils.functions import compile_function_definition, compile_inline_workflow_function_definition +from vellum.workflows.utils.functions import ( + compile_function_definition, + compile_inline_workflow_function_definition, + compile_workflow_deployment_function_definition, +) from vellum.workflows.utils.uuids import uuid4_from_hash from vellum.workflows.workflows.base import BaseWorkflow from vellum_ee.workflows.display.utils.exceptions import ( @@ -485,6 +490,35 @@ def serialize_value(executable_id: UUID, display_context: "WorkflowDisplayContex }, } + if isinstance(value, DeploymentDefinition): + # Handle DeploymentDefinition with tool wrapper support (similar to inline workflows) + context = {"executable_id": executable_id, "client": display_context.client} + dict_value = value.model_dump(context=context) + + # Compile the function definition (which now handles __vellum_inputs__ and __vellum_examples__) + if display_context.client is not None: + function_definition = compile_workflow_deployment_function_definition(value, display_context.client) + + # Handle __vellum_inputs__ for workflow deployments (similar to function tools) + inputs = getattr(value, "__vellum_inputs__", {}) + if inputs: + serialized_inputs = {} + for param_name, input_ref in inputs.items(): + serialized_input = serialize_value(executable_id, display_context, input_ref) + if serialized_input is not None: + serialized_inputs[param_name] = serialized_input + + model_data = function_definition.model_dump() + model_data["inputs"] = serialized_inputs + function_definition_data = model_data + else: + function_definition_data = function_definition.model_dump() + + dict_value["definition"] = function_definition_data + + dict_ref = serialize_value(executable_id, display_context, dict_value) + return dict_ref + if isinstance(value, BaseModel): context = {"executable_id": executable_id, "client": display_context.client} dict_value = value.model_dump(context=context) diff --git a/src/vellum/workflows/nodes/displayable/tool_calling_node/utils.py b/src/vellum/workflows/nodes/displayable/tool_calling_node/utils.py index 03629985c..4fb59c828 100644 --- a/src/vellum/workflows/nodes/displayable/tool_calling_node/utils.py +++ b/src/vellum/workflows/nodes/displayable/tool_calling_node/utils.py @@ -153,10 +153,24 @@ class RouterNode(BaseNode[ToolCallingState]): class DynamicSubworkflowDeploymentNode(SubworkflowDeploymentNode[ToolCallingState], FunctionCallNodeMixin): """Node that executes a deployment definition with function call output.""" + deployment_definition: Optional[DeploymentDefinition] = None + def run(self) -> Iterator[BaseOutput]: + # Merge arguments with resolved inputs from __vellum_inputs__ + merged_inputs = self.arguments.copy() + if self.deployment_definition is not None: + vellum_inputs = getattr(self.deployment_definition, "__vellum_inputs__", {}) + if vellum_inputs: + for param_name, param_ref in vellum_inputs.items(): + if isinstance(param_ref, BaseDescriptor): + resolved_value = param_ref.resolve(self.state) + else: + resolved_value = param_ref + merged_inputs[param_name] = resolved_value + # Mypy doesn't like instance assignments of class attributes. It's safe in our case tho bc it's what # we do in the `__init__` method. - self.subworkflow_inputs = self.arguments # type:ignore[misc] + self.subworkflow_inputs = merged_inputs # type:ignore[misc] # Call the parent run method to execute the subworkflow outputs = {} @@ -507,6 +521,7 @@ def create_function_node( { "deployment": deployment, "release_tag": release_tag, + "deployment_definition": function, "arguments": arguments_expr, "function_call_id": function_call_id_expr, "__module__": __name__, diff --git a/src/vellum/workflows/utils/functions.py b/src/vellum/workflows/utils/functions.py index 19a48bdcd..9af5e337a 100644 --- a/src/vellum/workflows/utils/functions.py +++ b/src/vellum/workflows/utils/functions.py @@ -11,6 +11,7 @@ Literal, Optional, Type, + TypeVar, Union, get_args, get_origin, @@ -304,16 +305,27 @@ def compile_workflow_deployment_function_definition( description = release_info["description"] input_variables = release_info["input_variables"] + # Get inputs from the decorator if present (to exclude from schema) + inputs = getattr(deployment_definition, "__vellum_inputs__", {}) + examples = getattr(deployment_definition, "__vellum_examples__", None) + exclude_params = set(inputs.keys()) + properties = {} required = [] for input_var in input_variables: + # Skip parameters that are in the exclude_params set + if exclude_params and input_var.key in exclude_params: + continue + properties[input_var.key] = _compile_workflow_deployment_input(input_var) if input_var.required and input_var.default is None: required.append(input_var.key) parameters = {"type": "object", "properties": properties, "required": required} + if examples is not None: + parameters["examples"] = examples return FunctionDefinition( name=name.replace("-", ""), @@ -405,16 +417,21 @@ def compile_vellum_integration_tool_definition( return FunctionDefinition(name=tool_def.name, description=tool_def.description, parameters={}) -ToolType = Union[Callable[..., Any], Type["BaseWorkflow"]] +ToolT = TypeVar( + "ToolT", + Callable[..., Any], + Type["BaseWorkflow[Any, Any]"], + DeploymentDefinition, +) def tool( *, inputs: Optional[dict[str, Any]] = None, examples: Optional[List[dict[str, Any]]] = None, -) -> Callable[[ToolType], ToolType]: +) -> Callable[[ToolT], ToolT]: """ - Decorator to configure a tool function or inline workflow. + Decorator to configure a tool function, inline workflow, or workflow deployment. Currently supports specifying which parameters should come from parent workflow inputs via the `inputs` mapping. Also supports providing `examples` which will be hoisted @@ -440,13 +457,18 @@ class MyInlineWorkflow(BaseWorkflow): class Outputs(BaseWorkflow.Outputs): result = MyNode.Outputs.result + + Example with workflow deployment: + tool(inputs={ + "context": ParentInputs.context, + })(DeploymentDefinition(deployment="my-workflow-deployment")) """ - def decorator(func: ToolType) -> ToolType: - # Store the inputs mapping on the function/workflow for later use + def decorator(func: ToolT) -> ToolT: + # Store the inputs mapping on the function/workflow/deployment for later use if inputs is not None: setattr(func, "__vellum_inputs__", inputs) - # Store the examples on the function/workflow for later use + # Store the examples on the function/workflow/deployment for later use if examples is not None: setattr(func, "__vellum_examples__", examples) return func @@ -454,7 +476,7 @@ def decorator(func: ToolType) -> ToolType: return decorator -def use_tool_inputs(**inputs: Any) -> Callable[[Callable], Callable]: +def use_tool_inputs(**inputs: Any) -> Callable[[ToolT], ToolT]: """ Decorator to specify which parameters of a tool function should be provided from the parent workflow inputs rather than from the LLM. diff --git a/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/__init__.py b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/__init__.py new file mode 100644 index 000000000..1e59e41a5 --- /dev/null +++ b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/__init__.py @@ -0,0 +1,3 @@ +from .workflow import BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow + +__all__ = ["BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow"] diff --git a/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/tests/__init__.py b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/tests/test_workflow.py b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/tests/test_workflow.py new file mode 100644 index 000000000..662624db3 --- /dev/null +++ b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/tests/test_workflow.py @@ -0,0 +1,188 @@ +from uuid import uuid4 +from typing import Iterator, List + +from vellum.client.types.chat_message import ChatMessage +from vellum.client.types.execute_prompt_event import ExecutePromptEvent +from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent +from vellum.client.types.function_call import FunctionCall +from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent +from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue +from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue +from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent +from vellum.client.types.prompt_output import PromptOutput +from vellum.client.types.string_chat_message_content import StringChatMessageContent +from vellum.client.types.string_vellum_value import StringVellumValue +from vellum.client.types.workflow_deployment_release import WorkflowDeploymentRelease +from vellum.client.types.workflow_execution_workflow_result_event import WorkflowExecutionWorkflowResultEvent +from vellum.client.types.workflow_output_string import WorkflowOutputString +from vellum.client.types.workflow_result_event import WorkflowResultEvent +from vellum.workflows.events.workflow import WorkflowExecutionFulfilledEvent + +from tests.workflows.basic_tool_calling_node_workflow_deployment_tool_wrapper.workflow import ( + BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow, + WorkflowInputs, +) + + +def test_workflow_deployment_tool_wrapper__merges_inputs_from_parent(vellum_adhoc_prompt_client, vellum_client): + """ + Tests that a workflow deployment with tool wrapper correctly merges inputs from the parent workflow. + """ + + # GIVEN a mock that returns function call events followed by a final response + def generate_prompt_events(*args, **kwargs) -> Iterator[ExecutePromptEvent]: # noqa: U100 + execution_id = str(uuid4()) + + call_count = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.call_count + expected_outputs: List[PromptOutput] + if call_count == 1: + expected_outputs = [ + FunctionCallVellumValue( + value=FunctionCall( + arguments={"city": "San Francisco", "date": "2025-01-01"}, + id="call_workflow_deployment", + name="weatherworkflowdeployment", + state="FULFILLED", + ), + ), + ] + else: + expected_outputs = [ + StringVellumValue( + value="Based on the function call, the current temperature in San Francisco is 70 degrees." + ) + ] + + events: List[ExecutePromptEvent] = [ + InitiatedExecutePromptEvent(execution_id=execution_id), + FulfilledExecutePromptEvent( + execution_id=execution_id, + outputs=expected_outputs, + ), + ] + yield from events + + vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events + + # AND a mock for the workflow deployment release info + mock_workflow_deployment_release = WorkflowDeploymentRelease( + id="mock-deployment-id", + created="2024-01-01T00:00:00Z", + environment={"id": "mock-env-id", "name": "mock-env-name", "label": "mock-env-label"}, + created_by={"id": "mock-user-id", "email": "mock@example.com"}, + workflow_version={ + "id": "mock-version-id", + "input_variables": [ + {"id": "city-input-id", "key": "city", "type": "STRING", "required": True, "default": None}, + {"id": "date-input-id", "key": "date", "type": "STRING", "required": True, "default": None}, + {"id": "context-input-id", "key": "context", "type": "STRING", "required": False, "default": None}, + ], + "output_variables": [], + }, + deployment={"id": "mock-deployment-id", "name": "weather-workflow-deployment"}, + description="A workflow that gets the weather for a given city and date.", + release_tags=[], + reviews=[], + ) + vellum_client.workflow_deployments.retrieve_workflow_deployment_release.return_value = ( + mock_workflow_deployment_release + ) + + # AND a mock for the workflow deployment execution that includes the context + def mock_workflow_execution(*args, **kwargs): # noqa: U100 + # Check that the context was passed to the workflow deployment + inputs = kwargs.get("inputs", []) + context_input = next((i for i in inputs if i.name == "context"), None) + context_value = context_input.value if context_input else "" + + yield WorkflowExecutionWorkflowResultEvent( + execution_id="mock-execution-id", + type="WORKFLOW", + data=WorkflowResultEvent(id="mock-event-id", state="INITIATED", ts="2024-01-01T00:00:00Z"), + ) + yield WorkflowExecutionWorkflowResultEvent( + execution_id="mock-execution-id", + type="WORKFLOW", + data=WorkflowResultEvent(id="mock-event-id", state="STREAMING", ts="2024-01-01T00:00:00Z"), + ) + yield WorkflowExecutionWorkflowResultEvent( + execution_id="mock-execution-id", + type="WORKFLOW", + data=WorkflowResultEvent( + id="mock-event-id", + state="FULFILLED", + ts="2024-01-01T00:00:00Z", + outputs=[ + WorkflowOutputString( + id="mock-output-id", + name="result", + type="STRING", + value=f"The weather in San Francisco on 2025-01-01 was hot. Context: {context_value}", + ) + ], + ), + ) + + vellum_client.execute_workflow_stream.side_effect = mock_workflow_execution + + # AND a workflow that uses a workflow deployment with tool wrapper + workflow = BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow() + + # WHEN the workflow is executed + terminal_event = workflow.run( + inputs=WorkflowInputs( + query="What's the weather like in San Francisco?", + context="This is additional context from parent workflow", + ) + ) + + # THEN the workflow should complete successfully + assert terminal_event.name == "workflow.execution.fulfilled" + assert isinstance(terminal_event, WorkflowExecutionFulfilledEvent) + + # AND the output should contain the expected text + assert terminal_event.outputs.text == ( + "Based on the function call, the current temperature in San Francisco is 70 degrees." + ) + + # AND the chat history should include the function call with merged context + chat_history = terminal_event.outputs.chat_history + assert len(chat_history) == 3 + + # AND the function result should include the context from the parent workflow + function_result_message = chat_history[1] + assert function_result_message.role == "FUNCTION" + assert isinstance(function_result_message.content, StringChatMessageContent) + assert "This is additional context from parent workflow" in function_result_message.content.value + + # AND the chat history should have the expected structure + assert chat_history == [ + ChatMessage( + text=None, + role="ASSISTANT", + content=FunctionCallChatMessageContent( + type="FUNCTION_CALL", + value=FunctionCallChatMessageContentValue( + name="weatherworkflowdeployment", + arguments={"city": "San Francisco", "date": "2025-01-01"}, + id="call_workflow_deployment", + ), + ), + source=None, + ), + ChatMessage( + text=None, + role="FUNCTION", + content=StringChatMessageContent( + type="STRING", + value='{"result": "The weather in San Francisco on 2025-01-01 was hot. Context: This is additional context from parent workflow"}', # noqa: E501 + ), + source="call_workflow_deployment", + ), + ChatMessage( + text="Based on the function call, the current temperature in San Francisco is 70 degrees.", + role="ASSISTANT", + content=None, + source=None, + ), + ] diff --git a/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/workflow.py b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/workflow.py new file mode 100644 index 000000000..e27d325d0 --- /dev/null +++ b/tests/workflows/basic_tool_calling_node_workflow_deployment_tool_wrapper/workflow.py @@ -0,0 +1,76 @@ +from vellum.client.types.chat_message_prompt_block import ChatMessagePromptBlock +from vellum.client.types.plain_text_prompt_block import PlainTextPromptBlock +from vellum.client.types.rich_text_prompt_block import RichTextPromptBlock +from vellum.client.types.variable_prompt_block import VariablePromptBlock +from vellum.workflows import BaseWorkflow +from vellum.workflows.inputs import BaseInputs +from vellum.workflows.nodes.displayable.tool_calling_node import ToolCallingNode +from vellum.workflows.state import BaseState +from vellum.workflows.types.definition import DeploymentDefinition +from vellum.workflows.utils.functions import tool + + +class WorkflowInputs(BaseInputs): + query: str + context: str + + +workflow_deployment_tool = tool( + inputs={"context": WorkflowInputs.context}, + examples=[{"city": "San Francisco", "date": "2025-01-01"}], +)( + DeploymentDefinition( + deployment="weather-workflow-deployment", + release_tag="LATEST", + ) +) + + +class GetCurrentWeatherNode(ToolCallingNode): + """ + A tool calling node that calls a workflow deployment with tool wrapper. + """ + + ml_model = "gpt-4o-mini" + blocks = [ + ChatMessagePromptBlock( + chat_role="SYSTEM", + blocks=[ + RichTextPromptBlock( + blocks=[ + PlainTextPromptBlock( + text="You are a weather expert", + ), + ], + ), + ], + ), + ChatMessagePromptBlock( + chat_role="USER", + blocks=[ + RichTextPromptBlock( + blocks=[ + VariablePromptBlock( + input_variable="question", + ), + ], + ), + ], + ), + ] + functions = [workflow_deployment_tool] + prompt_inputs = { + "question": WorkflowInputs.query, + } + + +class BasicToolCallingNodeWorkflowDeploymentToolWrapperWorkflow(BaseWorkflow[WorkflowInputs, BaseState]): + """ + A workflow that uses the GetCurrentWeatherNode with workflow deployment tool wrapper. + """ + + graph = GetCurrentWeatherNode + + class Outputs(BaseWorkflow.Outputs): + text = GetCurrentWeatherNode.Outputs.text + chat_history = GetCurrentWeatherNode.Outputs.chat_history