Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,34 @@ class ToolCallingNode(BaseToolCallingNode):
"
`;

exports[`ToolCallingNode > inline workflow > should generate inline workflow with tool wrapper when definition has inputs and examples 1`] = `
"from .get_weather.workflow import GetWeather

from vellum.workflows.nodes.displayable.tool_calling_node import (
ToolCallingNode as BaseToolCallingNode,
)
from vellum.workflows.utils.functions import tool

from ..inputs import Inputs


class ToolCallingNode(BaseToolCallingNode):
functions = [
tool(
inputs={
"context": Inputs.location,
},
examples=[
{
"city": "San Francisco",
"date": "2025-01-01",
},
],
)(GetWeather)
]
"
`;

exports[`ToolCallingNode > input variables > should generate input variables 1`] = `
"from vellum import (
ChatMessagePromptBlock,
Expand Down
85 changes: 85 additions & 0 deletions ee/codegen/src/__test__/nodes/tool-calling-node.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,91 @@ describe("ToolCallingNode", () => {
node.getNodeFile().write(writer);
expect(await writer.toStringFormatted()).toMatchSnapshot();
});

it("should generate inline workflow with tool wrapper when definition has inputs and examples", async () => {
const nodePortData: NodePort[] = [
nodePortFactory({
id: "port-id",
}),
];

const inlineWorkflowWithToolWrapper = {
name: "GetWeather",
type: "INLINE_WORKFLOW",
description: "Get weather for a city",
definition: {
name: "get_weather",
description: "Get weather for a city",
parameters: {
type: "object",
properties: {
city: { type: "string" },
date: { type: "string" },
},
required: ["city", "date"],
examples: [{ city: "San Francisco", date: "2025-01-01" }],
},
inputs: {
context: {
type: "WORKFLOW_INPUT",
input_variable_id: "input-1",
},
},
state: null,
cache_config: null,
forced: null,
strict: null,
},
exec_config: {
runner_config: {},
input_variables: [
{ id: "city-input", key: "city", type: "STRING" },
{ id: "date-input", key: "date", type: "STRING" },
{ id: "context-input", key: "context", type: "STRING" },
],
state_variables: [],
output_variables: [
{ id: "output-1", key: "temperature", type: "NUMBER" },
],
workflow_raw_data: {
edges: [],
nodes: [],
definition: null,
output_values: [],
},
},
};

const functionsAttribute = nodeAttributeFactory(
"functions-attr-id",
"functions",
{
type: "CONSTANT_VALUE",
value: {
type: "JSON",
value: [inlineWorkflowWithToolWrapper],
},
}
);

const nodeData = toolCallingNodeFactory({
nodePorts: nodePortData,
nodeAttributes: [functionsAttribute],
});

const nodeContext = (await createNodeContext({
workflowContext,
nodeData,
})) as GenericNodeContext;

const node = new GenericNode({
workflowContext,
nodeContext,
});

node.getNodeFile().write(writer);
expect(await writer.toStringFormatted()).toMatchSnapshot();
});
});

describe("workflow deployment", () => {
Expand Down
89 changes: 56 additions & 33 deletions ee/codegen/src/generators/nodes/generic-node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { join } from "path";

import { python } from "@fern-api/python-ast";
import { Field } from "@fern-api/python-ast/Field";
import { FunctionDefinition } from "vellum-ai/api/types";
import {
PromptBlock as PromptBlockSerializer,
PromptParameters as PromptParametersSerializer,
Expand Down Expand Up @@ -317,35 +318,16 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {
const codeExecutionFunction = f as FunctionArgs;
this.generateFunctionFile([codeExecutionFunction]);
const snakeName = toPythonSafeSnakeCase(codeExecutionFunction.name);
// Use toValidPythonIdentifier to ensure the name is safe for Python references
// but preserve original casing when possible (see APO-1372)
const safeName = toValidPythonIdentifier(codeExecutionFunction.name);
const functionReference = python.reference({
name: safeName, // Use safe Python identifier that preserves original casing
modulePath: [`.${snakeName}`], // Import from snake_case module
name: safeName,
modulePath: [`.${snakeName}`],
});

// Check if function has inputs or examples that need to be wrapped with tool()
const parsedInputs = this.parseToolInputs(codeExecutionFunction);
// Read examples from definition.parameters.examples (JSON Schema examples keyword)
const parameters = codeExecutionFunction.definition?.parameters as
| Record<string, unknown>
| undefined;
const examples =
(parameters?.examples as Array<Record<string, unknown>>) ?? null;
const hasInputs = parsedInputs && Object.keys(parsedInputs).length > 0;
const hasExamples = Array.isArray(examples) && examples.length > 0;

if (hasInputs || hasExamples) {
// Wrap the function reference with tool(...)(func)
const wrapper = this.getToolInvocation(parsedInputs, examples);
return new WrappedCall({
wrapper,
inner: functionReference,
});
}

return functionReference;
return this.wrapWithToolIfNeeded(
functionReference,
codeExecutionFunction.definition
);
}

private handleInlineWorkflowFunction(f: ToolArgs): AstNode | null {
Expand Down Expand Up @@ -390,10 +372,15 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {

const workflowClassName =
nestedWorkflowProject.workflowContext.workflowClassName;
return python.reference({
const workflowReference = python.reference({
name: workflowClassName,
modulePath: [`.${workflowName}`, GENERATED_WORKFLOW_MODULE_NAME],
});

return this.wrapWithToolIfNeeded(
workflowReference,
inlineWorkflow.definition
);
}

return null;
Expand Down Expand Up @@ -870,11 +857,11 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {
* Parses the tool inputs from a function definition.
* Returns null if there are no inputs or if parsing fails.
*/
private parseToolInputs(
f: FunctionArgs
private parseToolInputsFromDefinition(
definition: FunctionDefinition | undefined
): Record<string, WorkflowValueDescriptorType> | null {
const inputs = f.definition?.inputs;
if (!f.definition || !inputs) {
const inputs = definition?.inputs;
if (!definition || !inputs) {
return null;
}

Expand All @@ -901,9 +888,45 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {
return parsedInputs;
}

/**
* Creates a tool(inputs={...}, examples=[...]) method invocation for wrapping function references.
*/
private parseToolExamplesFromDefinition(
definition: FunctionDefinition | undefined
): Array<Record<string, unknown>> | null {
const parameters = definition?.parameters as
| Record<string, unknown>
| undefined
| null;
if (!parameters) {
return null;
}

const examples = parameters.examples as
| Array<Record<string, unknown>>
| undefined;
if (!Array.isArray(examples) || examples.length === 0) {
return null;
}

return examples;
}

private wrapWithToolIfNeeded(
inner: python.AstNode,
definition: FunctionDefinition | undefined
): python.AstNode {
const parsedInputs = this.parseToolInputsFromDefinition(definition);
const examples = this.parseToolExamplesFromDefinition(definition);

if (parsedInputs || examples) {
const wrapper = this.getToolInvocation(parsedInputs, examples);
return new WrappedCall({
wrapper,
inner,
});
}

return inner;
}

private getToolInvocation(
inputs: Record<string, WorkflowValueDescriptorType> | null,
examples: Array<Record<string, unknown>> | null
Expand Down
1 change: 1 addition & 0 deletions ee/codegen/src/types/vellum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,7 @@ export type FunctionArgs = {
export type InlineWorkflowFunctionArgs = {
type: "INLINE_WORKFLOW";
exec_config: WorkflowVersionExecConfig;
definition?: FunctionDefinition;
} & NameDescription;

export type WorkflowDeploymentFunctionArgs = {
Expand Down