Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/deepagents/middleware/patch_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, An
if not messages or len(messages) == 0:
return None

patched_messages = []
# Iterate over the messages and add any dangling tool calls
# Collect patches for dangling tool calls
patches_to_add = []
for i, msg in enumerate(messages):
patched_messages.append(msg)
if msg.type == "ai" and msg.tool_calls:
for tool_call in msg.tool_calls:
corresponding_tool_msg = next(
Expand All @@ -33,12 +32,31 @@ def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, An
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
"cancelled - another message came in before it could be completed."
)
patched_messages.append(
patches_to_add.append(
ToolMessage(
content=tool_msg,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)

# Only return patches if there are dangling tool calls
# This prevents unnecessary RemoveMessage from being streamed to clients
if not patches_to_add:
return None

# Rebuild message list with patches inserted after their corresponding AI messages
patched_messages = []
for i, msg in enumerate(messages):
patched_messages.append(msg)
if msg.type == "ai" and msg.tool_calls:
for tool_call in msg.tool_calls:
# Find patch for this specific tool call
patch = next(
(p for p in patches_to_add if p.tool_call_id == tool_call["id"]),
None,
)
if patch:
patched_messages.append(patch)

return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]}
10 changes: 3 additions & 7 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,9 @@ def test_no_missing_tool_calls(self) -> None:
]
middleware = PatchToolCallsMiddleware()
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 6
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1:] == input_messages
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 5
assert updated_messages == input_messages
# When there are no missing tool calls, middleware should return None
# to avoid streaming unnecessary RemoveMessage to clients
assert state_update is None

def test_two_missing_tool_calls(self) -> None:
input_messages = [
Expand Down