Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions ee/codegen/src/__test__/__snapshots__/generate-code.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@ class UseApiWithSecret(FinalOutputNode[BaseState, Any]):
"
`;

exports[`generateCode > should generate code for %1 chat-message-trigger.ts > triggers/chat_message.py 1`] = `
"from vellum.workflows.references import LazyReference
from vellum.workflows.triggers import ChatMessageTrigger

from ..nodes.bottom_node import BottomNode


class ChatMessage(ChatMessageTrigger):
message: str

class Config(ChatMessageTrigger.Config):
output = LazyReference(lambda: BottomNode.Outputs.result)

class Display(ChatMessageTrigger.Display):
label = "Chat Message"
x = 100
y = 200
z_index = 1
icon = "vellum:icon:message"
color = "blue"
"
`;

exports[`generateCode > should generate code for %1 chat-message-trigger.ts > workflow.py 1`] = `
"from vellum.workflows import BaseWorkflow

from .nodes.bottom_node import BottomNode
from .nodes.top_node import TopNode
from .triggers.chat_message import ChatMessage


class Workflow(BaseWorkflow):
graph = {
TopNode,
ChatMessage >> BottomNode,
}
"
`;

exports[`generateCode > should generate code for %1 code-execution-node-with-await-all.ts > nodes/code_execution_with_await_all/__init__.py 1`] = `
"from vellum.workflows.nodes.displayable import CodeExecutionNode
from vellum.workflows.state import BaseState
Expand Down
127 changes: 127 additions & 0 deletions ee/codegen/src/__test__/generate-code-fixtures/chat-message-trigger.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
export default {
workflow_raw_data: {
nodes: [
{
id: "entrypoint-node",
type: "ENTRYPOINT",
data: {
label: "Entrypoint",
source_handle_id: "entrypoint-source",
},
inputs: [],
},
{
id: "top-node",
type: "GENERIC",
label: "Top Node",
display_data: null,
base: {
name: "BaseNode",
module: ["vellum", "workflows", "nodes", "bases", "base"],
},
definition: {
name: "TopNode",
module: ["testing", "nodes", "top_node"],
},
trigger: {
id: "top-target",
merge_behavior: "AWAIT_ATTRIBUTES",
},
ports: [
{
id: "top-default-port-id",
name: "default",
type: "DEFAULT",
},
],
outputs: [],
attributes: [],
},
{
id: "bottom-node",
type: "GENERIC",
label: "Bottom Node",
display_data: null,
base: {
name: "BaseNode",
module: ["vellum", "workflows", "nodes", "bases", "base"],
},
definition: {
name: "BottomNode",
module: ["testing", "nodes", "bottom_node"],
},
trigger: {
id: "bottom-target",
merge_behavior: "AWAIT_ATTRIBUTES",
},
ports: [
{
id: "bottom-default-port-id",
name: "default",
type: "DEFAULT",
},
],
outputs: [
{
id: "bottom-output-id",
name: "result",
type: "STRING",
},
],
attributes: [],
},
],
edges: [
{
id: "edge-1",
source_node_id: "entrypoint-node",
source_handle_id: "entrypoint-source",
target_node_id: "top-node",
target_handle_id: "top-target",
type: "DEFAULT",
},
{
id: "edge-2",
source_node_id: "chat-message-trigger",
source_handle_id: "chat-message-trigger",
target_node_id: "bottom-node",
target_handle_id: "bottom-target",
type: "DEFAULT",
},
],
output_values: [],
},
input_variables: [],
output_variables: [],
triggers: [
{
id: "chat-message-trigger",
type: "CHAT_MESSAGE",
attributes: [
{
id: "message-attribute-id",
key: "message",
type: "JSON",
},
],
exec_config: {
output: {
type: "NODE_OUTPUT",
node_id: "bottom-node",
node_output_id: "bottom-output-id",
},
},
display_data: {
label: "Chat Message",
position: {
x: 100,
y: 200,
},
z_index: 1,
icon: "vellum:icon:message",
color: "blue",
},
},
],
assertions: ["workflow.py", "triggers/chat_message.py"],
};
15 changes: 15 additions & 0 deletions ee/codegen/src/__test__/utils/triggers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,19 @@ describe("getTriggerClassInfo", () => {
modulePath: ["tests", "fixtures", "triggers", "slack_new_message"],
});
});

it("should return correct info for CHAT_MESSAGE trigger", () => {
const trigger: WorkflowTrigger = {
id: "chat-message-trigger-id",
type: WorkflowTriggerType.CHAT_MESSAGE,
attributes: [{ id: "attr-1", type: "JSON", key: "message" }],
};

const result = getTriggerClassInfo(trigger, workflowContextFactory());

expect(result).toEqual({
className: "ChatMessageTrigger",
modulePath: ["code", "triggers", "chat_message"],
});
});
});
37 changes: 37 additions & 0 deletions ee/codegen/src/context/trigger-context/chat-message-trigger.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { GENERATED_TRIGGERS_MODULE_NAME } from "src/constants";
import { BaseTriggerContext } from "src/context/trigger-context/base";
import { ChatMessageTrigger } from "src/types/vellum";
import { createPythonClassName, toPythonSafeSnakeCase } from "src/utils/casing";

