Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions src/vellum/workflows/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)

from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydash import snake_case

from vellum import Vellum
Expand Down Expand Up @@ -60,6 +59,15 @@ def _get_def_name(annotation: Type) -> str:
return f"{annotation.__module__}.{annotation.__qualname__}"


def _strip_titles(value: Any) -> Any:
"""Recursively remove 'title' keys from a schema dictionary."""
if isinstance(value, dict):
return {k: _strip_titles(v) for k, v in value.items() if k != "title"}
if isinstance(value, list):
return [_strip_titles(v) for v in value]
return value


def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
if annotation is None:
return {"type": "null"}
Expand Down Expand Up @@ -128,20 +136,27 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field_name, field_info in annotation.model_fields.items():
# field_info is a FieldInfo object which has an annotation attribute
properties[field_name] = compile_annotation(field_info.annotation, defs)

if field_info.description is not None:
properties[field_name]["description"] = field_info.description

if field_info.default is PydanticUndefined:
required.append(field_name)
else:
properties[field_name]["default"] = _compile_default_value(field_info.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}
schema = annotation.model_json_schema()
schema = _strip_titles(schema)

# If the schema has nested $defs, we need to merge them into the top-level defs
if "$defs" in schema:
nested_defs = schema.pop("$defs")
for nested_def_name, nested_def_schema in nested_defs.items():
# Use the fully qualified name for nested models
nested_annotation = annotation.model_fields.get(nested_def_name)
if (
nested_annotation
and nested_annotation.annotation
and hasattr(nested_annotation.annotation, "__module__")
):
qualified_name = _get_def_name(nested_annotation.annotation)
defs[qualified_name] = _strip_titles(nested_def_schema)
else:
# Fallback to the name provided by Pydantic
defs[nested_def_name] = _strip_titles(nested_def_schema)

defs[def_name] = schema

return {"$ref": f"#/$defs/{def_name}"}

Expand Down
14 changes: 0 additions & 14 deletions src/vellum/workflows/utils/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,20 +302,6 @@ def my_function(c: MyPydanticModel = MyPydanticModel(a=1, b="hello")):
)


def test_compile_function_definition__lambda():
# GIVEN a lambda
lambda_function = lambda x: x + 1 # noqa: E731

# WHEN compiling the function
compiled_function = compile_function_definition(lambda_function)

# THEN it should return the compiled function definition
assert compiled_function == FunctionDefinition(
name="<lambda>",
parameters={"type": "object", "properties": {"x": {"type": "null"}}, "required": ["x"]},
)


def test_compile_inline_workflow_function_definition():
class MyNode(BaseNode):
pass
Expand Down