-
Notifications
You must be signed in to change notification settings - Fork 128
feat: rework tool calling #753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
6ab653a
a335bfe
d8a8a8c
29200f7
2b3d5b3
1336778
c165266
3d9fbfe
841e990
34d1ee3
e2b85c4
3ac384a
38565b3
1e2e9b1
6f4e213
786b369
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import json | ||
| import types | ||
| from concurrent.futures import as_completed | ||
| from typing import Any, Callable, Literal, Mapping | ||
| from typing import Any, Callable, Literal, Mapping, Union, get_args, get_origin | ||
|
|
||
| from litellm import get_supported_openai_params, supports_function_calling, supports_response_schema | ||
| from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator | ||
|
|
@@ -39,7 +40,7 @@ | |
| StreamingMode, | ||
| ) | ||
| from dynamiq.utils import generate_uuid, serialize_files_in_value | ||
| from dynamiq.utils.json_parser import parse_llm_json_output | ||
| from dynamiq.utils.json_parser import parse_llm_json_output, repair_truncated_json | ||
| from dynamiq.utils.logger import logger | ||
|
|
||
|
|
||
|
|
@@ -79,8 +80,11 @@ def parse_arguments(cls, v: Any) -> Any: | |
| if isinstance(v, str): | ||
| try: | ||
| return json.loads(v, strict=False) | ||
| except json.JSONDecodeError as e: | ||
| raise ValueError(f"Tool call arguments are not valid JSON: {e}") | ||
| except json.JSONDecodeError: | ||
| try: | ||
| return json.loads(repair_truncated_json(v), strict=False) | ||
| except json.JSONDecodeError as e: | ||
| raise ValueError(f"Tool call arguments are not valid JSON: {e}") | ||
| return v or {} | ||
|
|
||
| def parse_as_tool_call(self) -> ToolCallArguments: | ||
|
|
@@ -320,6 +324,93 @@ def _emit_tool_input_error( | |
| ) | ||
| self._streaming_tool_run_id = None | ||
|
|
||
| @staticmethod | ||
| def _annotation_accepts_none(annotation: Any) -> bool: | ||
| """Return True if a Pydantic field annotation includes ``NoneType``.""" | ||
| if annotation is type(None): | ||
| return True | ||
| origin = get_origin(annotation) | ||
| if origin in (Union, types.UnionType): | ||
| return type(None) in get_args(annotation) | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def _annotation_is_dict_like(annotation: Any) -> bool: | ||
| """Return True if the annotation is ``dict`` / ``dict[...]`` or a union including one.""" | ||
| if annotation is dict: | ||
| return True | ||
| origin = get_origin(annotation) | ||
| if origin is dict: | ||
| return True | ||
| if origin in (Union, types.UnionType): | ||
| return any(Agent._annotation_is_dict_like(arg) for arg in get_args(annotation)) | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def _extract_basemodel(annotation: Any) -> type[BaseModel] | None: | ||
| """Return the BaseModel subclass in an annotation (handles ``Model | None``), else None.""" | ||
| if isinstance(annotation, type) and issubclass(annotation, BaseModel): | ||
| return annotation | ||
| origin = get_origin(annotation) | ||
| if origin in (Union, types.UnionType): | ||
| for arg in get_args(annotation): | ||
| if isinstance(arg, type) and issubclass(arg, BaseModel): | ||
| return arg | ||
| return None | ||
|
|
||
| def _coerce_json_fields(self, tool: Node, action_input: dict) -> dict: | ||
| """Parse stringified free-form dict fields back into dicts. | ||
|
|
||
| Strict mode can't express a free-form ``dict[str, Any]`` as an object, so | ||
| the schema transforms ship those fields as JSON-encoded strings (see the | ||
| provider converters). Here we reverse that: if the tool declares a | ||
| dict-typed field and the model supplied a JSON string for it, parse it | ||
| back so the tool's Pydantic schema validates the real dict. | ||
| """ | ||
| fields = tool.input_schema.model_fields | ||
| for name, field in fields.items(): | ||
| value = action_input.get(name) | ||
| if isinstance(value, str) and self._annotation_is_dict_like(field.annotation): | ||
| stripped = value.strip() | ||
| if stripped.startswith("{") and stripped.endswith("}"): | ||
| try: | ||
| action_input[name] = json.loads(stripped) | ||
| except json.JSONDecodeError: | ||
| pass # leave as string; Pydantic will surface the error | ||
| return action_input | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| def _strip_protocol_nulls(self, tool: Node, action_input: dict) -> dict: | ||
| """Drop ``None`` values for fields whose Pydantic annotation rejects None. | ||
|
|
||
| OpenAI strict mode requires every property in ``required`` and uses | ||
| ``"null"`` in the type union as the signal for "leave it at the default." | ||
| Fields with a non-nullable default (``encoding: str = "utf-8"``) can't | ||
| accept that ``None`` directly — so we drop the key, letting the tool's | ||
| Pydantic default apply. Fields that genuinely accept ``None`` | ||
| (``encoding: str | None = None``) keep it. | ||
|
|
||
| Recurses into nested ``BaseModel`` fields so the same applies at depth | ||
| (e.g. ``config.port`` where ``DBConfig.port: int = 8080``). | ||
| """ | ||
| self._strip_nulls_for_fields(tool.input_schema.model_fields, action_input) | ||
| return action_input | ||
|
|
||
| def _strip_nulls_for_fields(self, fields: Mapping[str, Any], data: Any) -> None: | ||
| if not isinstance(data, dict): | ||
| return | ||
| for name in list(data): | ||
| field = fields.get(name) | ||
| if field is None: | ||
| continue | ||
| value = data[name] | ||
| if value is None: | ||
| if not self._annotation_accepts_none(field.annotation): | ||
| del data[name] | ||
| elif isinstance(value, dict): | ||
| nested_model = self._extract_basemodel(field.annotation) | ||
| if nested_model is not None: | ||
| self._strip_nulls_for_fields(nested_model.model_fields, value) | ||
|
|
||
| def _should_delegate_final( | ||
| self, | ||
| tool: Node | None, | ||
|
|
@@ -791,6 +882,7 @@ def _handle_function_calling_mode( | |
| if len(actual_tool_calls) > 1 and self.parallel_tool_calls_enabled: | ||
| tool_items = [] | ||
| for tc in actual_tool_calls: | ||
| tc_name = tc.function.name.strip() | ||
| args = tc.function.parse_as_tool_call() | ||
| tc_input = args.action_input | ||
| if isinstance(tc_input, str): | ||
|
|
@@ -800,9 +892,20 @@ def _handle_function_calling_mode( | |
| raise ActionParsingException(f"Error parsing action_input string. {e}", recoverable=True) | ||
| if not isinstance(tc_input, dict): | ||
| tc_input = {"input": tc_input} | ||
| tc_tool = self.tool_by_names.get(self.sanitize_tool_name(tc_name)) | ||
| if tc_tool is not None: | ||
| self._coerce_json_fields(tc_tool, tc_input) | ||
| self._strip_protocol_nulls(tc_tool, tc_input) | ||
| try: | ||
| tc_tool.input_schema.model_validate(tc_input) | ||
| except Exception as e: | ||
| raise ActionParsingException( | ||
| f"Tool call for '{tc_name}' has invalid arguments: {e}", | ||
| recoverable=True, | ||
| ) | ||
| tool_items.append( | ||
| ToolCallItem( | ||
| name=tc.function.name.strip(), | ||
| name=tc_name, | ||
| input=tc_input, | ||
| thought=args.thought, | ||
| ) | ||
|
|
@@ -826,6 +929,18 @@ def _handle_function_calling_mode( | |
| if not isinstance(action_input, dict): | ||
| action_input = {"input": action_input} | ||
|
|
||
| tool = self.tool_by_names.get(self.sanitize_tool_name(action)) | ||
| if tool is not None: | ||
| self._coerce_json_fields(tool, action_input) | ||
| self._strip_protocol_nulls(tool, action_input) | ||
| try: | ||
| tool.input_schema.model_validate(action_input) | ||
| except Exception as e: | ||
| raise ActionParsingException( | ||
| f"Tool call for '{action}' has invalid arguments: {e}", | ||
| recoverable=True, | ||
| ) | ||
|
|
||
| self.log_reasoning(thought, action, action_input, loop_num) | ||
| return thought, action, action_input | ||
|
|
||
|
|
@@ -866,8 +981,14 @@ def _handle_structured_output_mode( | |
| self._requested_output_files = self._parse_output_files_csv( | ||
| llm_generated_output_json.get("output_files") or "" | ||
| ) | ||
| self.log_final_output(thought, action_input, loop_num) | ||
| return thought, "final_answer", action_input | ||
| # action_input is now an object (per schema); the final answer lives | ||
| # under the ``answer`` key. Fall back to the raw value for backward | ||
| # compatibility with older models that still emit a plain string. | ||
| final_answer: Any = action_input | ||
| if isinstance(action_input, dict) and "answer" in action_input: | ||
| final_answer = action_input["answer"] | ||
| self.log_final_output(thought, final_answer, loop_num) | ||
| return thought, "final_answer", final_answer | ||
|
maksymbuleshnyi marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Finish string action_input not parsedMedium Severity On structured-output Reviewed by Cursor Bugbot for commit 786b369. Configure here. |
||
|
|
||
| try: | ||
| if isinstance(action_input, str): | ||
|
|
@@ -1504,12 +1625,24 @@ def _run_react_llm_step(self, config: RunnableConfig | None, loop_num: int, **kw | |
|
|
||
| try: | ||
| native_parallel = self.parallel_tool_calls_enabled and self.inference_mode == InferenceMode.FUNCTION_CALLING | ||
| # In FUNCTION_CALLING mode with tools present, force a tool call so | ||
| # the model cannot bail out with a text-only response. Honour any | ||
| # explicit caller override (kwargs / self.llm.tool_choice). | ||
| forced_tool_choice = None | ||
| if ( | ||
| self.inference_mode == InferenceMode.FUNCTION_CALLING | ||
| and self._tools | ||
| and "tool_choice" not in kwargs | ||
| and getattr(self.llm, "tool_choice", None) is None | ||
| ): | ||
| forced_tool_choice = "required" | ||
| llm_result = self._run_llm( | ||
| messages=messages, | ||
| tools=self._tools, | ||
| response_format=self._response_format, | ||
| config=llm_config, | ||
| parallel_tool_calls=True if native_parallel else None, | ||
| **({"tool_choice": forced_tool_choice} if forced_tool_choice else {}), | ||
|
maksymbuleshnyi marked this conversation as resolved.
|
||
| **kwargs, | ||
| ) | ||
| finally: | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.