diff --git a/ee/codegen/src/__test__/nodes/__snapshots__/tool-calling-node.test.ts.snap b/ee/codegen/src/__test__/nodes/__snapshots__/tool-calling-node.test.ts.snap index 7511d329c..daa429570 100644 --- a/ee/codegen/src/__test__/nodes/__snapshots__/tool-calling-node.test.ts.snap +++ b/ee/codegen/src/__test__/nodes/__snapshots__/tool-calling-node.test.ts.snap @@ -244,6 +244,34 @@ class ToolCallingNode(BaseToolCallingNode): " `; +exports[`ToolCallingNode > inline workflow > should generate inline workflow with tool wrapper when definition has inputs and examples 1`] = ` +"from .get_weather.workflow import GetWeather + +from vellum.workflows.nodes.displayable.tool_calling_node import ( + ToolCallingNode as BaseToolCallingNode, +) +from vellum.workflows.utils.functions import tool + +from ..inputs import Inputs + + +class ToolCallingNode(BaseToolCallingNode): + functions = [ + tool( + inputs={ + "context": Inputs.location, + }, + examples=[ + { + "city": "San Francisco", + "date": "2025-01-01", + }, + ], + )(GetWeather) + ] +" +`; + exports[`ToolCallingNode > input variables > should generate input variables 1`] = ` "from vellum import ( ChatMessagePromptBlock, diff --git a/ee/codegen/src/__test__/nodes/tool-calling-node.test.ts b/ee/codegen/src/__test__/nodes/tool-calling-node.test.ts index c1bf1420e..45704f61c 100644 --- a/ee/codegen/src/__test__/nodes/tool-calling-node.test.ts +++ b/ee/codegen/src/__test__/nodes/tool-calling-node.test.ts @@ -500,6 +500,91 @@ describe("ToolCallingNode", () => { node.getNodeFile().write(writer); expect(await writer.toStringFormatted()).toMatchSnapshot(); }); + + it("should generate inline workflow with tool wrapper when definition has inputs and examples", async () => { + const nodePortData: NodePort[] = [ + nodePortFactory({ + id: "port-id", + }), + ]; + + const inlineWorkflowWithToolWrapper = { + name: "GetWeather", + type: "INLINE_WORKFLOW", + description: "Get weather for a city", + definition: { + name: "get_weather", + description: "Get weather for a city", + parameters: { + type: "object", + properties: { + city: { type: "string" }, + date: { type: "string" }, + }, + required: ["city", "date"], + examples: [{ city: "San Francisco", date: "2025-01-01" }], + }, + inputs: { + context: { + type: "WORKFLOW_INPUT", + input_variable_id: "input-1", + }, + }, + state: null, + cache_config: null, + forced: null, + strict: null, + }, + exec_config: { + runner_config: {}, + input_variables: [ + { id: "city-input", key: "city", type: "STRING" }, + { id: "date-input", key: "date", type: "STRING" }, + { id: "context-input", key: "context", type: "STRING" }, + ], + state_variables: [], + output_variables: [ + { id: "output-1", key: "temperature", type: "NUMBER" }, + ], + workflow_raw_data: { + edges: [], + nodes: [], + definition: null, + output_values: [], + }, + }, + }; + + const functionsAttribute = nodeAttributeFactory( + "functions-attr-id", + "functions", + { + type: "CONSTANT_VALUE", + value: { + type: "JSON", + value: [inlineWorkflowWithToolWrapper], + }, + } + ); + + const nodeData = toolCallingNodeFactory({ + nodePorts: nodePortData, + nodeAttributes: [functionsAttribute], + }); + + const nodeContext = (await createNodeContext({ + workflowContext, + nodeData, + })) as GenericNodeContext; + + const node = new GenericNode({ + workflowContext, + nodeContext, + }); + + node.getNodeFile().write(writer); + expect(await writer.toStringFormatted()).toMatchSnapshot(); + }); }); describe("workflow deployment", () => { diff --git a/ee/codegen/src/generators/nodes/generic-node.ts b/ee/codegen/src/generators/nodes/generic-node.ts index 9096a2357..10f20981b 100644 --- a/ee/codegen/src/generators/nodes/generic-node.ts +++ b/ee/codegen/src/generators/nodes/generic-node.ts @@ -3,6 +3,7 @@ import { join } from "path"; import { python } from "@fern-api/python-ast"; import { Field } from "@fern-api/python-ast/Field"; +import { FunctionDefinition } from "vellum-ai/api/types"; import { PromptBlock as PromptBlockSerializer, PromptParameters as PromptParametersSerializer, @@ -317,35 +318,16 @@ export class GenericNode extends BaseNode { const codeExecutionFunction = f as FunctionArgs; this.generateFunctionFile([codeExecutionFunction]); const snakeName = toPythonSafeSnakeCase(codeExecutionFunction.name); - // Use toValidPythonIdentifier to ensure the name is safe for Python references - // but preserve original casing when possible (see APO-1372) const safeName = toValidPythonIdentifier(codeExecutionFunction.name); const functionReference = python.reference({ - name: safeName, // Use safe Python identifier that preserves original casing - modulePath: [`.${snakeName}`], // Import from snake_case module + name: safeName, + modulePath: [`.${snakeName}`], }); - // Check if function has inputs or examples that need to be wrapped with tool() - const parsedInputs = this.parseToolInputs(codeExecutionFunction); - // Read examples from definition.parameters.examples (JSON Schema examples keyword) - const parameters = codeExecutionFunction.definition?.parameters as - | Record - | undefined; - const examples = - (parameters?.examples as Array>) ?? null; - const hasInputs = parsedInputs && Object.keys(parsedInputs).length > 0; - const hasExamples = Array.isArray(examples) && examples.length > 0; - - if (hasInputs || hasExamples) { - // Wrap the function reference with tool(...)(func) - const wrapper = this.getToolInvocation(parsedInputs, examples); - return new WrappedCall({ - wrapper, - inner: functionReference, - }); - } - - return functionReference; + return this.wrapWithToolIfNeeded( + functionReference, + codeExecutionFunction.definition + ); } private handleInlineWorkflowFunction(f: ToolArgs): AstNode | null { @@ -390,10 +372,15 @@ export class GenericNode extends BaseNode { const workflowClassName = nestedWorkflowProject.workflowContext.workflowClassName; - return python.reference({ + const workflowReference = python.reference({ name: workflowClassName, modulePath: [`.${workflowName}`, GENERATED_WORKFLOW_MODULE_NAME], }); + + return this.wrapWithToolIfNeeded( + workflowReference, + inlineWorkflow.definition + ); } return null; @@ -870,11 +857,11 @@ export class GenericNode extends BaseNode { * Parses the tool inputs from a function definition. * Returns null if there are no inputs or if parsing fails. */ - private parseToolInputs( - f: FunctionArgs + private parseToolInputsFromDefinition( + definition: FunctionDefinition | undefined ): Record | null { - const inputs = f.definition?.inputs; - if (!f.definition || !inputs) { + const inputs = definition?.inputs; + if (!definition || !inputs) { return null; } @@ -901,9 +888,45 @@ export class GenericNode extends BaseNode { return parsedInputs; } - /** - * Creates a tool(inputs={...}, examples=[...]) method invocation for wrapping function references. - */ + private parseToolExamplesFromDefinition( + definition: FunctionDefinition | undefined + ): Array> | null { + const parameters = definition?.parameters as + | Record + | undefined + | null; + if (!parameters) { + return null; + } + + const examples = parameters.examples as + | Array> + | undefined; + if (!Array.isArray(examples) || examples.length === 0) { + return null; + } + + return examples; + } + + private wrapWithToolIfNeeded( + inner: python.AstNode, + definition: FunctionDefinition | undefined + ): python.AstNode { + const parsedInputs = this.parseToolInputsFromDefinition(definition); + const examples = this.parseToolExamplesFromDefinition(definition); + + if (parsedInputs || examples) { + const wrapper = this.getToolInvocation(parsedInputs, examples); + return new WrappedCall({ + wrapper, + inner, + }); + } + + return inner; + } + private getToolInvocation( inputs: Record | null, examples: Array> | null diff --git a/ee/codegen/src/types/vellum.ts b/ee/codegen/src/types/vellum.ts index f8a38d9af..c09a39337 100644 --- a/ee/codegen/src/types/vellum.ts +++ b/ee/codegen/src/types/vellum.ts @@ -1073,6 +1073,7 @@ export type FunctionArgs = { export type InlineWorkflowFunctionArgs = { type: "INLINE_WORKFLOW"; exec_config: WorkflowVersionExecConfig; + definition?: FunctionDefinition; } & NameDescription; export type WorkflowDeploymentFunctionArgs = {