diff --git a/dynamiq/callbacks/streaming.py b/dynamiq/callbacks/streaming.py index 62f535bf7..fde20d59a 100644 --- a/dynamiq/callbacks/streaming.py +++ b/dynamiq/callbacks/streaming.py @@ -1122,19 +1122,17 @@ def _try_initialize_next_json_field(self, buf: str, final_answer_only: bool) -> buf, JSONStreamingField.ANSWER.value, StreamingState.ANSWER ) - def _emit_tool_input_state(self, buf: str) -> None: - """Emit content for the current TOOL_INPUT state.""" + def _emit_tool_input_state(self, buf: str) -> bool: + """Emit content for the current TOOL_INPUT state. Returns True when complete.""" if self._fc_object_tool_input: - self._emit_json_object_field_content(buf, StreamingState.TOOL_INPUT) - else: - self._emit_json_field_content(buf, StreamingState.TOOL_INPUT) + return self._emit_json_object_field_content(buf, StreamingState.TOOL_INPUT) + return self._emit_json_field_content(buf, StreamingState.TOOL_INPUT) - def _emit_answer_state(self, buf: str) -> None: - """Emit content for the current ANSWER state.""" + def _emit_answer_state(self, buf: str) -> bool: + """Emit content for the current ANSWER state. Returns True when complete.""" if self._fc_object_answer: - self._emit_json_object_field_content(buf, StreamingState.ANSWER) - else: - self._emit_json_field_content(buf, StreamingState.ANSWER) + return self._emit_json_object_field_content(buf, StreamingState.ANSWER) + return self._emit_json_field_content(buf, StreamingState.ANSWER) def _process_json_mode(self, final_answer_only: bool) -> None: """ diff --git a/dynamiq/nodes/agents/agent.py b/dynamiq/nodes/agents/agent.py index 7b5c9e96a..0e1715a23 100644 --- a/dynamiq/nodes/agents/agent.py +++ b/dynamiq/nodes/agents/agent.py @@ -1,9 +1,11 @@ import json +import types from concurrent.futures import as_completed -from typing import Any, Callable, Literal, Mapping +from typing import Any, Callable, Literal, Mapping, Union, get_args, get_origin from litellm import get_supported_openai_params, supports_function_calling, supports_response_schema from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator +from pydantic_core import from_json from dynamiq.callbacks import AgentStreamingParserCallback, StreamingQueueCallbackHandler from dynamiq.executors.context import ContextAwareThreadPoolExecutor @@ -99,8 +101,13 @@ def parse_arguments(cls, v: Any) -> Any: if isinstance(v, str): try: return json.loads(v, strict=False) - except json.JSONDecodeError as e: - raise ValueError(f"Tool call arguments are not valid JSON: {e}") + except json.JSONDecodeError: + # Truncated mid-emission (LLM stopped mid-tool-call): parse the + # partial document, dropping the incomplete trailing value. + try: + return from_json(v, allow_partial=True) + except ValueError as e: + raise ValueError(f"Tool call arguments are not valid JSON: {e}") return v or {} def parse_as_tool_call(self) -> ToolCallArguments: @@ -340,6 +347,129 @@ def _emit_tool_input_error( ) self._streaming_tool_run_id = None + @staticmethod + def _annotation_accepts_none(annotation: Any) -> bool: + """Return True if a Pydantic field annotation includes ``NoneType``.""" + if annotation is type(None): + return True + origin = get_origin(annotation) + if origin in (Union, types.UnionType): + return type(None) in get_args(annotation) + return False + + @staticmethod + def _annotation_is_dict_like(annotation: Any) -> bool: + """Return True if the annotation is ``dict`` / ``dict[...]`` or a union including one.""" + if annotation is dict: + return True + origin = get_origin(annotation) + if origin is dict: + return True + if origin in (Union, types.UnionType): + return any(Agent._annotation_is_dict_like(arg) for arg in get_args(annotation)) + return False + + @staticmethod + def _extract_basemodel(annotation: Any) -> type[BaseModel] | None: + """Return the BaseModel subclass in an annotation (handles ``Model | None``), else None.""" + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return annotation + origin = get_origin(annotation) + if origin in (Union, types.UnionType): + for arg in get_args(annotation): + if isinstance(arg, type) and issubclass(arg, BaseModel): + return arg + return None + + def _normalize_fields(self, fields: Mapping[str, Any], data: Any) -> None: + """Reconcile raw tool ``data`` against its Pydantic ``fields`` in one walk. + + Strict tool-calling encodes two things on the wire that the tool's schema + can't accept directly; this single pass reverses both, recursing into + nested ``BaseModel`` fields *and list elements* so they apply at any depth + (e.g. ``list[SubModel]``). Call with ``tool.input_schema.model_fields`` and + the raw ``action_input``. + + - **Free-form ``dict[str, Any]`` shipped as a JSON string.** Strict mode + can't express an open object, so the provider converters send those + fields as JSON-encoded strings. When the field is dict-typed and the + model supplied a JSON string, parse it back into a dict. + - **``None`` for a non-nullable field.** Strict mode keeps every property in + ``required`` and uses ``null`` as the "leave it at the default" signal. A + field with a non-nullable default (``encoding: str = "utf-8"``) can't take + that ``None`` — drop the key so the Pydantic default applies. Fields that + genuinely accept ``None`` (``encoding: str | None = None``) keep it. + """ + if not isinstance(data, dict): + return + for name in list(data): + field = fields.get(name) + if field is None: + continue + value = data[name] + if value is None: + # Strict mode's "use the default" signal — drop unless the field + # genuinely accepts None. (Only meaningful for dict keys; a None + # list element is left in place by ``_normalize_value``.) + if not self._annotation_accepts_none(field.annotation): + del data[name] + else: + data[name] = self._normalize_value(field.annotation, value) + + def _normalize_value(self, annotation: Any, value: Any) -> Any: + """Normalize one value against its annotation; return the (possibly new) value. + + Dicts and lists are mutated in place; a parsed JSON string yields a new dict. + Recurses through nested ``BaseModel`` fields and list/tuple/set elements so + both wire-encoding reversals apply at any depth. + """ + if isinstance(value, str): + # Free-form dict shipped as a JSON string — parse it back. + if self._annotation_is_dict_like(annotation): + stripped = value.strip() + if stripped.startswith("{") and stripped.endswith("}"): + try: + return json.loads(stripped) + except json.JSONDecodeError: + pass # leave as string; Pydantic will surface the error + return value + if isinstance(value, dict): + nested_model = self._extract_basemodel(annotation) + if nested_model is not None: + self._normalize_fields(nested_model.model_fields, value) + return value + if isinstance(value, list): + element_annotation = self._list_element_annotation(annotation) + if element_annotation is not None: + for i, item in enumerate(value): + value[i] = self._normalize_value(element_annotation, item) + return value + return value + + @staticmethod + def _list_element_annotation(annotation: Any) -> Any | None: + """Return the element type of a ``list``/``set``/``tuple`` annotation + (handles ``list[X] | None`` / ``Optional[list[X]]``), else None. + + Heterogeneous ``tuple[X, Y]`` is skipped (returns None) — only homogeneous + ``tuple[X, ...]`` yields an element type. + """ + origin = get_origin(annotation) + if origin in (Union, types.UnionType): + for arg in get_args(annotation): + elem = Agent._list_element_annotation(arg) + if elem is not None: + return elem + return None + if origin in (list, set, frozenset): + args = get_args(annotation) + return args[0] if args else None + if origin is tuple: + args = get_args(annotation) + if len(args) == 2 and args[1] is Ellipsis: # tuple[X, ...] + return args[0] + return None + def _should_delegate_final( self, tool: Node | None, @@ -811,6 +941,7 @@ def _handle_function_calling_mode( if len(actual_tool_calls) > 1 and self.parallel_tool_calls_enabled: tool_items = [] for tc in actual_tool_calls: + tc_name = tc.function.name.strip() args = tc.function.parse_as_tool_call() tc_input = args.to_action_input() if isinstance(tc_input, str): @@ -820,9 +951,12 @@ def _handle_function_calling_mode( raise ActionParsingException(f"Error parsing action_input string. {e}", recoverable=True) if not isinstance(tc_input, dict): tc_input = {"input": tc_input} + tc_tool = self.tool_by_names.get(self.sanitize_tool_name(tc_name)) + if tc_tool is not None: + self._normalize_fields(tc_tool.input_schema.model_fields, tc_input) tool_items.append( ToolCallItem( - name=tc.function.name.strip(), + name=tc_name, input=tc_input, thought=args.thought, ) @@ -846,6 +980,10 @@ def _handle_function_calling_mode( if not isinstance(action_input, dict): action_input = {"input": action_input} + tool = self.tool_by_names.get(self.sanitize_tool_name(action)) + if tool is not None: + self._normalize_fields(tool.input_schema.model_fields, action_input) + self.log_reasoning(thought, action, action_input, loop_num) return thought, action, action_input @@ -1524,12 +1662,24 @@ def _run_react_llm_step(self, config: RunnableConfig | None, loop_num: int, **kw try: native_parallel = self.parallel_tool_calls_enabled and self.inference_mode == InferenceMode.FUNCTION_CALLING + # In FUNCTION_CALLING mode with tools present, force a tool call so + # the model cannot bail out with a text-only response. Honour any + # explicit caller override (kwargs / self.llm.tool_choice). + forced_tool_choice = None + if ( + self.inference_mode == InferenceMode.FUNCTION_CALLING + and self._tools + and "tool_choice" not in kwargs + and getattr(self.llm, "tool_choice", None) is None + ): + forced_tool_choice = "required" llm_result = self._run_llm( messages=messages, tools=self._tools, response_format=self._response_format, config=llm_config, parallel_tool_calls=True if native_parallel else None, + **({"tool_choice": forced_tool_choice} if forced_tool_choice else {}), **kwargs, ) finally: diff --git a/dynamiq/nodes/agents/components/schema_generator.py b/dynamiq/nodes/agents/components/schema_generator.py index 25481a4bc..523b10e49 100644 --- a/dynamiq/nodes/agents/components/schema_generator.py +++ b/dynamiq/nodes/agents/components/schema_generator.py @@ -19,7 +19,6 @@ FINAL_ANSWER_FUNCTION_SCHEMA = { "type": "function", - "strict": True, "function": { "name": "provide_final_answer", "description": "Function should be called when if you can answer the initial request" @@ -38,6 +37,7 @@ }, }, "required": ["thought", "answer", "output_files"], + "additionalProperties": False, }, }, } @@ -75,7 +75,6 @@ def build_final_answer_function_schema(response_format: dict | type[BaseModel] | return FINAL_ANSWER_FUNCTION_SCHEMA answer_schema = unwrap_response_format(response_format) - strict = _is_strict_compatible(answer_schema) parameters = { "type": "object", @@ -91,13 +90,11 @@ def build_final_answer_function_schema(response_format: dict | type[BaseModel] | }, }, "required": ["thought", "answer", "output_files"], + "additionalProperties": False, } - if strict: - parameters["additionalProperties"] = False return { "type": "function", - "strict": strict, "function": { "name": "provide_final_answer", "description": ( @@ -239,8 +236,8 @@ def _resolve_type_schema(param: Any, _seen: set | None = None) -> dict[str, Any] ``properties`` so the LLM produces correctly structured output. Generic ``dict`` types become bare ``{"type": "object"}``. - Tools whose schemas contain bare objects automatically get - ``strict: false`` via ``_is_strict_compatible``. + Per-provider transforms (in ``BaseLLM`` subclasses) decide whether + ``strict`` is engaged for a given schema and provider. """ if param is type(None): return {"type": "null"} @@ -323,29 +320,6 @@ def _basemodel_to_schema(model: type[BaseModel], _seen: set | None = None) -> di return result -def _is_strict_compatible(schema: Any) -> bool: - """Return ``False`` if the schema contains an object that OpenAI strict mode - would reject — bare objects without ``properties``, or objects missing - ``additionalProperties: False``.""" - if not isinstance(schema, dict): - return True - schema_type = schema.get("type") - is_object = schema_type == "object" or (isinstance(schema_type, list) and "object" in schema_type) - if is_object: - if "properties" not in schema: - return False - if schema.get("additionalProperties") is not False: - return False - for value in schema.values(): - if isinstance(value, dict) and not _is_strict_compatible(value): - return False - if isinstance(value, list): - for item in value: - if isinstance(item, dict) and not _is_strict_compatible(item): - return False - return True - - def _is_nullable(annotation: Any) -> bool: """Return True if the annotation is a Union that includes NoneType.""" origin = get_origin(annotation) @@ -427,20 +401,19 @@ def generate_function_calling_schemas( properties["delegate_final"] = { "type": "boolean", "description": ( - "Set to true to return the sub-agent's response verbatim " - "as the parent agent's final output." + "Set to true to return the sub-agent's response verbatim " "as the parent agent's final output." ), } - has_optional = len(required_fields) < len(properties) # Flat-args: prepend `thought` so it streams first and the model sees it before tool params. properties = { "thought": {"type": "string", "description": "Your reasoning about using this tool."}, **properties, } - use_strict = _is_strict_compatible(properties) and not has_optional - required = ["thought", *properties.keys()] if use_strict else ["thought", *required_fields] - required = list(dict.fromkeys(required)) + # Only genuinely-required fields are required here. Strict mode (which + # promotes every property to ``required``) is applied per-provider in + # ``BaseLLM.transform_tool_schemas``, not at generation time. + required = list(dict.fromkeys(["thought", *required_fields])) schema = { "type": "function", @@ -453,12 +426,8 @@ def generate_function_calling_schemas( "additionalProperties": False, "required": required, }, - "strict": use_strict, }, } - - schemas.append(schema) - else: # `extra="allow"` tools (e.g. generic Python) take arbitrary params: # keep the object open and non-strict so the model can pass them as @@ -480,10 +449,9 @@ def generate_function_calling_schemas( "additionalProperties": allows_extra, "required": ["thought"], }, - "strict": not allows_extra, }, } - schemas.append(schema) + schemas.append(schema) return schemas diff --git a/dynamiq/nodes/agents/prompts/react/instructions.py b/dynamiq/nodes/agents/prompts/react/instructions.py index 5b8d52b32..5a42c9ea1 100644 --- a/dynamiq/nodes/agents/prompts/react/instructions.py +++ b/dynamiq/nodes/agents/prompts/react/instructions.py @@ -107,9 +107,12 @@ - Make sure to adhere to AGENT PERSONA & STYLE & ADDITIONAL BEHAVIORAL GUIDELINES. ## Single Action Per Turn -- Execute exactly ONE / pair per response, then wait for its Observation before continuing -- Do NOT include multiple action blocks or answer blocks in the same response -- After receiving an Observation, decide the next single action based on the result +- Emit EXACTLY ONE ... block per response. Never write a second block. +- A response either takes an action (/) OR gives an — never both in the same response. +- When you write an , STOP immediately after . Do NOT continue, do NOT write an "Observation:", + a tool result, or an — the Observation is given back to you by the system, not written by you. +- NEVER predict, assume, or fabricate the tool's Observation/result. Wait for the real Observation to be returned, + then decide your next single action (or final answer) based on it in your NEXT response. ## JSON Formatting Requirements - Put JSON on single line within tags diff --git a/dynamiq/nodes/llms/anthropic.py b/dynamiq/nodes/llms/anthropic.py index 64a70442c..49a85ad54 100644 --- a/dynamiq/nodes/llms/anthropic.py +++ b/dynamiq/nodes/llms/anthropic.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Any, ClassVar, Literal from pydantic import BaseModel @@ -8,6 +8,139 @@ from dynamiq.utils.logger import logger +def _patch_litellm_anthropic_strict_forward() -> bool: + """Make LiteLLM forward ``strict: true`` from OpenAI-shape tools to Anthropic. + + LiteLLM's ``AnthropicConfig._map_tool_helper`` translates an OpenAI-style + tool definition into Anthropic's native shape but drops the ``strict`` + field on the way through. Until that's fixed upstream, this monkey-patch + wraps the method and lifts ``function.strict`` onto the resulting Anthropic + tool dict so Anthropic's grammar-constrained sampling actually engages. + + The patch is defensive: it logs and returns ``False`` if anything looks + different from what we expect (LiteLLM moved the method, refactored the + config class, etc.) — in that case strict-on-Anthropic via LiteLLM stops + working but nothing else breaks. It is also idempotent: re-importing this + module won't double-patch. + + Returns: + True if the patch was applied (or already in place), False if it was + skipped due to unexpected LiteLLM internals. + """ + try: + from litellm.llms.anthropic.chat import transformation as _xform + except Exception as exc: # ImportError, ModuleNotFoundError, etc. + logger.debug("LiteLLM Anthropic strict patch: import failed: %s", exc) + return False + + config_cls = getattr(_xform, "AnthropicConfig", None) + if config_cls is None: + logger.warning("LiteLLM Anthropic strict patch: AnthropicConfig not found; skipping.") + return False + + original = getattr(config_cls, "_map_tool_helper", None) + if original is None: + logger.warning("LiteLLM Anthropic strict patch: _map_tool_helper not found; skipping.") + return False + + if getattr(original, "__dynamiq_strict_patch__", False): + return True # already applied + + def _patched(self, tool): + returned_tool, mcp_server = original(self, tool) + try: + if returned_tool is not None and isinstance(tool, dict) and tool.get("type") == "function": + fn = tool.get("function") or {} + strict = fn.get("strict") + if strict is not None: + returned_tool["strict"] = strict + except Exception as exc: + # Never let a patching failure break the LiteLLM call path. + logger.debug("LiteLLM Anthropic strict patch: lift failed: %s", exc) + return returned_tool, mcp_server + + _patched.__dynamiq_strict_patch__ = True + config_cls._map_tool_helper = _patched + return True + + +_LITELLM_ANTHROPIC_STRICT_PATCHED = _patch_litellm_anthropic_strict_forward() +if _LITELLM_ANTHROPIC_STRICT_PATCHED: + logger.debug("Patched LiteLLM AnthropicConfig to forward `strict: true` for Anthropic tools.") + +# Anthropic strict tool use — per-request caps documented in the API: +# https://platform.claude.com/docs/en/agents-and-tools/tool-use/strict-tool-use +ANTHROPIC_MAX_STRICT_TOOLS = 20 + + +def _clean_anthropic_strict_schema(schema: Any) -> Any: + """Recursively clean a schema for Anthropic's strict tool-use mode. + + - Forces ``additionalProperties: false`` on every object that declares + ``properties``. + - Free-form objects (``dict[str, Any]`` → ``{"type": "object"}`` with no + ``properties``) are converted to JSON-encoded string fields, since strict + mode can't express an open object. The agent parses them back to dicts + before tool validation (see ``_normalize_fields``). + - Optional fields stay omitted from ``required`` (Anthropic's native shape; + no null-union trick). + """ + if not isinstance(schema, dict): + return schema + + schema_type = schema.get("type") + is_object = schema_type == "object" or (isinstance(schema_type, list) and "object" in schema_type) + if is_object and "properties" not in schema: + desc = schema.get("description", "") + return { + "type": "string", + "description": (f"{desc} " if desc else "") + "Provide as a JSON-encoded object string.", + } + + cleaned: dict = {} + for key, value in schema.items(): + if key == "default" and value is None: + # A null default conveys optionality, which Anthropic expresses via + # ``required`` omission. Drop it so it can't clash with a now non-null type. + continue + if key == "type" and isinstance(value, list): + # Anthropic conveys optionality by omitting the field from ``required``, + # not via a null-union. Strip ``null`` so a nullable scalar/enum keeps a + # single declared type (e.g. ``["string", "null"]`` -> ``"string"``); + # Anthropic rejects an enum whose declared type is ``["string", "null"]``. + non_null = [t for t in value if t != "null"] + cleaned["type"] = non_null[0] if len(non_null) == 1 else (non_null or value) + elif key == "properties" and isinstance(value, dict): + cleaned["properties"] = {k: _clean_anthropic_strict_schema(v) for k, v in value.items()} + elif key == "items" and isinstance(value, dict): + cleaned["items"] = _clean_anthropic_strict_schema(value) + elif key in ("anyOf", "oneOf", "allOf") and isinstance(value, list): + branches = [_clean_anthropic_strict_schema(v) if isinstance(v, dict) else v for v in value] + if key in ("anyOf", "oneOf"): + # Drop the ``{"type": "null"}`` branch — nullability is conveyed by + # leaving the field out of ``required`` (Anthropic's native shape). + non_null = [b for b in branches if not (isinstance(b, dict) and b.get("type") == "null")] + branches = non_null or branches + cleaned[key] = branches + else: + cleaned[key] = value + + # Inline a single-branch anyOf/oneOf left over after dropping the null branch, so + # Anthropic sees a plain typed schema (e.g. a nullable enum) instead of a 1-item union. + for union_key in ("anyOf", "oneOf"): + branches = cleaned.get(union_key) + if isinstance(branches, list) and len(branches) == 1 and isinstance(branches[0], dict): + del cleaned[union_key] + for k, v in branches[0].items(): + cleaned.setdefault(k, v) + + cleaned_type = cleaned.get("type") + if cleaned_type == "object" or (isinstance(cleaned_type, list) and "object" in cleaned_type): + cleaned["additionalProperties"] = False + + return cleaned + + class AnthropicCacheControl(BaseModel): """Anthropic prompt caching configuration.""" @@ -24,9 +157,19 @@ class Anthropic(BaseLLM): Attributes: connection (AnthropicConnection | None): The connection to use for the Anthropic LLM. cache_control (AnthropicCacheControl | None): The cache control configuration. + strict_tools: Inherited from :class:`BaseLLM`. False (default, or an empty + list) ships every tool as-is with no strict guarantee; True cleans each + tool's schema to Anthropic's strict subset and attaches ``strict: true`` + (up to :data:`ANTHROPIC_MAX_STRICT_TOOLS` per request); a list of tool + (function) names makes only those tools strict and ships the rest + untouched. Use a list to exclude tools whose schema exceeds Anthropic's + strict grammar-compilation budget (the ``Schema is too complex for + compilation`` 400). """ + connection: AnthropicConnection | None = None MODEL_PREFIX = "anthropic/" + MAX_STRICT_TOOLS: ClassVar[int] = ANTHROPIC_MAX_STRICT_TOOLS cache_control: AnthropicCacheControl | None = None def __init__(self, **kwargs): @@ -89,3 +232,23 @@ def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]: } ) return params + + def _to_strict_function(self, fn: dict) -> dict: + """Clean one tool's schema to Anthropic's strict shape and attach ``strict``. + + Cleans the parameter schema to Anthropic's strict shape (optionality via + ``required`` omission, free-form objects → JSON-string fields, + ``additionalProperties: false``) and attaches ``strict: true``. A function + without a dict ``parameters`` is returned unchanged (nothing to make strict). + + See :meth:`BaseLLM.transform_tool_schemas` for the shared gating, whitelist, + per-request cap (:attr:`MAX_STRICT_TOOLS`), and fail-safe fallback that drive + this hook. + """ + out = dict(fn) + parameters = out.get("parameters") + if not isinstance(parameters, dict): + return out + out["parameters"] = _clean_anthropic_strict_schema(parameters) + out["strict"] = True + return out diff --git a/dynamiq/nodes/llms/base.py b/dynamiq/nodes/llms/base.py index bd8d3cb55..ae855c7c1 100644 --- a/dynamiq/nodes/llms/base.py +++ b/dynamiq/nodes/llms/base.py @@ -234,6 +234,17 @@ class BaseLLM(ConnectionNode): default=None, description="Configuration for fallback behavior including the fallback LLM.", ) + strict_tools: bool | list[str] = Field( + default=False, + description=( + "Controls provider strict tool-calling (only honored by providers that " + "implement it, e.g. OpenAI and Anthropic). False (default) disables strict " + "for all tools; True makes every tool strict; a list of tool (function) " + "names makes only those tools strict and ships the rest non-strict. Use a " + "list to exclude tools whose schema is too complex for a provider's strict " + "grammar compilation." + ), + ) model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -654,6 +665,73 @@ def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]: params["stream_options"]["include_usage"] = True return params + # Per-request cap on strict tools. ``None`` means no cap. Providers with a + # hard limit (e.g. Anthropic) override this. + MAX_STRICT_TOOLS: ClassVar[int | None] = None + + def transform_tool_schemas(self, tools: list[dict]) -> list[dict]: + """Provider-specific schema transform applied before dispatch. + + Shared scaffolding for strict tool-calling, identical across providers: + gate on ``strict_tools``, apply the optional whitelist, respect + :attr:`MAX_STRICT_TOOLS`, and convert each eligible tool via + :meth:`_to_strict_function`. Conversion is **fail-safe** — if a tool's + schema can't be made strict (an exotic/malformed shape that makes the + converter raise), we log a warning, drop ``strict``, and ship that tool + non-strict so the request still succeeds. + + The default :meth:`_to_strict_function` is a no-op, so a generic provider + leaves every tool untouched. OpenAI/Anthropic override only that hook. + + Args: + tools: Function-calling tool schemas (already in OpenAI shape). + + Returns: + Transformed tool list (new dicts for converted tools). + """ + if not self.strict_tools: + return tools + + whitelist = self.strict_tools if isinstance(self.strict_tools, list) else None + cap = self.MAX_STRICT_TOOLS if self.MAX_STRICT_TOOLS is not None else float("inf") + out: list[dict] = [] + strict_count = 0 + for tool in tools: + if not isinstance(tool, dict): + out.append(tool) + continue + tool = dict(tool) + fn = tool.get("function") + eligible = ( + isinstance(fn, dict) and strict_count < cap and (whitelist is None or fn.get("name") in whitelist) + ) + if eligible: + try: + new_fn = self._to_strict_function(fn) + except Exception as exc: + logger.warning( + "Strict tool conversion failed for tool %s: %s; keeping non-strict schema", + fn.get("name", ""), + exc, + ) + new_fn = dict(fn) + new_fn.pop("strict", None) + tool["function"] = new_fn + if new_fn.get("strict"): + strict_count += 1 + out.append(tool) + return out + + def _to_strict_function(self, fn: dict) -> dict: + """Convert one function-tool definition into this provider's strict form. + + Returns a new function dict with ``strict: true`` attached when the schema + was made strict-compatible. The base implementation is a no-op (generic + providers don't support strict). May raise on an unconvertible schema — + :meth:`transform_tool_schemas` catches it and falls back to non-strict. + """ + return fn + @staticmethod def _sanitize_fc_messages(messages: list[dict]) -> list[dict]: """Repair FC messages before dispatch. See ``_fc_sanitization`` module.""" @@ -700,12 +778,14 @@ def _build_completion_params( response_format=response_format, ) if tools: + tools = self.transform_tool_schemas(tools) messages = self._sanitize_fc_messages(messages) # Check if a streaming callback is available in the config and enable streaming only if it is. # This is to avoid unnecessary streaming to reduce CPU usage. is_streaming_callback_available = any( isinstance(callback, BaseStreamingCallbackHandler) for callback in config.callbacks ) + effective_tool_choice = tool_choice if tool_choice is not None else self.tool_choice common_params: dict[str, Any] = { "model": self.model, "messages": messages, @@ -713,7 +793,7 @@ def _build_completion_params( "temperature": self.temperature, "max_tokens": self.max_tokens, "tools": tools, - "tool_choice": tool_choice if tool_choice is not None else self.tool_choice, + "tool_choice": effective_tool_choice, "stop": self.stop if self.stop else None, "top_p": self.top_p, "seed": self.seed, diff --git a/dynamiq/nodes/llms/openai.py b/dynamiq/nodes/llms/openai.py index ede172a95..77a26c6bc 100644 --- a/dynamiq/nodes/llms/openai.py +++ b/dynamiq/nodes/llms/openai.py @@ -6,6 +6,107 @@ from dynamiq.nodes.llms.base import BaseLLM +def _add_null_to_type(prop: Any) -> Any: + """Make a converted property nullable so it can sit in ``required`` while + letting the model signal "leave it at the default" by emitting ``null``. + + Handles plain types (``"x"`` → ``["x", "null"]``), type arrays, and complex + ``anyOf`` unions (adds a ``{"type": "null"}`` branch). No-op if already nullable + or if it's a stringified free-form object. + + The input is never mutated: a shallow copy is returned when a change is needed, + otherwise the original object is returned unchanged. + """ + if not isinstance(prop, dict): + return prop + t = prop.get("type") + if isinstance(t, str) and t != "null": + return {**prop, "type": [t, "null"]} + if isinstance(t, list) and "null" not in t: + return {**prop, "type": [*t, "null"]} + if "anyOf" in prop: + branches = prop["anyOf"] + if not any(isinstance(b, dict) and b.get("type") == "null" for b in branches): + return {**prop, "anyOf": [*branches, {"type": "null"}]} + return prop + + +def _to_openai_strict_property(prop: Any) -> Any: + """Convert a property schema into OpenAI strict-mode form. + + - ``anyOf`` of ``[primitive, null]`` is flattened to a type array ``["X", "null"]``. + - Nested objects get ``additionalProperties: false`` and every property in + ``required``. Nested fields that have a default (i.e. were NOT in the + object's own ``required``) are made nullable, so the model can emit ``null`` + to leave them at their default. The agent strips those nulls before tool + validation so the Python default applies (see ``_normalize_fields``). + - Arrays' ``items`` are converted recursively. + """ + if not isinstance(prop, dict): + return prop + + if "anyOf" in prop and "type" not in prop: + primitive_types: list[str] = [] + has_complex = False + for option in prop["anyOf"]: + if isinstance(option, dict) and "type" in option: + opt_type = option["type"] + if opt_type in ("object", "array") or isinstance(opt_type, list): + has_complex = True + break + primitive_types.append(opt_type) + else: + has_complex = True + break + if not has_complex and primitive_types: + out: dict[str, Any] = {"type": primitive_types} + for key in ("description", "default", "enum", "title"): + if key in prop: + out[key] = prop[key] + return out + return { + "anyOf": [_to_openai_strict_property(opt) for opt in prop["anyOf"]], + **{k: v for k, v in prop.items() if k in ("description", "title", "default")}, + } + + cleaned: dict = dict(prop) + + param_type = cleaned.get("type") + # `type` may be a plain string ("object") or a nullable type-array + # (["object", "null"] for `dict | None`). Detect both. + is_nullable_type = isinstance(param_type, list) and "null" in param_type + is_object = param_type == "object" or (isinstance(param_type, list) and "object" in param_type) + is_array = param_type == "array" or (isinstance(param_type, list) and "array" in param_type) + + if is_object: + if "properties" not in cleaned: + # Free-form object (dict[str, Any]). Strict can't express an open + # object, so represent it as a JSON-encoded string. The agent parses + # it back to a dict before tool validation (see _normalize_fields). + # Preserve nullability so a `dict | None` field can still be null. + desc = cleaned.get("description", "") + return { + "type": ["string", "null"] if is_nullable_type else "string", + "description": (f"{desc} " if desc else "") + "Provide as a JSON-encoded object string.", + } + nested = cleaned["properties"] + nested_required = set(cleaned.get("required", [])) + converted_nested: dict[str, Any] = {} + for k, v in nested.items(): + cv = _to_openai_strict_property(v) + if k not in nested_required: + cv = _add_null_to_type(cv) + converted_nested[k] = cv + cleaned["properties"] = converted_nested + cleaned["required"] = list(nested.keys()) + cleaned["additionalProperties"] = False + return cleaned + if is_array and isinstance(cleaned.get("items"), dict): + cleaned["items"] = _to_openai_strict_property(cleaned["items"]) + return cleaned + return cleaned + + class ReasoningEffort(str, enum.Enum): """ The reasoning effort to use for the OpenAI LLM. @@ -65,6 +166,12 @@ class OpenAI(BaseLLM): Attributes: connection (OpenAIConnection | None): The connection to use for the OpenAI LLM. + strict_tools: Inherited from :class:`BaseLLM`. False (default, or an empty + list) ships every tool as-is; True converts each tool's parameter schema + into OpenAI structured-outputs strict form (every property required, + optionals re-encoded as nullable types, ``additionalProperties: false`` + on every object) and attaches ``strict: true``; a list of tool (function) + names converts only those tools and ships the rest untouched. """ connection: OpenAIConnection | None = None reasoning_effort: ReasoningEffort | None = ReasoningEffort.AUTO @@ -105,6 +212,49 @@ def _apply_reasoning_effort(params: dict[str, Any], effort: ReasoningEffort | No else: params["reasoning_effort"] = effort + def _to_strict_function(self, fn: dict) -> dict: + """Convert one tool's schema into OpenAI structured-outputs strict form. + + Every property is promoted to ``required`` and ``additionalProperties: false`` + is set at every object level. Fields that have a default (NOT in the original + ``required``) — or that are already nullable — are made nullable so the model + can emit ``null`` to leave them at their default; the agent then strips those + nulls before tool validation so the Python default applies. Genuinely-required + fields (no default) stay non-nullable and must be emitted. ``strict: true`` + is attached. Free-form objects (``dict[str, Any]``) are handled inside + ``_to_openai_strict_property`` (converted to JSON-string fields). + + See :meth:`BaseLLM.transform_tool_schemas` for the shared gating, + whitelist, and fail-safe fallback that drive this hook. + """ + out = dict(fn) + parameters = out.get("parameters") + if not isinstance(parameters, dict): + return out + + properties = parameters.get("properties", {}) + original_required = set(parameters.get("required", [])) + converted_props: dict[str, Any] = {} + for name, prop in properties.items(): + converted = _to_openai_strict_property(prop) + if name not in original_required: + converted = _add_null_to_type(converted) + converted_props[name] = converted + + new_parameters: dict[str, Any] = { + "type": "object", + "properties": converted_props, + "required": list(converted_props.keys()), + "additionalProperties": False, + } + for key in ("description", "title"): + if key in parameters: + new_parameters[key] = parameters[key] + + out["parameters"] = new_parameters + out["strict"] = True + return out + def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]: """ Override the base method to update the completion parameters for OpenAI. diff --git a/tests/integration_with_creds/agents/streaming_assertions.py b/tests/integration_with_creds/agents/streaming_assertions.py index 6f06cd5b1..fdc9c8b2f 100644 --- a/tests/integration_with_creds/agents/streaming_assertions.py +++ b/tests/integration_with_creds/agents/streaming_assertions.py @@ -271,10 +271,39 @@ def _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_ return run_parallel_count -def _handle_answer(event_name, reasoning_blocks): - """Pop reasoning block on answer event.""" - if event_name == "answer" and reasoning_blocks: +def _handle_answer(event_name, content, reasoning_blocks, answer_chunks): + """Pop reasoning block on answer event and accumulate answer chunks.""" + if event_name != "answer": + return + if reasoning_blocks: reasoning_blocks.pop(0) + if isinstance(content, str): + answer_chunks.append(content) + elif isinstance(content, dict) and isinstance(content.get("answer"), str): + answer_chunks.append(content["answer"]) + + +def _assert_answer_clean(answer_chunks: list[str]) -> None: + """Assert the streamed answer is the answer text itself, not the raw wrapper. + + Regression guard: an SO finish answer lives under ``action_input["answer"]``. + A chunk-boundary race used to lock the stream onto the brace-delimited + ``action_input`` object and emit the whole ``{"answer": "..."}`` wrapper + instead of just the answer text, diverging from the non-streaming output. + Reject an accumulated answer that parses as a JSON object exposing an + ``answer`` key — that means the wrapper leaked into the stream. + """ + accumulated = "".join(answer_chunks).strip() + if not accumulated.startswith("{"): + return + try: + decoded = json.loads(accumulated) + except json.JSONDecodeError: + return + assert not (isinstance(decoded, dict) and "answer" in decoded), ( + "Streamed answer leaked the action_input wrapper instead of the answer text. " + f"Accumulated answer: {accumulated!r}" + ) def _track_parallel_individual_post_parse(content, idx, parallel_post_parse_tids): @@ -366,6 +395,17 @@ def _match_action_input(accumulated: str, expected) -> bool: return True attempts.append(f"decoded==expected: {decoded!r} == {expected!r} -> False") + # Strict null-default handling: under OpenAI strict, optional non-nullable fields + # are emitted as ``null`` ("use default") and stripped before the tool receives + # them (see _normalize_fields). The streamed tool_input is the raw model + # output and still carries those nulls, so prune null-valued keys absent from the + # (post-strip) expected action_input before comparing. + if isinstance(decoded, dict) and isinstance(expected, dict): + pruned = {k: v for k, v in decoded.items() if not (v is None and k not in expected)} + if pruned == expected: + return True + attempts.append(f"pruned==expected: {pruned!r} == {expected!r} -> False") + if isinstance(expected, dict) and "input" in expected: inner = expected["input"] if decoded == inner: @@ -483,6 +523,7 @@ def _run_fsm_fc(ordered_events, streaming_mode): state = State.INIT visited = {state} reasoning_blocks: list[str] = [] + answer_chunks: list[str] = [] tool_blocks: dict[str, dict] = {} run_parallel_count = 0 parallel_post_parse_tids: set[str] = set() @@ -569,12 +610,13 @@ def _run_fsm_fc(ordered_events, streaming_mode): _validate_single_post_parse(content, tool_blocks, reasoning_blocks) run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) - _handle_answer(event_name, reasoning_blocks) + _handle_answer(event_name, content, reasoning_blocks, answer_chunks) state = next_state visited.add(state) _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count, parallel_post_parse_tids) + _assert_answer_clean(answer_chunks) return state, visited, reasoning_blocks @@ -592,6 +634,7 @@ def _run_fsm_blob(ordered_events, streaming_mode): state = State.INIT visited = {state} reasoning_blocks: list[str] = [] + answer_chunks: list[str] = [] tool_blocks: dict[str, dict] = {} run_parallel_count = 0 parallel_post_parse_tids: set[str] = set() @@ -650,12 +693,13 @@ def _run_fsm_blob(ordered_events, streaming_mode): _validate_single_post_parse(content, tool_blocks, reasoning_blocks) run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) - _handle_answer(event_name, reasoning_blocks) + _handle_answer(event_name, content, reasoning_blocks, answer_chunks) state = next_state visited.add(state) _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count, parallel_post_parse_tids) + _assert_answer_clean(answer_chunks) return state, visited, reasoning_blocks @@ -669,6 +713,7 @@ def _run_fsm_default(ordered_events, streaming_mode): state = State.INIT visited = {state} reasoning_blocks: list[str] = [] + answer_chunks: list[str] = [] run_parallel_count = 0 tool_blocks: dict[str, dict] = {} @@ -685,12 +730,13 @@ def _run_fsm_default(ordered_events, streaming_mode): run_parallel_count += 1 run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) - _handle_answer(event_name, reasoning_blocks) + _handle_answer(event_name, content, reasoning_blocks, answer_chunks) state = next_state visited.add(state) _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count) + _assert_answer_clean(answer_chunks) return state, visited, reasoning_blocks diff --git a/tests/integration_with_creds/agents/test_agent_python_tool.py b/tests/integration_with_creds/agents/test_agent_python_tool.py index 197e6ad31..098506026 100644 --- a/tests/integration_with_creds/agents/test_agent_python_tool.py +++ b/tests/integration_with_creds/agents/test_agent_python_tool.py @@ -34,10 +34,11 @@ class Priority(str, Enum): class FilterOptions(BaseModel): - """Nested model -- tests Model | None union in the parent schema.""" + """Nested model -- tests Model | None union plus a nested free-form dict sub-field.""" min_score: float = Field(default=0.0, description="Minimum score threshold.") tags: list[str] = Field(default_factory=list, description="Tags to filter by.") + metadata: dict[str, Any] = Field(default_factory=dict, description="Free-form metadata (arbitrary keys).") class ActionType(str, Enum): @@ -48,6 +49,20 @@ class ActionType(str, Enum): EXTRACT = "extract" +class FilterRule(BaseModel): + """List-element model -- exercises ``list[Model]`` normalization. + + Carries a non-nullable defaulted field (``op``; strict emits ``null`` to mean + "use default") and a free-form ``dict[str, Any]`` (``params``; strict ships a + JSON-encoded string), both *inside list items* so the return-path normalization + must recurse into each element. + """ + + field: str = Field(..., description="Field name to match on.") + op: str = Field(default="eq", description="Comparison operator.") + params: dict[str, Any] = Field(default_factory=dict, description="Free-form operator params (arbitrary keys).") + + class ComprehensiveInputSchema(BaseModel): """Single schema covering all type patterns found in real tool schemas. @@ -58,6 +73,12 @@ class ComprehensiveInputSchema(BaseModel): Nullable: int|None (with ge/le), int|None (bare, no Field), str|None, bool|None, Enum|None, list[str]|None, Model|None + Nested: + Model|None whose sub-model carries a free-form dict[str, Any] + (FilterOptions.metadata) -- exercises nested strict-string coercion + List of models: + list[Model] (FilterRule) whose elements carry a non-nullable defaulted + field and a free-form dict -- exercises per-element list normalization Special: is_accessible_to_agent=False, ConfigDict(extra='allow') """ @@ -79,6 +100,9 @@ class ComprehensiveInputSchema(BaseModel): priority: Priority | None = Field(default=None, description="Optional priority level.") domains: list[str] | None = Field(default=None, description="Whitelist of domains.") filters: FilterOptions | None = Field(default=None, description="Optional filter configuration.") + rules: list[FilterRule] = Field( + default_factory=list, description="Ordered filter rules; each may carry free-form params." + ) internal_trace_id: str | None = Field( default=None, description="Internal tracing ID.", @@ -136,6 +160,8 @@ def execute(self, input_data: ComprehensiveInputSchema, config: RunnableConfig = extras.append(f"domains={input_data.domains}") if input_data.filters is not None: extras.append(f"filters(min_score={input_data.filters.min_score}, tags={input_data.filters.tags})") + if input_data.rules: + extras.append(f"rules={[(r.field, r.op, r.params) for r in input_data.rules]}") if extras: body += " Options: " + ", ".join(extras) + "." @@ -277,7 +303,7 @@ def run_and_assert_agent(agent: Agent, agent_input, expected_length, run_config) def test_react_agent_inference_modes( llm_instance, string_length_tool_instance, agent_role, agent_input, expected_length, run_config, inference_mode ): - """Test agent with Python tool across different inference modes.""" + """Test agent with a simple Python tool across different inference modes.""" agent = Agent( name=f"Test Agent {inference_mode.value}", llm=llm_instance, @@ -293,8 +319,15 @@ def test_react_agent_inference_modes( run_and_assert_agent(agent, agent_input, expected_length, run_config) -def _run_comprehensive_schema_test(llm, comprehensive_tool, run_config, inference_mode, label): - """Helper: run agent with ComprehensiveTool and assert success with streaming validation.""" +def _run_comprehensive_schema_test(llm, comprehensive_tool, run_config, inference_mode, label, strict_tools=False): + """Helper: run agent with ComprehensiveTool and assert success with streaming validation. + + When ``strict_tools`` is set, the LLM is copied with ``strict_tools=True`` so the + provider strict transform (``_to_strict_function``) and the agent-side normalization + round-trip (``_normalize_fields``) are exercised end-to-end against the complex schema. + """ + if strict_tools: + llm = llm.model_copy(update={"strict_tools": True}) agent = Agent( name=f"Comprehensive Schema Test ({label})", llm=llm, @@ -339,19 +372,38 @@ def _run_comprehensive_schema_test(llm, comprehensive_tool, run_config, inferenc @pytest.mark.integration @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize( - "inference_mode", - [InferenceMode.STRUCTURED_OUTPUT, InferenceMode.FUNCTION_CALLING], - ids=["structured_output", "function_calling"], -) -@pytest.mark.parametrize( - "llm_fixture", - ["llm_instance", "anthropic_llm"], - ids=["openai", "anthropic"], + "llm_fixture, inference_mode, strict_tools", + [ + ("llm_instance", InferenceMode.STRUCTURED_OUTPUT, False), + ("anthropic_llm", InferenceMode.STRUCTURED_OUTPUT, False), + ("llm_instance", InferenceMode.FUNCTION_CALLING, False), + ("anthropic_llm", InferenceMode.FUNCTION_CALLING, False), + ("llm_instance", InferenceMode.FUNCTION_CALLING, True), + ("anthropic_llm", InferenceMode.FUNCTION_CALLING, True), + ], + ids=[ + "openai-so", + "anthropic-so", + "openai-fc-nonstrict", + "anthropic-fc-nonstrict", + "openai-fc-strict", + "anthropic-fc-strict", + ], ) -def test_comprehensive_schema_tool_modes(llm_fixture, comprehensive_tool, run_config, inference_mode, request): - """Comprehensive typed schema (non-nullable + nullable + hidden fields) with OpenAI and Anthropic.""" +def test_comprehensive_schema_tool_modes( + llm_fixture, inference_mode, strict_tools, comprehensive_tool, run_config, request +): + """Comprehensive typed schema on the complex tool across an explicit provider/mode/strict matrix. + + - OpenAI and Anthropic: STRUCTURED_OUTPUT, FUNCTION_CALLING (non-strict and strict). + + ``strict_tools`` only varies for FUNCTION_CALLING — it is the only mode where tool + schemas are sent to the provider as function tools (SO uses ``response_format``), so + it's the only mode the strict transform applies to. + """ llm = request.getfixturevalue(llm_fixture) provider = "openai" if "llm_instance" in llm_fixture else "anthropic" + label = f"{provider}-{inference_mode.value}{'-strict' if strict_tools else ''}" _run_comprehensive_schema_test( - llm, comprehensive_tool, run_config, inference_mode, f"{provider}-{inference_mode.value}" + llm, comprehensive_tool, run_config, inference_mode, label, strict_tools=strict_tools ) diff --git a/tests/unit/nodes/agents/test_agent_parsing.py b/tests/unit/nodes/agents/test_agent_parsing.py index 697733817..b4f2cbd0c 100644 --- a/tests/unit/nodes/agents/test_agent_parsing.py +++ b/tests/unit/nodes/agents/test_agent_parsing.py @@ -1,5 +1,9 @@ +from typing import Any, ClassVar, Literal + import pytest +from pydantic import BaseModel, Field +from dynamiq.nodes import Node, NodeGroup from dynamiq.nodes.agents.components import parser from dynamiq.nodes.agents.exceptions import ActionParsingException @@ -600,29 +604,6 @@ class Doc(BaseModel): assert "title" in answer["properties"] -def test_agent_response_format_structured_output_schema_unchanged(): - """STRUCTURED_OUTPUT keeps its simple `action_input: string` schema; the - user's response_format is enforced via prompt injection + coerce instead.""" - from pydantic import BaseModel - - from dynamiq.nodes.types import InferenceMode - - class Doc(BaseModel): - title: str - tags: list[str] - - agent = _make_agent(inference_mode=InferenceMode.STRUCTURED_OUTPUT, response_format=Doc) - schema = agent._response_format["json_schema"]["schema"] - assert schema["properties"]["action_input"] == { - "type": "string", - "description": "Input for chosen action.", - } - rendered = agent.generate_prompt() - assert "MUST be a valid JSON document" in rendered - assert "title" in rendered - assert "tags" in rendered - - def test_agent_structured_output_finish_with_json_string_action_input(): """STRUCTURED_OUTPUT finish emits a JSON string; coerce parses it into a dict.""" from pydantic import BaseModel @@ -727,3 +708,43 @@ class Schema(BaseModel): with pytest.raises(ValueError, match="not exposed to the agent"): apply_param_modes(Schema, {"internal_id": "required"}) + + +def test_normalize_fields_coerces_nested_model(): + """A stringified free-form dict nested inside a sub-model is coerced back to a dict. + + Strict mode ships a free-form ``dict[str, Any]`` as a JSON-encoded string. For a + dict declared on a nested model (``FilterOptions.metadata``), ``_normalize_fields`` + must recurse into the sub-model and parse it back -- otherwise the string survives + and the nested model's Pydantic validation rejects it. + """ + class FilterOptions(BaseModel): + min_score: float = Field(default=0.0) + metadata: dict[str, Any] = Field(default_factory=dict) + + class ComprehensiveInputSchema(BaseModel): + text: str + filters: FilterOptions | None = None + + class ComprehensiveTool(Node): + group: Literal[NodeGroup.TOOLS] = NodeGroup.TOOLS + name: str = "Comprehensive Tool" + input_schema: ClassVar[type[ComprehensiveInputSchema]] = ComprehensiveInputSchema + + def execute(self, input_data, config=None, **kwargs): + return {} + + tool = ComprehensiveTool() + agent = _make_agent() + + action_input = { + "text": "hello", + "filters": {"min_score": 0.5, "metadata": '{"source": "web", "score": 1}'}, + } + + agent._normalize_fields(tool.input_schema.model_fields, action_input) + + # Nested free-form dict string parsed back into a dict. + assert action_input["filters"]["metadata"] == {"source": "web", "score": 1} + # Non-string nested values are left untouched. + assert action_input["filters"]["min_score"] == 0.5 diff --git a/tests/unit/nodes/agents/test_schema_generator_flat_args.py b/tests/unit/nodes/agents/test_schema_generator_flat_args.py index 934fb8b79..c234befc4 100644 --- a/tests/unit/nodes/agents/test_schema_generator_flat_args.py +++ b/tests/unit/nodes/agents/test_schema_generator_flat_args.py @@ -5,8 +5,10 @@ chain-of-thought behavior). * Tool params are TOP-LEVEL siblings of `thought` (no `action_input` wrapper). * `additionalProperties: false` is set on `parameters`. - * Strict mode is enabled only when every property is required AND the schema - contains no shapes OpenAI strict mode would reject. + * Only genuinely-required fields go in `required`. Strict mode (and its + all-required promotion) is applied per-provider in + `BaseLLM.transform_tool_schemas`, not at generation time, so no `strict` + key is emitted here. """ from unittest.mock import MagicMock @@ -91,7 +93,7 @@ class _Schema(BaseModel): properties = schema["function"]["parameters"]["properties"] assert "action_input" not in properties, f"action_input wrapper leaked into {schema['function']['name']}" - def test_strict_when_all_params_required(self): + def test_all_required_params_in_required_list(self): class _Schema(BaseModel): a: str = Field(..., description="A") b: str = Field(..., description="B") @@ -100,11 +102,12 @@ class _Schema(BaseModel): schemas = _gen(tool) tool_schema = next(s for s in schemas if s["function"]["name"] == "op") - assert tool_schema["function"]["strict"] is True + # No gen-time strict key — strict is applied per-provider downstream. + assert "strict" not in tool_schema["function"] assert tool_schema["function"]["parameters"]["additionalProperties"] is False assert set(tool_schema["function"]["parameters"]["required"]) == {"thought", "a", "b"} - def test_strict_dropped_when_any_param_is_optional(self): + def test_optional_params_excluded_from_required_list(self): class _Schema(BaseModel): required_field: str = Field(..., description="Required") optional_field: str | None = Field(default=None, description="Optional") @@ -113,7 +116,7 @@ class _Schema(BaseModel): schemas = _gen(tool) tool_schema = next(s for s in schemas if s["function"]["name"] == "op") - assert tool_schema["function"]["strict"] is False + assert "strict" not in tool_schema["function"] # required list excludes the optional field but always starts with thought assert tool_schema["function"]["parameters"]["required"][0] == "thought" assert "required_field" in tool_schema["function"]["parameters"]["required"] @@ -133,12 +136,12 @@ class _Schema(BaseModel): if schema["function"]["name"] == "provide_final_answer": continue # final-answer schema is untouched by this migration params = schema["function"]["parameters"] - assert params.get("additionalProperties") is False, ( - f"{schema['function']['name']} missing additionalProperties: false at parameters level" - ) + assert ( + params.get("additionalProperties") is False + ), f"{schema['function']['name']} missing additionalProperties: false at parameters level" - def test_zero_param_tool_is_strict_by_default(self): - """A tool with no params produces a fully-required, strict schema.""" + def test_zero_param_tool_requires_only_thought(self): + """A tool with no params produces a closed schema requiring only `thought`.""" class _Empty(BaseModel): pass @@ -146,13 +149,16 @@ class _Empty(BaseModel): tool = _make_tool("ping", _Empty) schemas = _gen(tool) tool_schema = next(s for s in schemas if s["function"]["name"] == "ping") + params = tool_schema["function"]["parameters"] - assert tool_schema["function"]["strict"] is True + assert "strict" not in tool_schema["function"] + assert params["additionalProperties"] is False + assert params["required"] == ["thought"] - def test_extra_allow_tool_is_open_and_non_strict(self): + def test_extra_allow_tool_is_open(self): """A no-declared-fields tool that accepts extras (e.g. the generic Python - tool) must stay OPEN: additionalProperties true and non-strict, so the model - can pass arbitrary params as top-level siblings of `thought`.""" + tool) must stay OPEN: additionalProperties true, so the model can pass + arbitrary params as top-level siblings of `thought`.""" from pydantic import ConfigDict class _Dynamic(BaseModel): @@ -164,7 +170,7 @@ class _Dynamic(BaseModel): params = tool_schema["function"]["parameters"] assert params["additionalProperties"] is True - assert tool_schema["function"]["strict"] is False + assert "strict" not in tool_schema["function"] assert params["required"] == ["thought"]