export class ChatMessageTriggerContext extends BaseTriggerContext<ChatMessageTrigger> {
protected getTriggerModuleInfo(): {
moduleName: string;
className: string;
modulePath: string[];
} {
const label = this.triggerData.displayData?.label || "chat_message";
const rawModuleName = toPythonSafeSnakeCase(label);
let moduleName = rawModuleName;
let numRenameAttempts = 0;
while (this.workflowContext.isTriggerModuleNameUsed(moduleName)) {
moduleName = `${rawModuleName}_${numRenameAttempts + 1}`;
numRenameAttempts += 1;
}
const className = createPythonClassName(
this.triggerData.displayData?.label || "ChatMessageTrigger",
{ force: true }
);

const modulePath = [
...this.workflowContext.modulePath.slice(0, -1),
GENERATED_TRIGGERS_MODULE_NAME,
moduleName,
];

return {
moduleName,
className,
modulePath,
};
}
}
10 changes: 10 additions & 0 deletions ee/codegen/src/context/trigger-context/create-trigger-context.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { WorkflowContext } from "src/context";
import { ChatMessageTriggerContext } from "src/context/trigger-context/chat-message-trigger";
import { IntegrationTriggerContext } from "src/context/trigger-context/integration-trigger";
import { ScheduledTriggerContext } from "src/context/trigger-context/scheduled-trigger";
import {
ChatMessageTrigger,
IntegrationTrigger,
ScheduledTrigger,
WorkflowTrigger,
Expand Down Expand Up @@ -37,6 +39,14 @@ export function createTriggerContext({
workflowContext.addTriggerContext(triggerContext);
break;
}
case "CHAT_MESSAGE": {
const triggerContext = new ChatMessageTriggerContext({
workflowContext,
triggerData: triggerData as ChatMessageTrigger,
});
workflowContext.addTriggerContext(triggerContext);
break;
}
case "MANUAL":
// For now, we don't create contexts for MANUAL triggers
// as they don't have associated files
Expand Down
1 change: 1 addition & 0 deletions ee/codegen/src/context/trigger-context/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { BaseTriggerContext } from "./base";
export { ChatMessageTriggerContext } from "./chat-message-trigger";
export { createTriggerContext } from "./create-trigger-context";
export { IntegrationTriggerContext } from "./integration-trigger";
export { ScheduledTriggerContext } from "./scheduled-trigger";
120 changes: 120 additions & 0 deletions ee/codegen/src/generators/triggers/chat-message-trigger.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { python } from "@fern-api/python-ast";

import {
OUTPUTS_CLASS_NAME,
VELLUM_WORKFLOW_TRIGGERS_MODULE_PATH,
} from "src/constants";
import { AccessAttribute } from "src/generators/extensions/access-attribute";
import { Class } from "src/generators/extensions/class";
import { ClassInstantiation } from "src/generators/extensions/class-instantiation";
import { Field } from "src/generators/extensions/field";
import { MethodArgument } from "src/generators/extensions/method-argument";
import { Reference } from "src/generators/extensions/reference";
import { BaseTrigger } from "src/generators/triggers/base-trigger";
import { createPythonClassName, toPythonSafeSnakeCase } from "src/utils/casing";

import type { AstNode } from "src/generators/extensions/ast-node";
import type { ChatMessageTrigger as ChatMessageTriggerType } from "src/types/vellum";

export declare namespace ChatMessageTriggerGenerator {
interface Args {
workflowContext: BaseTrigger.Args<ChatMessageTriggerType>["workflowContext"];
trigger: ChatMessageTriggerType;
}
}

export class ChatMessageTrigger extends BaseTrigger<ChatMessageTriggerType> {
protected generateClassName(): string {
const label = this.trigger.displayData?.label || "ChatMessageTrigger";
return createPythonClassName(label, {
force: true,
});
}

protected getModuleName(): string {
const label = this.trigger.displayData?.label || "chat_message";
return toPythonSafeSnakeCase(label);
}

protected getBaseTriggerClassName(): string {
return "ChatMessageTrigger";
}

protected getTriggerClassBody(): AstNode[] {
const body: AstNode[] = [];

// Add attribute fields
body.push(...this.createAttributeFields());

// Create Config class if execConfig.output is present
const execConfig = this.trigger.execConfig;
if (execConfig?.output) {
body.push(this.createConfigClass(execConfig.output));
}

return body;
}

private createConfigClass(
output: NonNullable<ChatMessageTriggerType["execConfig"]>["output"]
): AstNode {
const configClass = new Class({
name: "Config",
extends_: [
new Reference({
name: "ChatMessageTrigger",
modulePath: VELLUM_WORKFLOW_TRIGGERS_MODULE_PATH,
attribute: ["Config"],
}),
],
});

if (output && output.type === "NODE_OUTPUT") {
const nodeContext = this.workflowContext.findNodeContext(output.nodeId);

if (nodeContext) {
const nodeOutputName = nodeContext.getNodeOutputNameById(
output.nodeOutputId
);

if (nodeOutputName) {
const lazyReferenceValue = new ClassInstantiation({
classReference: new Reference({
name: "LazyReference",
modulePath: [
...this.workflowContext.sdkModulePathNames
.WORKFLOWS_MODULE_PATH,
"references",
],
}),
arguments_: [
new MethodArgument({
value: python.lambda({
body: new AccessAttribute({
lhs: new Reference({
name: nodeContext.nodeClassName,
modulePath: nodeContext.nodeModulePath,
}),
rhs: new Reference({
name: `${OUTPUTS_CLASS_NAME}.${nodeOutputName}`,
modulePath: [],
}),
}),
}),
}),
],
});

configClass.add(
new Field({
name: "output",
initializer: lazyReferenceValue,
})
);
}
}
}

return configClass;
}
}
Loading