Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions backend/agent_templates/travel_vacation_planner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ templates:
nodes:
- id: places_list_task
type: llm
internal: true
prompt: |
List 5 specific places, neighborhoods, or attractions in {inputs.destination}
that match {inputs.interests}. Return a JSON array of strings only.

- id: destination_research_task
type: mcp_tool_map
depends_on:
- places_list_task
server: travel_research
tool: tavily_travel_search
items_path: outputs.places_list_task
Expand All @@ -45,14 +48,14 @@ templates:
type: llm
depends_on:
- destination_research_task
- hotel_research_task
- flight_research_task
prompt: |
Create 2-3 itinerary options for the trip.
Include day-by-day highlights and note must-see places.

- id: hotel_research_task
type: mcp_tool
depends_on:
- itinerary_options_task
server: hotel
tool: google_hotels_search
args:
Expand All @@ -64,6 +67,8 @@ templates:

- id: flight_research_task
type: mcp_tool
depends_on:
- itinerary_options_task
server: flight
tool: google_flights_search
args:
Expand Down
27 changes: 25 additions & 2 deletions backend/app/services/runners/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,19 @@ def __init__(self, config: Dict[str, Any], llm: Any) -> None:
self.nodes = nodes
self.edges = cfg.get("edges") or []
self.entry_id: Optional[str] = cfg.get("entry")
self._internal_ids: set = {
str(n["id"]) for n in self.nodes if n.get("internal")
}
# Auto-detect: nodes referenced only via items_path in mcp_tool_map
# nodes are intermediate data producers, not user-facing output.
items_source_ids = set()
for n in self.nodes:
if str(n.get("type", "")).strip().lower() in ("mcp_tool_map", "mcp_map"):
ip = str(n.get("items_path", ""))
m = _OUTPUT_REF_RE.search(ip)
if m:
items_source_ids.add(m.group(1))
self._internal_ids |= items_source_ids

def _build_graph(self):
"""Build a compiled LangGraph StateGraph from the config."""
Expand Down Expand Up @@ -592,6 +605,7 @@ def _make_step_fn(self, step: dict):
mcp_servers = self.mcp_cfg.get("servers") or {}
mcp_transport = str(self.mcp_cfg.get("transport") or "streamable-http").lower()
llm = self.llm
internal_ids = self._internal_ids

async def _run(state: GraphState) -> dict:
inputs = state.get("inputs") or {}
Expand All @@ -602,7 +616,10 @@ async def _run(state: GraphState) -> dict:
result = ""

if step_type in ("llm", "prompt"):
result = await _run_llm_node(llm, step, inputs, outputs)
visible_outputs = {
k: v for k, v in outputs.items() if k not in internal_ids
}
result = await _run_llm_node(llm, step, inputs, visible_outputs)

elif step_type in ("mcp_tool", "mcp"):
result = await _run_mcp_tool_node(
Expand All @@ -628,6 +645,7 @@ async def _run(state: GraphState) -> dict:
"name": step_id,
"summary": _summarize_output(result),
"raw": result,
"internal": step_id in internal_ids,
}
],
}
Expand Down Expand Up @@ -661,6 +679,12 @@ async def run_streaming(
tasks = node_state.get("tasks_output", [])
for task in tasks:
task_name = task.get("name", node_name)
total_tasks += 1

if task.get("internal"):
logger.debug("Suppressing SSE for internal node: %s", task_name)
continue

yield _sse("node_started", {"node": task_name}, session_id)

raw = task.get("raw", "")
Expand All @@ -676,7 +700,6 @@ async def run_streaming(
)

yield _sse("node_completed", {"node": task_name}, session_id)
total_tasks += 1

if total_tasks > 0:
yield _sse(
Expand Down
121 changes: 121 additions & 0 deletions tests/unit/test_graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,127 @@ async def mock_ainvoke(messages):
nodes_started = [p["node"] for p in parsed if p["type"] == "node_started"]
assert nodes_started == ["step1", "step2"]

@pytest.mark.asyncio
async def test_internal_node_suppressed(self):
"""Nodes with internal: true produce no SSE events and their output
is excluded from the LLM context of downstream nodes."""
from backend.app.services.runners.graph_engine import GraphEngine

config = {
"nodes": [
{
"id": "hidden",
"type": "llm",
"internal": True,
"prompt": "List items",
},
{
"id": "visible",
"type": "llm",
"depends_on": ["hidden"],
"prompt": "Use {outputs.hidden}",
},
]
}

mock_llm = AsyncMock()
call_count = 0
captured_prompts: list = []

async def mock_ainvoke(messages):
nonlocal call_count
call_count += 1
for msg in messages:
if isinstance(msg, dict):
captured_prompts.append(msg.get("content", ""))
resp = MagicMock()
resp.content = f"Result {call_count}"
return resp

mock_llm.ainvoke = mock_ainvoke

engine = GraphEngine(config=config, llm=mock_llm)
events = []
async for event in engine.run_streaming({}, "sess"):
events.append(event)

parsed = [
json.loads(e[len("data: ") : -2]) for e in events if e.startswith("data: ")
]

node_names = [p["node"] for p in parsed if p["type"] == "node_started"]
assert "hidden" not in node_names
assert "visible" in node_names

responses = [
p for p in parsed if p["type"] == "response" and p.get("id") == "hidden"
]
assert len(responses) == 0

visible_responses = [
p for p in parsed if p["type"] == "response" and p.get("id") == "visible"
]
assert len(visible_responses) == 1

assert call_count == 2, "Both nodes should execute even if one is internal"

visible_prompt = captured_prompts[1]
assert (
"hidden" not in visible_prompt
), "Internal node output should not appear in downstream LLM context"

@pytest.mark.asyncio
async def test_items_path_source_auto_internal(self):
"""Nodes referenced only via items_path are auto-detected as internal,
even without an explicit internal: true flag."""
from backend.app.services.runners.graph_engine import GraphEngine

config = {
"nodes": [
{
"id": "places",
"type": "llm",
"prompt": "List places",
},
{
"id": "research",
"type": "mcp_tool_map",
"depends_on": ["places"],
"server": "travel",
"tool": "search",
"items_path": "outputs.places",
"query_template": "Research {item}",
},
],
"mcp": {
"transport": "streamable-http",
"servers": {
"travel": {"url": "http://localhost:7001/mcp"},
},
},
}

mock_llm = AsyncMock()
mock_resp = MagicMock()
mock_resp.content = '["Tokyo", "Kyoto"]'
mock_llm.ainvoke.return_value = mock_resp

engine = GraphEngine(config=config, llm=mock_llm)
assert "places" in engine._internal_ids

events = []
async for event in engine.run_streaming({}, "sess"):
events.append(event)

parsed = [
json.loads(e[len("data: ") : -2]) for e in events if e.startswith("data: ")
]

node_names = [p["node"] for p in parsed if p["type"] == "node_started"]
assert (
"places" not in node_names
), "items_path source node should be auto-suppressed"

def test_empty_nodes_raises(self):
"""GraphEngine raises on empty node list."""
from backend.app.services.runners.graph_engine import GraphEngine
Expand Down
Loading