diff --git a/backend/agent_templates/travel_vacation_planner.yaml b/backend/agent_templates/travel_vacation_planner.yaml index f32182f7..33c9a12c 100644 --- a/backend/agent_templates/travel_vacation_planner.yaml +++ b/backend/agent_templates/travel_vacation_planner.yaml @@ -28,12 +28,15 @@ templates: nodes: - id: places_list_task type: llm + internal: true prompt: | List 5 specific places, neighborhoods, or attractions in {inputs.destination} that match {inputs.interests}. Return a JSON array of strings only. - id: destination_research_task type: mcp_tool_map + depends_on: + - places_list_task server: travel_research tool: tavily_travel_search items_path: outputs.places_list_task @@ -45,14 +48,14 @@ templates: type: llm depends_on: - destination_research_task - - hotel_research_task - - flight_research_task prompt: | Create 2-3 itinerary options for the trip. Include day-by-day highlights and note must-see places. - id: hotel_research_task type: mcp_tool + depends_on: + - itinerary_options_task server: hotel tool: google_hotels_search args: @@ -64,6 +67,8 @@ templates: - id: flight_research_task type: mcp_tool + depends_on: + - itinerary_options_task server: flight tool: google_flights_search args: diff --git a/backend/app/services/runners/graph_engine.py b/backend/app/services/runners/graph_engine.py index 34ddd06f..af1730a1 100644 --- a/backend/app/services/runners/graph_engine.py +++ b/backend/app/services/runners/graph_engine.py @@ -486,6 +486,19 @@ def __init__(self, config: Dict[str, Any], llm: Any) -> None: self.nodes = nodes self.edges = cfg.get("edges") or [] self.entry_id: Optional[str] = cfg.get("entry") + self._internal_ids: set = { + str(n["id"]) for n in self.nodes if n.get("internal") + } + # Auto-detect: nodes referenced only via items_path in mcp_tool_map + # nodes are intermediate data producers, not user-facing output. + items_source_ids = set() + for n in self.nodes: + if str(n.get("type", "")).strip().lower() in ("mcp_tool_map", "mcp_map"): + ip = str(n.get("items_path", "")) + m = _OUTPUT_REF_RE.search(ip) + if m: + items_source_ids.add(m.group(1)) + self._internal_ids |= items_source_ids def _build_graph(self): """Build a compiled LangGraph StateGraph from the config.""" @@ -592,6 +605,7 @@ def _make_step_fn(self, step: dict): mcp_servers = self.mcp_cfg.get("servers") or {} mcp_transport = str(self.mcp_cfg.get("transport") or "streamable-http").lower() llm = self.llm + internal_ids = self._internal_ids async def _run(state: GraphState) -> dict: inputs = state.get("inputs") or {} @@ -602,7 +616,10 @@ async def _run(state: GraphState) -> dict: result = "" if step_type in ("llm", "prompt"): - result = await _run_llm_node(llm, step, inputs, outputs) + visible_outputs = { + k: v for k, v in outputs.items() if k not in internal_ids + } + result = await _run_llm_node(llm, step, inputs, visible_outputs) elif step_type in ("mcp_tool", "mcp"): result = await _run_mcp_tool_node( @@ -628,6 +645,7 @@ async def _run(state: GraphState) -> dict: "name": step_id, "summary": _summarize_output(result), "raw": result, + "internal": step_id in internal_ids, } ], } @@ -661,6 +679,12 @@ async def run_streaming( tasks = node_state.get("tasks_output", []) for task in tasks: task_name = task.get("name", node_name) + total_tasks += 1 + + if task.get("internal"): + logger.debug("Suppressing SSE for internal node: %s", task_name) + continue + yield _sse("node_started", {"node": task_name}, session_id) raw = task.get("raw", "") @@ -676,7 +700,6 @@ async def run_streaming( ) yield _sse("node_completed", {"node": task_name}, session_id) - total_tasks += 1 if total_tasks > 0: yield _sse( diff --git a/tests/unit/test_graph_engine.py b/tests/unit/test_graph_engine.py index c0fabd5a..61d1030f 100644 --- a/tests/unit/test_graph_engine.py +++ b/tests/unit/test_graph_engine.py @@ -360,6 +360,127 @@ async def mock_ainvoke(messages): nodes_started = [p["node"] for p in parsed if p["type"] == "node_started"] assert nodes_started == ["step1", "step2"] + @pytest.mark.asyncio + async def test_internal_node_suppressed(self): + """Nodes with internal: true produce no SSE events and their output + is excluded from the LLM context of downstream nodes.""" + from backend.app.services.runners.graph_engine import GraphEngine + + config = { + "nodes": [ + { + "id": "hidden", + "type": "llm", + "internal": True, + "prompt": "List items", + }, + { + "id": "visible", + "type": "llm", + "depends_on": ["hidden"], + "prompt": "Use {outputs.hidden}", + }, + ] + } + + mock_llm = AsyncMock() + call_count = 0 + captured_prompts: list = [] + + async def mock_ainvoke(messages): + nonlocal call_count + call_count += 1 + for msg in messages: + if isinstance(msg, dict): + captured_prompts.append(msg.get("content", "")) + resp = MagicMock() + resp.content = f"Result {call_count}" + return resp + + mock_llm.ainvoke = mock_ainvoke + + engine = GraphEngine(config=config, llm=mock_llm) + events = [] + async for event in engine.run_streaming({}, "sess"): + events.append(event) + + parsed = [ + json.loads(e[len("data: ") : -2]) for e in events if e.startswith("data: ") + ] + + node_names = [p["node"] for p in parsed if p["type"] == "node_started"] + assert "hidden" not in node_names + assert "visible" in node_names + + responses = [ + p for p in parsed if p["type"] == "response" and p.get("id") == "hidden" + ] + assert len(responses) == 0 + + visible_responses = [ + p for p in parsed if p["type"] == "response" and p.get("id") == "visible" + ] + assert len(visible_responses) == 1 + + assert call_count == 2, "Both nodes should execute even if one is internal" + + visible_prompt = captured_prompts[1] + assert ( + "hidden" not in visible_prompt + ), "Internal node output should not appear in downstream LLM context" + + @pytest.mark.asyncio + async def test_items_path_source_auto_internal(self): + """Nodes referenced only via items_path are auto-detected as internal, + even without an explicit internal: true flag.""" + from backend.app.services.runners.graph_engine import GraphEngine + + config = { + "nodes": [ + { + "id": "places", + "type": "llm", + "prompt": "List places", + }, + { + "id": "research", + "type": "mcp_tool_map", + "depends_on": ["places"], + "server": "travel", + "tool": "search", + "items_path": "outputs.places", + "query_template": "Research {item}", + }, + ], + "mcp": { + "transport": "streamable-http", + "servers": { + "travel": {"url": "http://localhost:7001/mcp"}, + }, + }, + } + + mock_llm = AsyncMock() + mock_resp = MagicMock() + mock_resp.content = '["Tokyo", "Kyoto"]' + mock_llm.ainvoke.return_value = mock_resp + + engine = GraphEngine(config=config, llm=mock_llm) + assert "places" in engine._internal_ids + + events = [] + async for event in engine.run_streaming({}, "sess"): + events.append(event) + + parsed = [ + json.loads(e[len("data: ") : -2]) for e in events if e.startswith("data: ") + ] + + node_names = [p["node"] for p in parsed if p["type"] == "node_started"] + assert ( + "places" not in node_names + ), "items_path source node should be auto-suppressed" + def test_empty_nodes_raises(self): """GraphEngine raises on empty node list.""" from backend.app.services.runners.graph_engine import GraphEngine