Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"name": "chat_history",
"type": "CHAT_HISTORY",
"value": null,
"schema": {"type": "array", "items": {"$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage"}}
"schema": {"type": "array", "items": {"$ref": "#/$defs/ChatMessage"}}
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_serialize_workflow():
"name": "chat_history",
"type": "CHAT_HISTORY",
"value": None,
"schema": {"type": "array", "items": {"$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage"}},
"schema": {"type": "array", "items": {"$ref": "#/$defs/ChatMessage"}},
},
],
}
122 changes: 25 additions & 97 deletions src/vellum/workflows/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
Callable,
ForwardRef,
List,
Literal,
Optional,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic import BaseModel, TypeAdapter
from pydash import snake_case

from vellum import Vellum
Expand All @@ -38,29 +36,22 @@
if TYPE_CHECKING:
from vellum.workflows.workflows.base import BaseWorkflow

type_map: dict[Any, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
None: "null",
type(None): "null",
inspect._empty: "null",
"None": "null",
}

for k, v in list(type_map.items()):
if isinstance(k, type):
type_map[k.__name__] = v


def _get_def_name(annotation: Type) -> str:
return f"{annotation.__module__}.{annotation.__qualname__}"


def _strip_titles(value: Any) -> Any:
"""Recursively remove 'title' keys from a schema dictionary."""
if isinstance(value, dict):
return {k: _strip_titles(v) for k, v in value.items() if k != "title"}
if isinstance(value, list):
return [_strip_titles(v) for v in value]
return value


def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
# Handle special cases that TypeAdapter doesn't handle well
if annotation is None:
return {"type": "null"}

Expand All @@ -70,89 +61,26 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
if annotation is datetime:
return {"type": "string", "format": "date-time"}

if get_origin(annotation) is Union:
if is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

return {"anyOf": [compile_annotation(a, defs) for a in get_args(annotation)]}

if get_origin(annotation) is Literal:
values = list(get_args(annotation))
types = {type(value) for value in values}
if len(types) == 1:
value_type = types.pop()
if value_type in type_map:
return {"type": type_map[value_type], "enum": values}
else:
return {"enum": values}
else:
return {"enum": values}

if get_origin(annotation) is dict:
_, value_type = get_args(annotation)
return {"type": "object", "additionalProperties": compile_annotation(value_type, defs)}

if get_origin(annotation) is list:
item_type = get_args(annotation)[0]
return {"type": "array", "items": compile_annotation(item_type, defs)}

if get_origin(annotation) is tuple:
args = get_args(annotation)
if len(args) == 2 and args[1] is Ellipsis:
# Tuple[int, ...] with homogeneous items
return {"type": "array", "items": compile_annotation(args[0], defs)}
else:
# Tuple[int, str] with fixed length items
result = {
"type": "array",
"prefixItems": [compile_annotation(arg, defs) for arg in args],
"minItems": len(args),
"maxItems": len(args),
}
return result

if dataclasses.is_dataclass(annotation) and isinstance(annotation, type):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field in dataclasses.fields(annotation):
properties[field.name] = compile_annotation(field.type, defs)
if field.default is dataclasses.MISSING:
required.append(field.name)
else:
properties[field.name]["default"] = _compile_default_value(field.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}
return {"$ref": f"#/$defs/{def_name}"}

if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field_name, field_info in annotation.model_fields.items():
# field_info is a FieldInfo object which has an annotation attribute
properties[field_name] = compile_annotation(field_info.annotation, defs)

if field_info.description is not None:
properties[field_name]["description"] = field_info.description

if field_info.default is PydanticUndefined:
required.append(field_name)
else:
properties[field_name]["default"] = _compile_default_value(field_info.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}

return {"$ref": f"#/$defs/{def_name}"}
if get_origin(annotation) is Union and is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

if type(annotation) is ForwardRef:
# Ignore forward references for now
return {}

if annotation not in type_map:
raise ValueError(f"Failed to compile type: {annotation}")
# Use Pydantic's TypeAdapter for everything else
try:
schema = TypeAdapter(annotation).json_schema()
Comment on lines +71 to +73

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore null schema for unannotated parameters

Parameters without type hints arrive as inspect._empty, but compile_annotation now falls through to TypeAdapter(annotation) which raises because _empty is not a valid type. Previously these were mapped to {"type": "null"}, so compiling an unannotated tool/workflow function succeeded; now the same call surfaces a ValueError and blocks schema generation for any function with untyped args.

Useful? React with 👍 / 👎.

schema = _strip_titles(schema)
Comment on lines +71 to +74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve support for unannotated parameters

