diff --git a/dynamiq/callbacks/inner_thoughts_extractor.py b/dynamiq/callbacks/inner_thoughts_extractor.py new file mode 100644 index 000000000..125ef5268 --- /dev/null +++ b/dynamiq/callbacks/inner_thoughts_extractor.py @@ -0,0 +1,253 @@ +"""Inline ``thought`` extraction for streaming FC arguments. + +Splits the LLM's streaming JSON object into two output streams: the tool's +real params (with ``thought`` removed) and just the thought value content. +""" + +INNER_THOUGHTS_DEFAULT_KEY = "thought" + + +class JSONInnerThoughtsExtractor: + """Streaming JSON parser that routes the ``thought`` field into a separate buffer.""" + + def __init__( + self, + inner_thoughts_key: str = INNER_THOUGHTS_DEFAULT_KEY, + wait_for_first_key: bool = False, + ) -> None: + self.inner_thoughts_key = inner_thoughts_key + self.wait_for_first_key = wait_for_first_key + + # Cumulative buffers across all process_fragment calls. + self.main_buffer: str = "" + self.inner_thoughts_buffer: str = "" + self.main_json_held_buffer: str = "" + + # Parser state. + self.state: str = "start" + self.in_string: bool = False + self.escaped: bool = False + self.current_key: str = "" + self.is_inner_thoughts_value: bool = False + self.inner_thoughts_processed: bool = False + self.hold_main_json: bool = wait_for_first_key + + # Deferred top-level separator: emitted before the next main field, or dropped. + self.pending_comma: bool = False + + # Top-level transitions only fire at depth == 1; deeper structures pass through. + self.depth: int = 0 + + @property + def thought_complete(self) -> bool: + """Whether the thought field's value has been fully processed.""" + return self.inner_thoughts_processed + + @property + def held_main_buffer(self) -> str: + """Held bytes not yet flushed; drained at end-of-stream when thought was missing.""" + return self.main_json_held_buffer + + def process_fragment(self, fragment: str) -> tuple[str, str]: + """Feed a chunk; returns ``(main_delta, thought_delta)`` for this fragment.""" + updates_main: list[str] = [] + updates_thought: list[str] = [] + for c in fragment: + main_chunk, thought_chunk = self._process_char(c) + if main_chunk: + updates_main.append(main_chunk) + if thought_chunk: + updates_thought.append(thought_chunk) + return "".join(updates_main), "".join(updates_thought) + + def _emit_main(self, s: str) -> str: + """Append to main buffer (held or live); return the delta to surface.""" + if self.hold_main_json: + self.main_json_held_buffer += s + return "" + self.main_buffer += s + return s + + def _flush_held_buffer(self) -> str: + """Move held bytes to the live buffer and surface them as a delta.""" + if not self.main_json_held_buffer: + self.hold_main_json = False + return "" + delta = self.main_json_held_buffer + self.main_buffer += delta + self.main_json_held_buffer = "" + self.hold_main_json = False + return delta + + def _emit_thought(self, s: str) -> str: + self.inner_thoughts_buffer += s + return s + + def _process_char(self, c: str) -> tuple[str, str]: + """Process a single character and return its ``(main, thought)`` delta.""" + + if self.escaped: + self.escaped = False + return self._consume_value_char(c) + + if c == "\\": + self.escaped = True + if self.in_string: + return self._consume_value_char(c) + return "", "" + + if c == '"': + return self._handle_quote() + + if self.in_string: + return self._consume_value_char(c) + + # Structural characters (outside any string). + if c == "{": + return self._handle_open_object() + + if c == "[": + return self._handle_open_bracket() + + if c == "}": + return self._handle_close_object() + + if c == "]": + return self._handle_close_bracket() + + if self.depth >= 2: + # Inside nested object/array — passthrough to value's target buffer. + return self._consume_value_char(c) + + if c == ":" and self.state == "colon": + return self._handle_colon() + + # A top-level comma always ends the current value — including after a + # number/array/object value, where state stays "value" (no closing quote). + if c == "," and self.state in ("comma_or_end", "value"): + return self._handle_comma() + + if self.state == "value": + # Non-string scalar in top-level value (number, bool, null). + return self._consume_value_char(c) + + # Whitespace or non-structural chars outside any value — ignore. + return "", "" + + def _consume_value_char(self, c: str) -> tuple[str, str]: + """Route char to thought or main based on whose value we're in.""" + if self.in_string and self.state == "key": + self.current_key += c + return "", "" + if self.is_inner_thoughts_value and self.depth == 1: + return "", self._emit_thought(c) + return self._emit_main(c), "" + + def _handle_quote(self) -> tuple[str, str]: + """Handle an unescaped ``"``.""" + self.in_string = not self.in_string + + if self.in_string: + # Opening quote. + if self.depth >= 2: + # Inside nested value — passthrough. + return self._emit_main('"'), "" + + if self.state in ("start", "comma_or_end"): + # Start of a new top-level key — flush held bytes if thought is done. + main_delta = "" + if self.wait_for_first_key and self.hold_main_json and self.inner_thoughts_processed: + main_delta = self._flush_held_buffer() + self.state = "key" + self.current_key = "" + return main_delta, "" + + if self.state == "value": + if self.is_inner_thoughts_value: + return "", "" + return self._emit_main('"'), "" + + return "", "" + + # Closing quote. + if self.depth >= 2: + return self._emit_main('"'), "" + + if self.state == "key": + self.state = "colon" + return "", "" + + if self.state == "value": + if self.is_inner_thoughts_value: + self.inner_thoughts_processed = True + self.state = "comma_or_end" + return "", "" + self.state = "comma_or_end" + return self._emit_main('"'), "" + + return "", "" + + def _handle_open_object(self) -> tuple[str, str]: + self.depth += 1 + if self.depth == 1: + # Outermost ``{``. + return self._emit_main("{"), "" + # Nested object literal as a value — passthrough. + if self.is_inner_thoughts_value and self.depth == 2: + return "", self._emit_thought("{") + return self._emit_main("{"), "" + + def _handle_open_bracket(self) -> tuple[str, str]: + self.depth += 1 + if self.is_inner_thoughts_value and self.depth == 2: + return "", self._emit_thought("[") + return self._emit_main("["), "" + + def _handle_close_object(self) -> tuple[str, str]: + self.depth -= 1 + if self.depth >= 1: + # Closing a nested object — passthrough. + if self.is_inner_thoughts_value and self.depth == 1: + return "", self._emit_thought("}") + return self._emit_main("}"), "" + + # Outermost `}`. A deferred separator (thought-last case) is dropped: + # nothing follows it, so no comma was ever emitted to revoke. + self.pending_comma = False + self.state = "end" + if self.hold_main_json: + self.main_json_held_buffer += "}" + return "", "" + self.main_buffer += "}" + return "}", "" + + def _handle_close_bracket(self) -> tuple[str, str]: + self.depth -= 1 + if self.is_inner_thoughts_value and self.depth == 1: + return "", self._emit_thought("]") + return self._emit_main("]"), "" + + def _handle_colon(self) -> tuple[str, str]: + """Top-level `:` — colon → value transition.""" + self.state = "value" + self.is_inner_thoughts_value = self.current_key == self.inner_thoughts_key + if self.is_inner_thoughts_value: + # Skip the `"thought":` prefix from main. + return "", "" + # Surface any deferred separator right before this field's key. + prefix = "," if self.pending_comma else "" + self.pending_comma = False + return self._emit_main(f'{prefix}"{self.current_key}":'), "" + + def _handle_comma(self) -> tuple[str, str]: + """Top-level `,` — separates fields.""" + if self.is_inner_thoughts_value: + # Drop comma after thought to avoid a dangling separator. + self.is_inner_thoughts_value = False + self.state = "start" + return "", "" + + # Defer the separator until the next main field confirms it's needed. + self.pending_comma = True + self.state = "start" + return "", "" diff --git a/dynamiq/callbacks/streaming.py b/dynamiq/callbacks/streaming.py index dac8bc9e0..62f535bf7 100644 --- a/dynamiq/callbacks/streaming.py +++ b/dynamiq/callbacks/streaming.py @@ -6,6 +6,7 @@ from dynamiq.callbacks import BaseCallbackHandler from dynamiq.callbacks.base import get_run_id +from dynamiq.callbacks.inner_thoughts_extractor import JSONInnerThoughtsExtractor from dynamiq.types.streaming import ( AgentToolData, AgentToolInputDeltaData, @@ -369,6 +370,9 @@ def __init__(self, agent: "Agent", config, loop_num: int, **kwargs): self._current_action_name: str | None = None self._fc_object_tool_input: bool = False self._fc_object_answer: bool = False + # FC inline thought extractor + per-chunk delta. + self._fc_extractor: JSONInnerThoughtsExtractor | None = None + self._latest_fc_args_delta: str = "" self._brace_depth: int = 0 self._brace_scan_index: int = 0 self._so_action_emitted: bool = False @@ -417,6 +421,7 @@ def on_node_execute_stream(self, serialized: dict[str, Any], chunk: dict[str, An new_id = generate_uuid() self.agent._streaming_tool_run_id = new_id self.agent._streaming_tool_run_ids.append(new_id) + self._latest_fc_args_delta = text_delta or "" else: text_delta = self._extract_text_delta(chunk) @@ -452,7 +457,26 @@ def on_node_execute_end(self, serialized: dict[str, Any], output_data: dict[str, def _flush_buffer(self) -> None: """Flush the remaining buffer content by streaming it as one chunk.""" - if not self._buffer or len(self._buffer) <= self._state_last_emit_index: + if not self._buffer: + self._flush_chunk_buffer() + return + + # FC fallback: drain extractor's held buffer when thought was missing. + if ( + self.mode_name == InferenceMode.FUNCTION_CALLING.value + and self._fc_extractor is not None + and self._tool_input_started + and not self._answer_started + and not self._state_has_emitted.get(StreamingState.TOOL_INPUT, False) + ): + held = self._fc_extractor.held_main_buffer + if held: + self._emit(held, step=StreamingState.TOOL_INPUT) + self._state_last_emit_index = len(self._buffer) + self._flush_chunk_buffer() + return + + if len(self._buffer) <= self._state_last_emit_index: self._flush_chunk_buffer() return @@ -475,6 +499,8 @@ def _reset_tool_call_state(self) -> None: self._current_action_name = None self._fc_object_tool_input = False self._fc_object_answer = False + self._fc_extractor = None + self._latest_fc_args_delta = "" self._brace_depth = 0 self._brace_scan_index = 0 self._state_has_emitted = { @@ -905,8 +931,27 @@ def _process_structured_output_mode(self, final_answer_only: bool) -> None: self._so_action_emitted = True def _process_function_calling_mode(self, final_answer_only: bool) -> None: - """Process function calling mode.""" - self._process_json_mode(final_answer_only) + """FC mode: route `thought` to REASONING, tool args to TOOL_INPUT.""" + if self._answer_started: + self._process_json_mode(final_answer_only) + return + + if self._fc_extractor is None: + self._fc_extractor = JSONInnerThoughtsExtractor( + inner_thoughts_key=JSONStreamingField.THOUGHT.value, + wait_for_first_key=self.agent.streaming.fc_wait_for_first_key, + ) + + delta = self._latest_fc_args_delta + self._latest_fc_args_delta = "" + if not delta: + return + + main_delta, thought_delta = self._fc_extractor.process_fragment(delta) + if thought_delta and not final_answer_only: + self._emit(thought_delta, step=StreamingState.REASONING) + if main_delta: + self._emit(main_delta, step=StreamingState.TOOL_INPUT) def _find_unescaped_quote_end(self, input_string: str, start_quote_index: int) -> int: """ @@ -1060,10 +1105,10 @@ def _initialize_json_object_field_state(self, buf: str, field_name: str, state: return False def _try_initialize_next_json_field(self, buf: str, final_answer_only: bool) -> None: - """Try to initialize the next JSON field state (thought, answer, or action_input). + """Try to initialize the next JSON field state (thought or answer). - Each initializer is a no-op when _current_state is already set, so this is safe - to call multiple times within a single chunk processing cycle. + Used by the ANSWER path (provide_final_answer). FC tool calls use the + inline extractor in ``_process_function_calling_mode`` instead. """ if not self._state_has_emitted.get(StreamingState.REASONING, False): self._initialize_json_field_state( @@ -1077,15 +1122,6 @@ def _try_initialize_next_json_field(self, buf: str, final_answer_only: bool) -> buf, JSONStreamingField.ANSWER.value, StreamingState.ANSWER ) - if self._tool_input_started and not self._answer_started: - if not self._initialize_json_field_state( - buf, JSONStreamingField.ACTION_INPUT.value, StreamingState.TOOL_INPUT - ): - if self._current_state is None: - self._initialize_json_object_field_state( - buf, JSONStreamingField.ACTION_INPUT.value, StreamingState.TOOL_INPUT - ) - def _emit_tool_input_state(self, buf: str) -> None: """Emit content for the current TOOL_INPUT state.""" if self._fc_object_tool_input: diff --git a/dynamiq/nodes/agents/agent.py b/dynamiq/nodes/agents/agent.py index 50eef5530..b593f144d 100644 --- a/dynamiq/nodes/agents/agent.py +++ b/dynamiq/nodes/agents/agent.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Literal, Mapping from litellm import get_supported_openai_params, supports_function_calling, supports_response_schema -from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator from dynamiq.callbacks import AgentStreamingParserCallback, StreamingQueueCallbackHandler from dynamiq.executors.context import ContextAwareThreadPoolExecutor @@ -44,8 +44,19 @@ class ToolCallArguments(BaseModel): + """Flat function-calling arguments: `thought` sibling of the tool's real params. + + Tool params arrive via Pydantic's `extra="allow"` and are extracted via + `to_action_input()`. + """ + + model_config = ConfigDict(extra="allow") + thought: str = "" - action_input: dict | str + + def to_action_input(self) -> dict: + """Return the non-thought fields as a plain dict (the tool's real params).""" + return self.model_dump(exclude={"thought"}) @field_validator("thought", mode="before") @classmethod @@ -90,7 +101,7 @@ def parse_as_tool_call(self) -> ToolCallArguments: raise ActionParsingException( "Your tool call is missing required fields. " "Every tool call must include 'thought' (your reasoning) " - "and 'action_input' (the tool parameters as an object).", + "and the tool's parameters as top-level fields.", recoverable=True, ) @@ -792,7 +803,7 @@ def _handle_function_calling_mode( tool_items = [] for tc in actual_tool_calls: args = tc.function.parse_as_tool_call() - tc_input = args.action_input + tc_input = args.to_action_input() if isinstance(tc_input, str): try: tc_input = json.loads(tc_input, strict=False) @@ -817,7 +828,7 @@ def _handle_function_calling_mode( action = single_call.function.name.strip() args = single_call.function.parse_as_tool_call() thought = args.thought - action_input = args.action_input + action_input = args.to_action_input() if isinstance(action_input, str): try: action_input = json.loads(action_input, strict=False) diff --git a/dynamiq/nodes/agents/components/schema_generator.py b/dynamiq/nodes/agents/components/schema_generator.py index d49caa38b..af096cab1 100644 --- a/dynamiq/nodes/agents/components/schema_generator.py +++ b/dynamiq/nodes/agents/components/schema_generator.py @@ -433,19 +433,14 @@ def generate_function_calling_schemas( } has_optional = len(required_fields) < len(properties) - use_strict = _is_strict_compatible(properties) and not has_optional - - action_input_schema: dict[str, Any] = { - "type": "object", - "description": "Tool parameters as a JSON object, not a string.", - "properties": properties, + # Flat-args: prepend `thought` so it streams first and the model sees it before tool params. + properties = { + "thought": {"type": "string", "description": "Your reasoning about using this tool."}, + **properties, } - if use_strict: - action_input_schema["required"] = list(properties.keys()) - action_input_schema["additionalProperties"] = False - else: - if required_fields: - action_input_schema["required"] = required_fields + use_strict = _is_strict_compatible(properties) and not has_optional + required = ["thought", *properties.keys()] if use_strict else ["thought", *required_fields] + required = list(dict.fromkeys(required)) schema = { "type": "function", @@ -454,15 +449,9 @@ def generate_function_calling_schemas( "description": tool.description[:1024], "parameters": { "type": "object", - "properties": { - "thought": { - "type": "string", - "description": "Your reasoning about using this tool.", - }, - "action_input": action_input_schema, - }, + "properties": properties, "additionalProperties": False, - "required": ["thought", "action_input"], + "required": required, }, "strict": use_strict, }, @@ -471,6 +460,10 @@ def generate_function_calling_schemas( schemas.append(schema) else: + # `extra="allow"` tools (e.g. generic Python) take arbitrary params: + # keep the object open and non-strict so the model can pass them as + # top-level siblings of `thought`. Real zero-param tools stay closed. + allows_extra = getattr(tool.input_schema, "model_config", {}).get("extra") == "allow" schema = { "type": "function", "function": { @@ -483,15 +476,11 @@ def generate_function_calling_schemas( "type": "string", "description": "Your reasoning about using this tool.", }, - "action_input": { - "type": "string", - "description": "Input for the selected tool in JSON string format.", - }, }, - "additionalProperties": False, - "required": ["thought", "action_input"], + "additionalProperties": allows_extra, + "required": ["thought"], }, - "strict": True, + "strict": not allows_extra, }, } diff --git a/dynamiq/nodes/agents/prompts/react/instructions.py b/dynamiq/nodes/agents/prompts/react/instructions.py index 751f447c5..5b8d52b32 100644 --- a/dynamiq/nodes/agents/prompts/react/instructions.py +++ b/dynamiq/nodes/agents/prompts/react/instructions.py @@ -198,12 +198,13 @@ you call `provide_final_answer` to deliver the final response. ## Function Calling Guidelines -- ALWAYS populate the "thought" field FIRST before any other field (particularly "action_input") in your function calls +- ALWAYS populate the "thought" field FIRST before any other field in your function calls +- Pass tool parameters as top-level fields of the function arguments (alongside "thought"), not nested inside an "action_input" wrapper - Analyze the request carefully to determine if tools are needed - Call functions with properly formatted arguments - Handle tool responses appropriately before providing final answer - Chain multiple tool calls when necessary for complex tasks -- If you want an agent tool's response returned verbatim as the final output, include "delegate_final": true inside that tool's action_input. Use this only for a single agent tool call and do not call provide_final_answer yourself; the system will return the agent's result directly. +- If you want an agent tool's response returned verbatim as the final output, include "delegate_final": true at the top level of that tool's arguments. Use this only for a single agent tool call and do not call provide_final_answer yourself; the system will return the agent's result directly. ## File Handling - Tools may generate or process files (images, CSVs, PDFs, etc.) diff --git a/dynamiq/types/streaming.py b/dynamiq/types/streaming.py index 2cc9dc764..af9f53f41 100644 --- a/dynamiq/types/streaming.py +++ b/dynamiq/types/streaming.py @@ -198,6 +198,11 @@ class StreamingConfig(BaseModel): min_chunk_chars (int): Minimum number of characters to accumulate before emitting a streaming event. Helps reduce event count by combining small fragments. 0 means no accumulation (emit immediately). Defaults to 0. + fc_wait_for_first_key (bool): FUNCTION_CALLING mode only. When True, hold + tool-input bytes until `thought` is processed so REASONING streams before + TOOL_INPUT (for providers that emit `thought` first, e.g. OpenAI strict). + When False, stream tool-input immediately regardless of key order. + Defaults to True. """ enabled: bool = False stream_tool_input: list[str] | None = None @@ -209,6 +214,7 @@ class StreamingConfig(BaseModel): mode: StreamingMode = StreamingMode.FINAL include_usage: bool = False min_chunk_chars: NonNegativeInt = 0 + fc_wait_for_first_key: bool = True model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/integration/cancellation/test_agent_cancellation_fc_memory.py b/tests/integration/cancellation/test_agent_cancellation_fc_memory.py index d48088e50..866d05efc 100644 --- a/tests/integration/cancellation/test_agent_cancellation_fc_memory.py +++ b/tests/integration/cancellation/test_agent_cancellation_fc_memory.py @@ -125,10 +125,13 @@ def test_sanitization_on_save_clears_orphan_from_memory(): ), "expected _prompt.messages to retain the cancelled orphan tool_call (canonical state policy)" -def test_fc_agent_recovers_when_tool_call_missing_required_argument(): - """FC mode: LLM emits a tool_call whose arguments lack the required - `action_input` field. parse_as_tool_call raises ActionParsingException. - The agent appends a Correction Instruction and recovers on the next loop.""" +def test_fc_agent_recovers_when_tool_call_arguments_are_malformed_json(): + """FC mode: LLM emits a tool_call whose ``arguments`` string is not valid JSON. + + ``FunctionCall.parse_arguments`` raises ValueError → wrapped as + ActionParsingException → the agent appends a Correction Instruction user + message and recovers on the next loop. + """ conn = connections.OpenAI(id="fake-conn", api_key="fake-key") llm = OpenAI( name="TestLLM", @@ -150,7 +153,7 @@ def test_fc_agent_recovers_when_tool_call_missing_required_argument(): max_loops=3, ) - def _llm_missing_action_input(*_a, **_kw): + def _llm_malformed_arguments(*_a, **_kw): return RunnableResult( status=RunnableStatus.SUCCESS, input={}, @@ -158,9 +161,13 @@ def _llm_missing_action_input(*_a, **_kw): "content": "", "tool_calls": [ { - "id": "call_missing", + "id": "call_bad", "type": "function", - "function": {"name": "search_tool", "arguments": '{"thought": "x"}'}, + "function": { + "name": "search_tool", + # Not valid JSON — triggers parse_arguments validator failure. + "arguments": "not valid json {", + }, } ], }, @@ -185,7 +192,7 @@ def _llm_final(*_a, **_kw): }, ) - responses = iter([_llm_missing_action_input(), _llm_final()]) + responses = iter([_llm_malformed_arguments(), _llm_final()]) with patch.object(agent, "_run_llm", side_effect=lambda *a, **kw: next(responses)): result = agent.run(input_data={"input": "go"}, config=RunnableConfig()) assert result.status == RunnableStatus.SUCCESS @@ -195,5 +202,5 @@ def _llm_final(*_a, **_kw): for m in agent._prompt.messages if m.role == MessageRole.USER and "Correction Instruction" in (m.content or "") ] - assert recovery, "no recovery instruction added for missing required argument" + assert recovery, "no recovery instruction added for malformed tool_call arguments" assert "ActionParsingException" in recovery[-1].content diff --git a/tests/integration/nodes/agents/test_agent_methods.py b/tests/integration/nodes/agents/test_agent_methods.py index aee71b400..cf6883f9d 100644 --- a/tests/integration/nodes/agents/test_agent_methods.py +++ b/tests/integration/nodes/agents/test_agent_methods.py @@ -531,14 +531,22 @@ def test_generate_function_calling_schemas(openai_node, mock_tool): assert "thought" in final_answer_schema["function"]["parameters"]["properties"] assert "answer" in final_answer_schema["function"]["parameters"]["properties"] - # Verify all schemas have required structure + # Verify all schemas have required structure and flat-args shape for schema in schemas: assert "type" in schema assert schema["type"] == "function" assert "function" in schema assert "name" in schema["function"] assert "parameters" in schema["function"] - assert "properties" in schema["function"]["parameters"] + parameters = schema["function"]["parameters"] + properties = parameters["properties"] + assert "properties" in parameters + # Flat-args: no action_input wrapper, tool params (if any) are top-level siblings of thought + assert "action_input" not in properties + # thought is the first property (load-bearing for streaming UX and model behavior) + assert next(iter(properties)) == "thought" + # thought is in required + assert parameters["required"][0] == "thought" def test_agent_injects_file_store_into_python_code_executor(openai_node, mock_llm_executor): diff --git a/tests/integration_with_creds/callbacks/__init__.py b/tests/integration_with_creds/callbacks/__init__.py new file mode 100644 index 000000000..bf6bd6c59 --- /dev/null +++ b/tests/integration_with_creds/callbacks/__init__.py @@ -0,0 +1,3 @@ +from dotenv import load_dotenv + +load_dotenv() diff --git a/tests/integration_with_creds/nodes/llms/test_fc_sanitization_live.py b/tests/integration_with_creds/callbacks/test_fc_sanitization_live.py similarity index 100% rename from tests/integration_with_creds/nodes/llms/test_fc_sanitization_live.py rename to tests/integration_with_creds/callbacks/test_fc_sanitization_live.py diff --git a/tests/integration_with_creds/callbacks/test_inner_thoughts_extractor_live.py b/tests/integration_with_creds/callbacks/test_inner_thoughts_extractor_live.py new file mode 100644 index 000000000..217d01657 --- /dev/null +++ b/tests/integration_with_creds/callbacks/test_inner_thoughts_extractor_live.py @@ -0,0 +1,139 @@ +"""Live test for ``JSONInnerThoughtsExtractor`` against a real OpenAI model. + +Drives a real FC stream through ``AgentStreamingParserCallback`` — the same +``on_node_execute_stream`` path the production pipeline uses — and asserts: + + * REASONING events accumulate to the model's thought text (non-empty). + * TOOL_INPUT events accumulate into JSON args with the deterministic + ``message`` we asked for, with ``thought`` correctly stripped. + +This exercises the full inline-thought split inside the streaming layer end +to end, not just the extractor in isolation. +""" + +import json +import os +from unittest.mock import MagicMock + +import litellm +import pytest + +from dynamiq.callbacks.streaming import AgentStreamingParserCallback, StreamingState +from dynamiq.nodes.types import InferenceMode +from dynamiq.types.streaming import StreamingMode + +pytestmark = pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + + +EXPECTED_MESSAGE = "hello world from inner-thoughts test" +LLM_ID = "llm-1" + + +def _make_fc_callback() -> AgentStreamingParserCallback: + """Build the same callback the production pipeline wires up, but with a + mock agent so we can inspect emissions directly via ``call_args_list``.""" + agent = MagicMock() + agent.streaming.enabled = True + agent.streaming.mode = StreamingMode.ALL + agent.streaming.stream_tool_input = None + agent.streaming.min_chunk_chars = 0 + agent.inference_mode.name = InferenceMode.FUNCTION_CALLING.value + agent.name = "live-extractor-agent" + agent._streaming_tool_run_id = None + agent._streaming_tool_run_ids = [] + agent.tool_by_names = {} + agent.sanitize_tool_name = lambda name: name + agent.llm = MagicMock() + agent.llm.id = LLM_ID + return AgentStreamingParserCallback(agent=agent, config=None, loop_num=1) + + +def _collect_reasoning(cb) -> list[str]: + """REASONING chunks are wrapped as ``{"thought": ""}``.""" + out: list[str] = [] + for c in cb.agent.stream_content.call_args_list: + if c.kwargs.get("step") != StreamingState.REASONING: + continue + content = c.kwargs.get("content") + if isinstance(content, dict) and isinstance(content.get("thought"), str): + out.append(content["thought"]) + return out + + +def _collect_tool_input(cb) -> list[str]: + """TOOL_INPUT chunks are dicts with ``action_input`` string fragments.""" + out: list[str] = [] + for c in cb.agent.stream_content.call_args_list: + if c.kwargs.get("step") != StreamingState.TOOL_INPUT: + continue + content = c.kwargs.get("content") + if isinstance(content, dict) and isinstance(content.get("action_input"), str): + out.append(content["action_input"]) + return out + + +def test_extractor_emits_correct_reasoning_and_tool_input_via_callback(): + tool = { + "type": "function", + "function": { + "name": "echo", + "description": "Echo back the given message verbatim.", + "parameters": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "Your reasoning about using this tool.", + }, + "message": { + "type": "string", + "description": "The message to echo back.", + }, + }, + "required": ["thought", "message"], + "additionalProperties": False, + }, + }, + } + + stream = litellm.completion( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": ( + f"Call the `echo` tool. The `message` field MUST be exactly " + f"'{EXPECTED_MESSAGE}' (verbatim, no extra words). Use the " + f"`thought` field to briefly explain why you're calling it." + ), + } + ], + tools=[tool], + tool_choice={"type": "function", "function": {"name": "echo"}}, + stream=True, + temperature=0, + max_tokens=200, + ) + + cb = _make_fc_callback() + serialized = {"group": "llms", "id": LLM_ID} + + for chunk in stream: + cb.on_node_execute_stream(serialized, chunk.model_dump()) + cb.on_node_execute_end(serialized, output_data={}) + + reasoning_chunks = _collect_reasoning(cb) + tool_input_chunks = _collect_tool_input(cb) + + assert reasoning_chunks, "no REASONING events were emitted" + assert tool_input_chunks, "no TOOL_INPUT events were emitted" + + accumulated_thought = "".join(reasoning_chunks).strip() + assert accumulated_thought, "REASONING channel reassembled to empty string" + + accumulated_tool_input = "".join(tool_input_chunks) + parsed = json.loads(accumulated_tool_input) + assert "thought" not in parsed, f"thought leaked into TOOL_INPUT: {parsed}" + assert parsed.get("message") == EXPECTED_MESSAGE, ( + f"message mismatch — expected {EXPECTED_MESSAGE!r}, " f"got {parsed.get('message')!r} (full payload: {parsed})" + ) diff --git a/tests/unit/nodes/agents/test_agent_parsing.py b/tests/unit/nodes/agents/test_agent_parsing.py index b69914e90..302d13a5c 100644 --- a/tests/unit/nodes/agents/test_agent_parsing.py +++ b/tests/unit/nodes/agents/test_agent_parsing.py @@ -367,43 +367,6 @@ def test_structured_output_fallback_decoder_with_literal_newlines(): assert action_input == {"command": "cat > f.py\nprint(1)"} -def test_function_calling_action_input_with_literal_newlines(mocker): - """FC mode: strict=False allows action_input with literal newlines.""" - import uuid - - from dynamiq import connections, prompts - from dynamiq.nodes.agents import Agent - from dynamiq.nodes.llms import OpenAI - from dynamiq.nodes.types import InferenceMode - - conn = connections.OpenAI(id=str(uuid.uuid4()), api_key="fake-key") - llm = OpenAI( - name="TestLLM", - model="gpt-4o-mini", - connection=conn, - prompt=prompts.Prompt(messages=[prompts.Message(role="user", content="{{input}}")]), - ) - agent = Agent(name="test-agent", llm=llm, tools=[], inference_mode=InferenceMode.FUNCTION_CALLING) - - # Simulate an LLM result with a tool_call whose action_input has a literal newline - mock_result = mocker.MagicMock() - mock_result.output = { - "tool_calls": [ - { - "function": { - "name": "SandboxShellTool", - "arguments": {"thought": "run it", "action_input": '{"cmd": "ls\nls -la"}'}, - } - } - ] - } - - thought, action, action_input = agent._handle_function_calling_mode(mock_result, loop_num=1) - assert thought == "run it" - assert action == "SandboxShellTool" - assert action_input == {"cmd": "ls\nls -la"} - - def _mock_llm_response(text: str): from litellm import ModelResponse diff --git a/tests/unit/nodes/agents/test_inner_thoughts_extractor.py b/tests/unit/nodes/agents/test_inner_thoughts_extractor.py new file mode 100644 index 000000000..2fe7943e4 --- /dev/null +++ b/tests/unit/nodes/agents/test_inner_thoughts_extractor.py @@ -0,0 +1,225 @@ +"""Unit tests for ``JSONInnerThoughtsExtractor``.""" + +import json + +import pytest + +from dynamiq.callbacks.inner_thoughts_extractor import JSONInnerThoughtsExtractor + + +def _drive(raw: str, *, wait_for_first_key: bool = False, char_by_char: bool = False): + """Feed ``raw`` into a fresh extractor and return ``(main, thought)``. + + When ``wait_for_first_key=True``, the held buffer (if any remains at the + end) is appended to ``main`` so callers can compare the final state. + """ + ext = JSONInnerThoughtsExtractor(wait_for_first_key=wait_for_first_key) + if char_by_char: + for ch in raw: + ext.process_fragment(ch) + else: + ext.process_fragment(raw) + main = ext.main_buffer + if wait_for_first_key and ext.held_main_buffer: + main += ext.held_main_buffer + return main, ext.inner_thoughts_buffer + + +class TestThoughtPositions: + def test_thought_first(self): + main, thought = _drive('{"thought":"hi","query":"weather"}') + assert json.loads(main) == {"query": "weather"} + assert thought == "hi" + + def test_thought_middle(self): + main, thought = _drive('{"a":"x","thought":"hi","b":"y"}') + assert json.loads(main) == {"a": "x", "b": "y"} + assert thought == "hi" + + def test_thought_last(self): + # Trailing comma must be stripped before the closing }. + main, thought = _drive('{"a":"x","thought":"hi"}') + assert json.loads(main) == {"a": "x"} + assert thought == "hi" + + def test_thought_missing(self): + # With wait_for_first_key=False, main streams unchanged. + main, thought = _drive('{"query":"weather","limit":5}') + assert json.loads(main) == {"query": "weather", "limit": 5} + assert thought == "" + + def test_thought_missing_with_wait_for_first_key(self): + # With wait=True, held buffer never flushes; the safety net is the + # ``held_main_buffer`` property which the streaming layer flushes + # at end of stream. + main, thought = _drive( + '{"query":"weather","limit":5}', wait_for_first_key=True + ) + assert json.loads(main) == {"query": "weather", "limit": 5} + assert thought == "" + + +class TestWaitForFirstKey: + def test_thought_late_held_then_flushed(self): + ext = JSONInnerThoughtsExtractor(wait_for_first_key=True) + # Feed pre-thought field — main delta must be empty (held). + main_a, thought_a = ext.process_fragment('{"a":"x",') + assert main_a == "", "main should be held while thought is unresolved" + assert thought_a == "" + + # Feed thought. + main_b, thought_b = ext.process_fragment('"thought":"hi"') + assert thought_b == "hi" + assert main_b == "", "main still held until next key starts" + + # Feed next field — held buffer flushes when its opening quote arrives. + main_c, thought_c = ext.process_fragment(',"b":"y"}') + assert thought_c == "" + # First chunk delta carries the held bytes + the new field bytes. + assert main_c.startswith("{") + assert json.loads(ext.main_buffer) == {"a": "x", "b": "y"} + + def test_thought_first_flushes_quickly(self): + ext = JSONInnerThoughtsExtractor(wait_for_first_key=True) + main, thought = ext.process_fragment('{"thought":"hi","q":"x"}') + assert thought == "hi" + assert json.loads(ext.main_buffer) == {"q": "x"} + + +class TestThoughtInStringValue: + def test_word_thought_inside_other_value(self): + # "thought" appears as substring inside the value of `query`. + main, thought = _drive('{"thought":"hi","query":"what is a thought?"}') + assert json.loads(main) == {"query": "what is a thought?"} + assert thought == "hi" + + def test_quoted_thought_key_inside_value(self): + # Even an escaped `"thought":` inside a value must not be routed. + raw = '{"thought":"hi","query":"contains \\"thought\\": pattern"}' + main, thought = _drive(raw) + assert json.loads(main) == {"query": 'contains "thought": pattern'} + assert thought == "hi" + + +class TestNestedStructures: + def test_nested_object_param(self): + main, thought = _drive('{"thought":"hi","config":{"x":1,"y":2}}') + assert json.loads(main) == {"config": {"x": 1, "y": 2}} + assert thought == "hi" + + def test_deeply_nested(self): + raw = '{"thought":"hi","a":{"b":{"c":{"d":42}}}}' + main, thought = _drive(raw) + assert json.loads(main) == {"a": {"b": {"c": {"d": 42}}}} + assert thought == "hi" + + def test_nested_thought_key_treated_as_regular(self): + # Inner `thought` is just a regular field — must NOT be routed. + raw = '{"thought":"outer","config":{"thought":"inner","x":1}}' + main, thought = _drive(raw) + assert json.loads(main) == {"config": {"thought": "inner", "x": 1}} + assert thought == "outer" + + def test_array_param(self): + raw = '{"thought":"hi","items":[{"id":1},{"id":2}]}' + main, thought = _drive(raw) + assert json.loads(main) == {"items": [{"id": 1}, {"id": 2}]} + assert thought == "hi" + + def test_array_of_strings(self): + raw = '{"thought":"hi","tags":["a","b","c"]}' + main, thought = _drive(raw) + assert json.loads(main) == {"tags": ["a", "b", "c"]} + assert thought == "hi" + + +class TestEscapeSequences: + def test_escaped_quote_in_thought(self): + raw = '{"thought":"he said \\"hi\\"","q":"x"}' + main, thought = _drive(raw) + assert json.loads(main) == {"q": "x"} + assert thought == 'he said \\"hi\\"' + + def test_escaped_backslash(self): + raw = '{"thought":"path \\\\ here","q":"x"}' + main, thought = _drive(raw) + assert json.loads(main) == {"q": "x"} + + def test_newline_in_thought(self): + raw = '{"thought":"line1\\nline2","q":"x"}' + main, thought = _drive(raw) + assert json.loads(main) == {"q": "x"} + assert thought == "line1\\nline2" + + +class TestChunkedFeeding: + @pytest.mark.parametrize( + "raw", + [ + '{"thought":"hello world","q":"x"}', + '{"q":"x","thought":"hello world"}', + '{"thought":"hi","config":{"a":1,"b":2}}', + '{"thought":"hi","items":[1,2,3]}', + ], + ids=["thought_first", "thought_last", "nested_object", "array"], + ) + def test_char_by_char_matches_whole(self, raw): + whole_main, whole_thought = _drive(raw) + char_main, char_thought = _drive(raw, char_by_char=True) + assert whole_main == char_main + assert whole_thought == char_thought + + +class TestStreamingDeltas: + """Ensure ``process_fragment`` returns proper deltas, not just cumulative state.""" + + def test_thought_streams_progressively(self): + ext = JSONInnerThoughtsExtractor() + chunks = ['{"th', 'ought":"hel', 'lo wor', 'ld","q":"x"}'] + thoughts = [] + for chunk in chunks: + _, td = ext.process_fragment(chunk) + thoughts.append(td) + assert "".join(thoughts) == "hello world" + + def test_main_streams_progressively_without_wait(self): + ext = JSONInnerThoughtsExtractor(wait_for_first_key=False) + chunks = ['{"thought":"hi","q":"x', '","r":"y"}'] + mains = [] + for chunk in chunks: + md, _ = ext.process_fragment(chunk) + mains.append(md) + assert json.loads("".join(mains)) == {"q": "x", "r": "y"} + + +class TestDeltaBufferInvariant: + """The streamed main deltas must always reconstruct ``main_buffer`` and valid JSON. + + Guards against eagerly emitting a separator that is later only stripped from + the buffer — leaving a dangling comma in the delta stream (thought-last case). + """ + + RAWS = [ + '{"thought":"hi","q":"x"}', + '{"q":"x","thought":"hi"}', # thought last — the regression case + '{"a":"x","thought":"hi","b":"y"}', + '{"a":"x","b":"y"}', # no thought + '{"thought":"hi","config":{"a":1,"b":2}}', + '{"a":"x","items":[1,2,3],"thought":"hi"}', + ] + + @pytest.mark.parametrize("raw", RAWS) + @pytest.mark.parametrize("wait", [False, True]) + def test_delta_sum_matches_buffer(self, raw, wait): + ext = JSONInnerThoughtsExtractor(wait_for_first_key=wait) + mains = [] + for ch in raw: + md, _ = ext.process_fragment(ch) + mains.append(md) + streamed = "".join(mains) + # Core invariant: deltas reconstruct the buffer exactly (no phantom comma). + assert streamed == ext.main_buffer + # Effective output the streaming layer surfaces = deltas + drained held bytes. + effective = streamed + ext.held_main_buffer + parsed = json.loads(effective) + assert "thought" not in parsed diff --git a/tests/unit/nodes/agents/test_native_parallel_tool_calls.py b/tests/unit/nodes/agents/test_native_parallel_tool_calls.py index d9e01c866..ec691c11e 100644 --- a/tests/unit/nodes/agents/test_native_parallel_tool_calls.py +++ b/tests/unit/nodes/agents/test_native_parallel_tool_calls.py @@ -27,8 +27,8 @@ class TestNativeParallelToolCalling: def test_tool_calls_returned_as_list(self): from dynamiq.nodes.llms.base import BaseLLM - tc1 = {"function": {"name": "search", "arguments": json.dumps({"thought": "t1", "action_input": {"q": "a"}})}} - tc2 = {"function": {"name": "search", "arguments": json.dumps({"thought": "t2", "action_input": {"q": "b"}})}} + tc1 = {"function": {"name": "search", "arguments": json.dumps({"thought": "t1", "q": "a"})}} + tc2 = {"function": {"name": "search", "arguments": json.dumps({"thought": "t2", "q": "b"})}} tc_objects = [] for tc in [tc1, tc2]: @@ -61,8 +61,8 @@ def test_multiple_tool_calls_routed_as_parallel_batch(self): llm_result = SimpleNamespace( output={ "tool_calls": [ - {"function": {"name": "search", "arguments": {"thought": "first", "action_input": {"q": "a"}}}}, - {"function": {"name": "calc", "arguments": {"thought": "second", "action_input": {"expr": "1+1"}}}}, + {"function": {"name": "search", "arguments": {"thought": "first", "q": "a"}}}, + {"function": {"name": "calc", "arguments": {"thought": "second", "expr": "1+1"}}}, ] } ) @@ -72,7 +72,9 @@ def test_multiple_tool_calls_routed_as_parallel_batch(self): assert action == PARALLEL_TOOL_NAME assert len(action_input["tools"]) == 2 assert action_input["tools"][0]["name"] == "search" + assert action_input["tools"][0]["input"] == {"q": "a"} assert action_input["tools"][1]["name"] == "calc" + assert action_input["tools"][1]["input"] == {"expr": "1+1"} def test_single_tool_call_unchanged(self): from dynamiq.nodes.agents.agent import Agent @@ -85,7 +87,7 @@ def test_single_tool_call_unchanged(self): llm_result = SimpleNamespace( output={ "tool_calls": [ - {"function": {"name": "search", "arguments": {"thought": "t", "action_input": {"q": "a"}}}}, + {"function": {"name": "search", "arguments": {"thought": "t", "q": "a"}}}, ] } ) @@ -108,6 +110,7 @@ def test_no_tool_calls_raises(self): Agent._handle_function_calling_mode(agent, llm_result, loop_num=1) def test_arguments_as_json_string(self): + """OpenAI's wire protocol: function.arguments arrives as a JSON-encoded string.""" from dynamiq.nodes.agents.agent import Agent agent = _make_agent() @@ -117,7 +120,7 @@ def test_arguments_as_json_string(self): { "function": { "name": "search", - "arguments": json.dumps({"thought": "t", "action_input": {"q": "a"}}), + "arguments": json.dumps({"thought": "t", "q": "a"}), } } ] @@ -139,7 +142,7 @@ def test_missing_thought_defaults_to_empty(self): llm_result = SimpleNamespace( output={ "tool_calls": [ - {"function": {"name": "search", "arguments": {"action_input": {"q": "a"}}}}, + {"function": {"name": "search", "arguments": {"q": "a"}}}, ] } ) @@ -150,7 +153,12 @@ def test_missing_thought_defaults_to_empty(self): assert thought == "" assert action_input == {"q": "a"} - def test_missing_action_input_raises(self): + def test_only_thought_yields_empty_action_input(self): + """A tool call with only `thought` and no real params produces an empty dict. + + Some tools genuinely take no parameters (their schema has only `thought`), + so this is not an error. + """ from dynamiq.nodes.agents.agent import Agent agent = _make_agent() @@ -162,8 +170,11 @@ def test_missing_action_input_raises(self): } ) - with pytest.raises(ActionParsingException): - Agent._handle_function_calling_mode(agent, llm_result, loop_num=1) + thought, action, action_input = Agent._handle_function_calling_mode(agent, llm_result, loop_num=1) + + assert action == "search" + assert thought == "t" + assert action_input == {} def test_final_answer(self): from dynamiq.nodes.agents.agent import Agent diff --git a/tests/unit/nodes/agents/test_schema_generator_flat_args.py b/tests/unit/nodes/agents/test_schema_generator_flat_args.py new file mode 100644 index 000000000..302599b67 --- /dev/null +++ b/tests/unit/nodes/agents/test_schema_generator_flat_args.py @@ -0,0 +1,216 @@ +"""Tests for the flat-args function-calling schema shape. + +Verifies that the schema generator produces schemas where: + * `thought` is the FIRST property (load-bearing for streaming UX and model + chain-of-thought behavior). + * Tool params are TOP-LEVEL siblings of `thought` (no `action_input` wrapper). + * `additionalProperties: false` is set on `parameters`. + * Strict mode is enabled only when every property is required AND the schema + contains no shapes OpenAI strict mode would reject. +""" + +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from dynamiq.nodes.agents.components.schema_generator import generate_function_calling_schemas + + +def _sanitize(name: str) -> str: + return name + + +def _gen(tool, delegation_allowed: bool = False): + return generate_function_calling_schemas( + tools=[tool], delegation_allowed=delegation_allowed, sanitize_tool_name=_sanitize + ) + + +def _make_tool(name: str, input_schema_cls: type[BaseModel] | None = None, description: str = "desc"): + """Build a mock tool with a Pydantic input schema.""" + + class _NoFields(BaseModel): + pass + + tool = MagicMock() + tool.name = name + tool.description = description + tool.input_schema = input_schema_cls if input_schema_cls is not None else _NoFields + return tool + + +class _RequiredOnlySchema(BaseModel): + file_path: str = Field(..., description="Path") + content: str = Field(..., description="Content") + + +class _MixedParamsSchema(BaseModel): + query: str = Field(..., description="Query") + limit: int = Field(default=10, description="Max results") + + +class _EmptySchema(BaseModel): + pass + + +class TestFlatArgsSchema: + @pytest.mark.parametrize( + "schema_cls, expected_properties", + [ + (_RequiredOnlySchema, ["thought", "file_path", "content"]), + (_MixedParamsSchema, ["thought", "query", "limit"]), + (_EmptySchema, ["thought"]), + ], + ids=["required_only", "mixed_params", "zero_params"], + ) + def test_thought_is_first_and_params_are_flat_siblings(self, schema_cls, expected_properties): + """`thought` is the first property and tool params are top-level siblings, + regardless of whether params are all-required, mixed, or absent.""" + tool = _make_tool("tool", schema_cls) + tool_schema = next(s for s in _gen(tool) if s["function"]["name"] == "tool") + + properties = tool_schema["function"]["parameters"]["properties"] + required = tool_schema["function"]["parameters"]["required"] + + assert list(properties.keys()) == expected_properties + assert required[0] == "thought" + + def test_no_action_input_wrapper_in_any_schema(self): + class _Schema(BaseModel): + x: str = Field(..., description="X") + + tool = _make_tool("foo", _Schema) + schemas = _gen(tool) + + for schema in schemas: + properties = schema["function"]["parameters"]["properties"] + assert "action_input" not in properties, f"action_input wrapper leaked into {schema['function']['name']}" + + def test_strict_when_all_params_required(self): + class _Schema(BaseModel): + a: str = Field(..., description="A") + b: str = Field(..., description="B") + + tool = _make_tool("op", _Schema) + schemas = _gen(tool) + tool_schema = next(s for s in schemas if s["function"]["name"] == "op") + + assert tool_schema["function"]["strict"] is True + assert tool_schema["function"]["parameters"]["additionalProperties"] is False + assert set(tool_schema["function"]["parameters"]["required"]) == {"thought", "a", "b"} + + def test_strict_dropped_when_any_param_is_optional(self): + class _Schema(BaseModel): + required_field: str = Field(..., description="Required") + optional_field: str | None = Field(default=None, description="Optional") + + tool = _make_tool("op", _Schema) + schemas = _gen(tool) + tool_schema = next(s for s in schemas if s["function"]["name"] == "op") + + assert tool_schema["function"]["strict"] is False + # required list excludes the optional field but always starts with thought + assert tool_schema["function"]["parameters"]["required"][0] == "thought" + assert "required_field" in tool_schema["function"]["parameters"]["required"] + assert "optional_field" not in tool_schema["function"]["parameters"]["required"] + + def test_additional_properties_false_on_tool_schemas(self): + """Flat-args tool schemas must set additionalProperties: false at the outer + parameters level so strict mode (when enabled) blocks malformed shapes.""" + + class _Schema(BaseModel): + x: str = Field(..., description="X") + + tool = _make_tool("foo", _Schema) + schemas = _gen(tool) + + for schema in schemas: + if schema["function"]["name"] == "provide_final_answer": + continue # final-answer schema is untouched by this migration + params = schema["function"]["parameters"] + assert params.get("additionalProperties") is False, ( + f"{schema['function']['name']} missing additionalProperties: false at parameters level" + ) + + def test_zero_param_tool_is_strict_by_default(self): + """A tool with no params produces a fully-required, strict schema.""" + + class _Empty(BaseModel): + pass + + tool = _make_tool("ping", _Empty) + schemas = _gen(tool) + tool_schema = next(s for s in schemas if s["function"]["name"] == "ping") + + assert tool_schema["function"]["strict"] is True + + def test_extra_allow_tool_is_open_and_non_strict(self): + """A no-declared-fields tool that accepts extras (e.g. the generic Python + tool) must stay OPEN: additionalProperties true and non-strict, so the model + can pass arbitrary params as top-level siblings of `thought`.""" + from pydantic import ConfigDict + + class _Dynamic(BaseModel): + model_config = ConfigDict(extra="allow") + + tool = _make_tool("run_code", _Dynamic) + schemas = _gen(tool) + tool_schema = next(s for s in schemas if s["function"]["name"] == "run_code") + params = tool_schema["function"]["parameters"] + + assert params["additionalProperties"] is True + assert tool_schema["function"]["strict"] is False + assert params["required"] == ["thought"] + + +class TestFinalAnswerSchema: + def test_final_answer_is_first_in_list(self): + class _Schema(BaseModel): + x: str = Field(..., description="X") + + tool = _make_tool("foo", _Schema) + schemas = _gen(tool) + + assert schemas[0]["function"]["name"] == "provide_final_answer" + + def test_final_answer_keeps_thought_first(self): + tool = _make_tool("foo", None) + schemas = _gen(tool) + + properties = schemas[0]["function"]["parameters"]["properties"] + assert list(properties.keys())[0] == "thought" + + +@pytest.fixture +def _sub_agent_tool(): + """Build a SubAgentTool-like mock for delegate_final tests. + + Uses ``spec=SubAgentTool`` so the ``isinstance(tool, SubAgentTool)`` branch in + ``generate_function_calling_schemas`` is taken without instantiating the real + Pydantic model (which has ``input_schema`` as a ClassVar). + """ + from dynamiq.nodes.tools.agent_tool import SubAgentTool + + class _Schema(BaseModel): + input: str = Field(..., description="Subtask") + + tool = MagicMock(spec=SubAgentTool) + tool.name = "Researcher" + tool.description = "Research tool" + tool.input_schema = _Schema + return tool + + +class TestSubAgentDelegateFinal: + def test_delegate_final_is_top_level_sibling_not_nested(self, _sub_agent_tool): + """In flat-args mode, `delegate_final` lives at the top level alongside the + agent tool's own params (e.g. `input`), not inside an `action_input` wrapper.""" + schemas = _gen(_sub_agent_tool, delegation_allowed=True) + tool_schema = next(s for s in schemas if s["function"]["name"] == "Researcher") + properties = tool_schema["function"]["parameters"]["properties"] + + assert "delegate_final" in properties + assert "action_input" not in properties + # thought is still first + assert list(properties.keys())[0] == "thought" diff --git a/tests/unit/nodes/agents/test_streaming_chunking.py b/tests/unit/nodes/agents/test_streaming_chunking.py index 78cb299a7..01c770d7d 100644 --- a/tests/unit/nodes/agents/test_streaming_chunking.py +++ b/tests/unit/nodes/agents/test_streaming_chunking.py @@ -158,13 +158,13 @@ def _feed_fc_chunks(cb, thought: str, answer: str) -> None: def _feed_fc_tool_chunks(cb, thought: str, tool_name: str, tool_input: str) -> None: - """FC single tool call: function name = tool, arguments contain thought + action_input.""" + """FC single tool call (flat-args schema): tool's `query` param is a sibling of `thought`.""" serialized = {"group": "llms", "id": "llm-1"} cb.on_node_execute_stream( serialized, {"choices": [{"delta": {"tool_calls": [{"index": 0, "type": "function", "function": {"name": tool_name}}]}}]}, ) - args = '{"thought": "' + thought + '", "action_input": "' + tool_input + '"}' + args = '{"thought": "' + thought + '", "query": "' + tool_input + '"}' for ch in args: cb.on_node_execute_stream( serialized, @@ -173,10 +173,10 @@ def _feed_fc_tool_chunks(cb, thought: str, tool_name: str, tool_input: str) -> N def _feed_fc_parallel_tool_chunks(cb, calls: list) -> None: - """Multiple FC tool calls. Each ``calls`` entry is (tool_name, thought, action_input). + """Multiple FC tool calls (flat-args schema). Each entry is (tool_name, thought, query_value). The tc_index increments per tool, which triggers the parser's _reset_tool_call_state.""" serialized = {"group": "llms", "id": "llm-1"} - for index, (tool_name, thought, action_input) in enumerate(calls): + for index, (tool_name, thought, query) in enumerate(calls): cb.on_node_execute_stream( serialized, { @@ -185,7 +185,7 @@ def _feed_fc_parallel_tool_chunks(cb, calls: list) -> None: ] }, ) - args = '{"thought": "' + thought + '", "action_input": "' + action_input + '"}' + args = '{"thought": "' + thought + '", "query": "' + query + '"}' for ch in args: cb.on_node_execute_stream( serialized, @@ -274,7 +274,19 @@ def test_streaming_chunking_tool_call_path(mode, min_chunk_chars): assert len(chunk) >= min_chunk_chars, f"TOOL_INPUT chunk {len(chunk)} < {min_chunk_chars}" assert "".join(reasoning).strip() == SAMPLE_THOUGHT.strip() - assert "".join(tool_inputs).strip() == SAMPLE_TOOL_INPUT.strip() + + if mode == InferenceMode.FUNCTION_CALLING: + # Flat-args FC mode: the inline extractor routes `thought` to REASONING + # events and emits the tool's real params (without thought) as TOOL_INPUT + # content. Consumer parses it directly without needing to pop thought. + import json + + joined = "".join(tool_inputs).strip() + parsed = json.loads(joined) + assert "thought" not in parsed, "TOOL_INPUT must NOT contain thought" + assert parsed["query"] == SAMPLE_TOOL_INPUT + else: + assert "".join(tool_inputs).strip() == SAMPLE_TOOL_INPUT.strip() # No final-answer events on this path. assert not _emitted_by_step( @@ -314,9 +326,18 @@ def test_streaming_chunking_parallel_tool_calls(min_chunk_chars): by_run_id = _tool_input_events_by_run_id(cb) assert len(by_run_id) == 2, f"expected 2 distinct tool_run_ids, got {len(by_run_id)}" + # Flat-args FC mode: each tool's TOOL_INPUT stream carries only the tool + # params (thought is routed to REASONING events instead). Verify `query` + # round-trips and `thought` is NOT in the parsed object. + import json + fragments_per_tool = list(by_run_id.values()) - assert "".join(fragments_per_tool[0]) == input_a - assert "".join(fragments_per_tool[1]) == input_b + parsed_a = json.loads("".join(fragments_per_tool[0])) + parsed_b = json.loads("".join(fragments_per_tool[1])) + assert "thought" not in parsed_a + assert parsed_a["query"] == input_a + assert "thought" not in parsed_b + assert parsed_b["query"] == input_b if min_chunk_chars > 0: for tid, frags in by_run_id.items(): @@ -326,3 +347,34 @@ def test_streaming_chunking_parallel_tool_calls(min_chunk_chars): reasoning = _emitted_by_step(cb, StreamingState.REASONING) assert reasoning, "no REASONING events for parallel tool calls" assert "".join(reasoning).strip() == (thought_a + thought_b).strip() + + +def test_fc_tool_input_emits_when_thought_is_missing(): + """Flat-args fallback: if the LLM emits a tool call without `thought`, the + streaming layer still emits the full outer JSON object as TOOL_INPUT at the + end of the stream. Without this fallback, the consumer would see no events + for that tool call even though it was dispatched successfully.""" + import json + + cb = _make_callback_for_mode(InferenceMode.FUNCTION_CALLING, min_chunk_chars=0) + serialized = {"group": "llms", "id": "llm-1"} + + cb.on_node_execute_stream( + serialized, + {"choices": [{"delta": {"tool_calls": [{"index": 0, "type": "function", "function": {"name": "search"}}]}}]}, + ) + args = '{"query": "weather", "model": "gpt"}' + for ch in args: + cb.on_node_execute_stream( + serialized, + {"choices": [{"delta": {"tool_calls": [{"index": 0, "type": "function", "function": {"arguments": ch}}]}}]}, + ) + cb.on_node_execute_end({"group": "llms"}, output_data={}) + + reasoning = _emitted_by_step(cb, StreamingState.REASONING) + tool_inputs = _emitted_tool_input_chunks(cb) + + assert not reasoning, "no REASONING expected when thought is absent" + assert tool_inputs, "TOOL_INPUT must be emitted as a fallback even without thought" + parsed = json.loads("".join(tool_inputs)) + assert parsed == {"query": "weather", "model": "gpt"} diff --git a/tests/unit/nodes/agents/test_streaming_parser.py b/tests/unit/nodes/agents/test_streaming_parser.py index fcb5156c3..e700aada8 100644 --- a/tests/unit/nodes/agents/test_streaming_parser.py +++ b/tests/unit/nodes/agents/test_streaming_parser.py @@ -87,24 +87,6 @@ def _make_fc_callback(tool_input_started=False, answer_started=False, action_nam @pytest.mark.parametrize( "buf, tool_input_started, answer_started, action_name, expected_state, expected_fc_object", [ - pytest.param( - '{"action_input": {"query": "hello"', - True, - False, - "exa_search", - StreamingState.TOOL_INPUT, - True, - id="object_action_input", - ), - pytest.param( - '{"action_input": "sub-query text', - True, - False, - "sub_agent", - StreamingState.TOOL_INPUT, - False, - id="string_action_input", - ), pytest.param( '{"answer": "Here is the result', False, @@ -128,6 +110,9 @@ def _make_fc_callback(tool_input_started=False, answer_started=False, action_nam def test_process_json_mode_function_calling( buf, tool_input_started, answer_started, action_name, expected_state, expected_fc_object ): + """ANSWER path (provide_final_answer) still routes through ``_process_json_mode``. + Real FC tool calls go through the inline extractor and don't touch this method. + """ cb = _make_fc_callback( tool_input_started=tool_input_started, answer_started=answer_started,