-
Notifications
You must be signed in to change notification settings - Fork 18
Refactor compile_annotation to use Pydantic TypeAdapter throughout #3194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
02385b8
ce808ee
2f11cb8
6b61cb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,16 +8,14 @@ | |
| Callable, | ||
| ForwardRef, | ||
| List, | ||
| Literal, | ||
| Optional, | ||
| Type, | ||
| Union, | ||
| get_args, | ||
| get_origin, | ||
| ) | ||
|
|
||
| from pydantic import BaseModel | ||
| from pydantic_core import PydanticUndefined | ||
| from pydantic import BaseModel, TypeAdapter | ||
| from pydash import snake_case | ||
|
|
||
| from vellum import Vellum | ||
|
|
@@ -38,29 +36,22 @@ | |
| if TYPE_CHECKING: | ||
| from vellum.workflows.workflows.base import BaseWorkflow | ||
|
|
||
| type_map: dict[Any, str] = { | ||
| str: "string", | ||
| int: "integer", | ||
| float: "number", | ||
| bool: "boolean", | ||
| list: "array", | ||
| dict: "object", | ||
| None: "null", | ||
| type(None): "null", | ||
| inspect._empty: "null", | ||
| "None": "null", | ||
| } | ||
|
|
||
| for k, v in list(type_map.items()): | ||
| if isinstance(k, type): | ||
| type_map[k.__name__] = v | ||
|
|
||
|
|
||
| def _get_def_name(annotation: Type) -> str: | ||
| return f"{annotation.__module__}.{annotation.__qualname__}" | ||
|
|
||
|
|
||
| def _strip_titles(value: Any) -> Any: | ||
| """Recursively remove 'title' keys from a schema dictionary.""" | ||
| if isinstance(value, dict): | ||
| return {k: _strip_titles(v) for k, v in value.items() if k != "title"} | ||
| if isinstance(value, list): | ||
| return [_strip_titles(v) for v in value] | ||
| return value | ||
|
|
||
|
|
||
| def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict: | ||
| # Handle special cases that TypeAdapter doesn't handle well | ||
| if annotation is None: | ||
| return {"type": "null"} | ||
|
|
||
|
|
@@ -70,89 +61,26 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict: | |
| if annotation is datetime: | ||
| return {"type": "string", "format": "date-time"} | ||
|
|
||
| if get_origin(annotation) is Union: | ||
| if is_json_type(get_args(annotation)): | ||
| return {"$ref": "#/$defs/vellum.workflows.types.core.Json"} | ||
|
|
||
| return {"anyOf": [compile_annotation(a, defs) for a in get_args(annotation)]} | ||
|
|
||
| if get_origin(annotation) is Literal: | ||
| values = list(get_args(annotation)) | ||
| types = {type(value) for value in values} | ||
| if len(types) == 1: | ||
| value_type = types.pop() | ||
| if value_type in type_map: | ||
| return {"type": type_map[value_type], "enum": values} | ||
| else: | ||
| return {"enum": values} | ||
| else: | ||
| return {"enum": values} | ||
|
|
||
| if get_origin(annotation) is dict: | ||
| _, value_type = get_args(annotation) | ||
| return {"type": "object", "additionalProperties": compile_annotation(value_type, defs)} | ||
|
|
||
| if get_origin(annotation) is list: | ||
| item_type = get_args(annotation)[0] | ||
| return {"type": "array", "items": compile_annotation(item_type, defs)} | ||
|
|
||
| if get_origin(annotation) is tuple: | ||
| args = get_args(annotation) | ||
| if len(args) == 2 and args[1] is Ellipsis: | ||
| # Tuple[int, ...] with homogeneous items | ||
| return {"type": "array", "items": compile_annotation(args[0], defs)} | ||
| else: | ||
| # Tuple[int, str] with fixed length items | ||
| result = { | ||
| "type": "array", | ||
| "prefixItems": [compile_annotation(arg, defs) for arg in args], | ||
| "minItems": len(args), | ||
| "maxItems": len(args), | ||
| } | ||
| return result | ||
|
|
||
| if dataclasses.is_dataclass(annotation) and isinstance(annotation, type): | ||
| def_name = _get_def_name(annotation) | ||
| if def_name not in defs: | ||
| properties = {} | ||
| required = [] | ||
| for field in dataclasses.fields(annotation): | ||
| properties[field.name] = compile_annotation(field.type, defs) | ||
| if field.default is dataclasses.MISSING: | ||
| required.append(field.name) | ||
| else: | ||
| properties[field.name]["default"] = _compile_default_value(field.default) | ||
| defs[def_name] = {"type": "object", "properties": properties, "required": required} | ||
| return {"$ref": f"#/$defs/{def_name}"} | ||
|
|
||
| if inspect.isclass(annotation) and issubclass(annotation, BaseModel): | ||
| def_name = _get_def_name(annotation) | ||
| if def_name not in defs: | ||
| properties = {} | ||
| required = [] | ||
| for field_name, field_info in annotation.model_fields.items(): | ||
| # field_info is a FieldInfo object which has an annotation attribute | ||
| properties[field_name] = compile_annotation(field_info.annotation, defs) | ||
|
|
||
| if field_info.description is not None: | ||
| properties[field_name]["description"] = field_info.description | ||
|
|
||
| if field_info.default is PydanticUndefined: | ||
| required.append(field_name) | ||
| else: | ||
| properties[field_name]["default"] = _compile_default_value(field_info.default) | ||
| defs[def_name] = {"type": "object", "properties": properties, "required": required} | ||
|
|
||
| return {"$ref": f"#/$defs/{def_name}"} | ||
| if get_origin(annotation) is Union and is_json_type(get_args(annotation)): | ||
| return {"$ref": "#/$defs/vellum.workflows.types.core.Json"} | ||
|
|
||
| if type(annotation) is ForwardRef: | ||
| # Ignore forward references for now | ||
| return {} | ||
|
|
||
| if annotation not in type_map: | ||
| raise ValueError(f"Failed to compile type: {annotation}") | ||
| # Use Pydantic's TypeAdapter for everything else | ||
| try: | ||
| schema = TypeAdapter(annotation).json_schema() | ||
| schema = _strip_titles(schema) | ||
|
Comment on lines
+71
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Parameters without type hints now hit the TypeAdapter call and raise instead of defaulting to Useful? React with 👍 / 👎. |
||
|
|
||
| # Merge any nested $defs into the top-level defs dict | ||
| if "$defs" in schema: | ||
| nested_defs = schema.pop("$defs") | ||
| defs.update(nested_defs) | ||
|
Comment on lines
+76
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nested Useful? React with 👍 / 👎. |
||
|
|
||
| return {"type": type_map[annotation]} | ||
| return schema | ||
| except Exception as exc: | ||
| raise ValueError(f"Failed to compile schema for annotation {annotation!r}") from exc | ||
|
|
||
|
|
||
| def _compile_default_value(default: Any) -> Any: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameters without type hints arrive as
inspect._empty, butcompile_annotationnow falls through toTypeAdapter(annotation)which raises because_emptyis not a valid type. Previously these were mapped to{"type": "null"}, so compiling an unannotated tool/workflow function succeeded; now the same call surfaces a ValueError and blocks schema generation for any function with untyped args.Useful? React with 👍 / 👎.