Parameters without type hints now hit the TypeAdapter call and raise instead of defaulting to {"type": "null"} as before. Any call to compile_function_definition on a function whose parameters are unannotated (where param.annotation is inspect._empty) will now propagate a ValueError from TypeAdapter, whereas the previous type_map explicitly handled that case. This breaks existing usage of untyped functions.

Useful? React with 👍 / 👎.


# Merge any nested $defs into the top-level defs dict
if "$defs" in schema:
nested_defs = schema.pop("$defs")
defs.update(nested_defs)
Comment on lines +76 to +79

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid clobbering $defs when class names collide

Nested $defs emitted by TypeAdapter are merged with a raw defs.update, so if two different dataclasses/Pydantic models share the same class name (e.g., two ChatMessage types from different modules), the second schema overwrites the first while both parameters still point to #/$defs/ChatMessage. That produces an incorrect schema for the first parameter; the prior implementation used fully qualified names to avoid this collision.

Useful? React with 👍 / 👎.


return {"type": type_map[annotation]}
return schema
except Exception as exc:
raise ValueError(f"Failed to compile schema for annotation {annotation!r}") from exc


def _compile_default_value(default: Any) -> Any:
Expand Down
67 changes: 26 additions & 41 deletions src/vellum/workflows/utils/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def my_function(a: str, b: int, c: float, d: bool, e: list, f: dict):
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
},
"required": ["a", "b", "c", "d", "e", "f"],
},
Expand Down Expand Up @@ -187,20 +187,19 @@ def my_function(c: MyDataClass):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__dataclasses.<locals>.MyDataClass"
# Pydantic's TypeAdapter inlines the schema instead of using $refs
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}"}},
"required": ["c"],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"required": ["a", "b"],
}
},
"required": ["c"],
},
)

Expand All @@ -218,15 +217,13 @@ def my_function(c: MyPydanticModel):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__pydantic.<locals>.MyPydanticModel"
# Pydantic's TypeAdapter inlines the schema instead of using $refs
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}"}},
"required": ["c"],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number"},
Expand All @@ -235,6 +232,7 @@ def my_function(c: MyPydanticModel):
"required": ["a", "b"],
}
},
"required": ["c"],
},
)

Expand All @@ -253,20 +251,20 @@ def my_function(c: MyDataClass = MyDataClass(a=1, b="hello")):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__default_dataclass.<locals>.MyDataClass"
# Pydantic's TypeAdapter inlines the schema instead of using $refs
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}},
"required": [],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"required": ["a", "b"],
"default": {"a": 1, "b": "hello"},
}
},
"required": [],
},
)

Expand All @@ -284,38 +282,24 @@ def my_function(c: MyPydanticModel = MyPydanticModel(a=1, b="hello")):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__default_pydantic.<locals>.MyPydanticModel"
# Pydantic's TypeAdapter inlines the schema instead of using $refs
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}},
"required": [],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"required": ["a", "b"],
"default": {"a": 1, "b": "hello"},
}
},
"required": [],
},
)


def test_compile_function_definition__lambda():
# GIVEN a lambda
lambda_function = lambda x: x + 1 # noqa: E731

# WHEN compiling the function
compiled_function = compile_function_definition(lambda_function)

# THEN it should return the compiled function definition
assert compiled_function == FunctionDefinition(
name="<lambda>",
parameters={"type": "object", "properties": {"x": {"type": "null"}}, "required": ["x"]},
)


def test_compile_inline_workflow_function_definition():
class MyNode(BaseNode):
pass
Expand Down Expand Up @@ -383,8 +367,8 @@ class MyWorkflow(BaseWorkflow[MyInputs, BaseState]):
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
},
"required": ["a", "b", "c", "d", "e", "f"],
},
Expand Down Expand Up @@ -622,7 +606,8 @@ def my_function(a: Literal[MyEnum.FOO, MyEnum.BAR]):

compiled_function = compile_function_definition(my_function)
assert isinstance(compiled_function.parameters, dict)
assert compiled_function.parameters["properties"]["a"] == {"enum": [MyEnum.FOO, MyEnum.BAR]}
# Pydantic's TypeAdapter converts enum values to their actual values
assert compiled_function.parameters["properties"]["a"] == {"enum": ["foo", "bar"], "type": "string"}


def test_compile_function_definition__annotated_descriptions():
Expand Down Expand Up @@ -770,8 +755,8 @@ def my_function_with_string_annotations(
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
"g": {"type": "null"},
},
"required": ["a", "b", "c", "d", "e", "f", "g"],
Expand Down