Skip to content

Commit 8a9dcd6

Browse files
committed
x
1 parent 9ced664 commit 8a9dcd6

File tree

1 file changed

+144
-1
lines changed

1 file changed

+144
-1
lines changed

libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing_extensions import NotRequired, TypedDict
2323

2424
if TYPE_CHECKING:
25-
from collections.abc import Callable, Sequence
25+
from collections.abc import Awaitable, Callable, Sequence
2626

2727
from langchain.tools.tool_node import ToolCallRequest
2828

@@ -209,6 +209,32 @@ def wrap_model_call(
209209

210210
return handler(request)
211211

212+
async def awrap_model_call(
213+
self,
214+
request: ModelRequest,
215+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
216+
) -> ModelResponse:
217+
"""Inject tool and optional system prompt (async version)."""
218+
# Add tool
219+
tools = list(request.tools or [])
220+
tools.append(
221+
{
222+
"type": self.tool_type,
223+
"name": self.tool_name,
224+
}
225+
)
226+
request.tools = tools
227+
228+
# Inject system prompt if provided
229+
if self.system_prompt:
230+
request.system_prompt = (
231+
request.system_prompt + "\n\n" + self.system_prompt
232+
if request.system_prompt
233+
else self.system_prompt
234+
)
235+
236+
return await handler(request)
237+
212238
def wrap_tool_call(
213239
self,
214240
request: ToolCallRequest,
@@ -255,6 +281,52 @@ def wrap_tool_call(
255281
status="error",
256282
)
257283

284+
async def awrap_tool_call(
285+
self,
286+
request: ToolCallRequest,
287+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
288+
) -> ToolMessage | Command:
289+
"""Intercept tool calls (async version)."""
290+
tool_call = request.tool_call
291+
tool_name = tool_call.get("name")
292+
293+
if tool_name != self.tool_name:
294+
return await handler(request)
295+
296+
# Handle tool call
297+
try:
298+
args = tool_call.get("args", {})
299+
command = args.get("command")
300+
state = request.state
301+
302+
if command == "view":
303+
return self._handle_view(args, state, tool_call["id"])
304+
if command == "create":
305+
return self._handle_create(args, state, tool_call["id"])
306+
if command == "str_replace":
307+
return self._handle_str_replace(args, state, tool_call["id"])
308+
if command == "insert":
309+
return self._handle_insert(args, state, tool_call["id"])
310+
if command == "delete":
311+
return self._handle_delete(args, state, tool_call["id"])
312+
if command == "rename":
313+
return self._handle_rename(args, state, tool_call["id"])
314+
315+
msg = f"Unknown command: {command}"
316+
return ToolMessage(
317+
content=msg,
318+
tool_call_id=tool_call["id"],
319+
name=tool_name,
320+
status="error",
321+
)
322+
except (ValueError, FileNotFoundError) as e:
323+
return ToolMessage(
324+
content=str(e),
325+
tool_call_id=tool_call["id"],
326+
name=tool_name,
327+
status="error",
328+
)
329+
258330
def _handle_view(
259331
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
260332
) -> Command:
@@ -645,6 +717,32 @@ def wrap_model_call(
645717

646718
return handler(request)
647719

720+
async def awrap_model_call(
721+
self,
722+
request: ModelRequest,
723+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
724+
) -> ModelResponse:
725+
"""Inject tool and optional system prompt (async version)."""
726+
# Add tool
727+
tools = list(request.tools or [])
728+
tools.append(
729+
{
730+
"type": self.tool_type,
731+
"name": self.tool_name,
732+
}
733+
)
734+
request.tools = tools
735+
736+
# Inject system prompt if provided
737+
if self.system_prompt:
738+
request.system_prompt = (
739+
request.system_prompt + "\n\n" + self.system_prompt
740+
if request.system_prompt
741+
else self.system_prompt
742+
)
743+
744+
return await handler(request)
745+
648746
def wrap_tool_call(
649747
self,
650748
request: ToolCallRequest,
@@ -690,6 +788,51 @@ def wrap_tool_call(
690788
status="error",
691789
)
692790

791+
async def awrap_tool_call(
792+
self,
793+
request: ToolCallRequest,
794+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
795+
) -> ToolMessage | Command:
796+
"""Intercept tool calls (async version)."""
797+
tool_call = request.tool_call
798+
tool_name = tool_call.get("name")
799+
800+
if tool_name != self.tool_name:
801+
return await handler(request)
802+
803+
# Handle tool call
804+
try:
805+
args = tool_call.get("args", {})
806+
command = args.get("command")
807+
808+
if command == "view":
809+
return self._handle_view(args, tool_call["id"])
810+
if command == "create":
811+
return self._handle_create(args, tool_call["id"])
812+
if command == "str_replace":
813+
return self._handle_str_replace(args, tool_call["id"])
814+
if command == "insert":
815+
return self._handle_insert(args, tool_call["id"])
816+
if command == "delete":
817+
return self._handle_delete(args, tool_call["id"])
818+
if command == "rename":
819+
return self._handle_rename(args, tool_call["id"])
820+
821+
msg = f"Unknown command: {command}"
822+
return ToolMessage(
823+
content=msg,
824+
tool_call_id=tool_call["id"],
825+
name=tool_name,
826+
status="error",
827+
)
828+
except (ValueError, FileNotFoundError) as e:
829+
return ToolMessage(
830+
content=str(e),
831+
tool_call_id=tool_call["id"],
832+
name=tool_name,
833+
status="error",
834+
)
835+
693836
def _validate_and_resolve_path(self, path: str) -> Path:
694837
"""Validate and resolve a virtual path to filesystem path.
695838

0 commit comments

Comments
 (0)