|
22 | 22 | from typing_extensions import NotRequired, TypedDict |
23 | 23 |
|
24 | 24 | if TYPE_CHECKING: |
25 | | - from collections.abc import Callable, Sequence |
| 25 | + from collections.abc import Awaitable, Callable, Sequence |
26 | 26 |
|
27 | 27 | from langchain.tools.tool_node import ToolCallRequest |
28 | 28 |
|
@@ -209,6 +209,32 @@ def wrap_model_call( |
209 | 209 |
|
210 | 210 | return handler(request) |
211 | 211 |
|
| 212 | + async def awrap_model_call( |
| 213 | + self, |
| 214 | + request: ModelRequest, |
| 215 | + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], |
| 216 | + ) -> ModelResponse: |
| 217 | + """Inject tool and optional system prompt (async version).""" |
| 218 | + # Add tool |
| 219 | + tools = list(request.tools or []) |
| 220 | + tools.append( |
| 221 | + { |
| 222 | + "type": self.tool_type, |
| 223 | + "name": self.tool_name, |
| 224 | + } |
| 225 | + ) |
| 226 | + request.tools = tools |
| 227 | + |
| 228 | + # Inject system prompt if provided |
| 229 | + if self.system_prompt: |
| 230 | + request.system_prompt = ( |
| 231 | + request.system_prompt + "\n\n" + self.system_prompt |
| 232 | + if request.system_prompt |
| 233 | + else self.system_prompt |
| 234 | + ) |
| 235 | + |
| 236 | + return await handler(request) |
| 237 | + |
212 | 238 | def wrap_tool_call( |
213 | 239 | self, |
214 | 240 | request: ToolCallRequest, |
@@ -255,6 +281,52 @@ def wrap_tool_call( |
255 | 281 | status="error", |
256 | 282 | ) |
257 | 283 |
|
| 284 | + async def awrap_tool_call( |
| 285 | + self, |
| 286 | + request: ToolCallRequest, |
| 287 | + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], |
| 288 | + ) -> ToolMessage | Command: |
| 289 | + """Intercept tool calls (async version).""" |
| 290 | + tool_call = request.tool_call |
| 291 | + tool_name = tool_call.get("name") |
| 292 | + |
| 293 | + if tool_name != self.tool_name: |
| 294 | + return await handler(request) |
| 295 | + |
| 296 | + # Handle tool call |
| 297 | + try: |
| 298 | + args = tool_call.get("args", {}) |
| 299 | + command = args.get("command") |
| 300 | + state = request.state |
| 301 | + |
| 302 | + if command == "view": |
| 303 | + return self._handle_view(args, state, tool_call["id"]) |
| 304 | + if command == "create": |
| 305 | + return self._handle_create(args, state, tool_call["id"]) |
| 306 | + if command == "str_replace": |
| 307 | + return self._handle_str_replace(args, state, tool_call["id"]) |
| 308 | + if command == "insert": |
| 309 | + return self._handle_insert(args, state, tool_call["id"]) |
| 310 | + if command == "delete": |
| 311 | + return self._handle_delete(args, state, tool_call["id"]) |
| 312 | + if command == "rename": |
| 313 | + return self._handle_rename(args, state, tool_call["id"]) |
| 314 | + |
| 315 | + msg = f"Unknown command: {command}" |
| 316 | + return ToolMessage( |
| 317 | + content=msg, |
| 318 | + tool_call_id=tool_call["id"], |
| 319 | + name=tool_name, |
| 320 | + status="error", |
| 321 | + ) |
| 322 | + except (ValueError, FileNotFoundError) as e: |
| 323 | + return ToolMessage( |
| 324 | + content=str(e), |
| 325 | + tool_call_id=tool_call["id"], |
| 326 | + name=tool_name, |
| 327 | + status="error", |
| 328 | + ) |
| 329 | + |
258 | 330 | def _handle_view( |
259 | 331 | self, args: dict, state: AnthropicToolsState, tool_call_id: str | None |
260 | 332 | ) -> Command: |
@@ -645,6 +717,32 @@ def wrap_model_call( |
645 | 717 |
|
646 | 718 | return handler(request) |
647 | 719 |
|
| 720 | + async def awrap_model_call( |
| 721 | + self, |
| 722 | + request: ModelRequest, |
| 723 | + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], |
| 724 | + ) -> ModelResponse: |
| 725 | + """Inject tool and optional system prompt (async version).""" |
| 726 | + # Add tool |
| 727 | + tools = list(request.tools or []) |
| 728 | + tools.append( |
| 729 | + { |
| 730 | + "type": self.tool_type, |
| 731 | + "name": self.tool_name, |
| 732 | + } |
| 733 | + ) |
| 734 | + request.tools = tools |
| 735 | + |
| 736 | + # Inject system prompt if provided |
| 737 | + if self.system_prompt: |
| 738 | + request.system_prompt = ( |
| 739 | + request.system_prompt + "\n\n" + self.system_prompt |
| 740 | + if request.system_prompt |
| 741 | + else self.system_prompt |
| 742 | + ) |
| 743 | + |
| 744 | + return await handler(request) |
| 745 | + |
648 | 746 | def wrap_tool_call( |
649 | 747 | self, |
650 | 748 | request: ToolCallRequest, |
@@ -690,6 +788,51 @@ def wrap_tool_call( |
690 | 788 | status="error", |
691 | 789 | ) |
692 | 790 |
|
| 791 | + async def awrap_tool_call( |
| 792 | + self, |
| 793 | + request: ToolCallRequest, |
| 794 | + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], |
| 795 | + ) -> ToolMessage | Command: |
| 796 | + """Intercept tool calls (async version).""" |
| 797 | + tool_call = request.tool_call |
| 798 | + tool_name = tool_call.get("name") |
| 799 | + |
| 800 | + if tool_name != self.tool_name: |
| 801 | + return await handler(request) |
| 802 | + |
| 803 | + # Handle tool call |
| 804 | + try: |
| 805 | + args = tool_call.get("args", {}) |
| 806 | + command = args.get("command") |
| 807 | + |
| 808 | + if command == "view": |
| 809 | + return self._handle_view(args, tool_call["id"]) |
| 810 | + if command == "create": |
| 811 | + return self._handle_create(args, tool_call["id"]) |
| 812 | + if command == "str_replace": |
| 813 | + return self._handle_str_replace(args, tool_call["id"]) |
| 814 | + if command == "insert": |
| 815 | + return self._handle_insert(args, tool_call["id"]) |
| 816 | + if command == "delete": |
| 817 | + return self._handle_delete(args, tool_call["id"]) |
| 818 | + if command == "rename": |
| 819 | + return self._handle_rename(args, tool_call["id"]) |
| 820 | + |
| 821 | + msg = f"Unknown command: {command}" |
| 822 | + return ToolMessage( |
| 823 | + content=msg, |
| 824 | + tool_call_id=tool_call["id"], |
| 825 | + name=tool_name, |
| 826 | + status="error", |
| 827 | + ) |
| 828 | + except (ValueError, FileNotFoundError) as e: |
| 829 | + return ToolMessage( |
| 830 | + content=str(e), |
| 831 | + tool_call_id=tool_call["id"], |
| 832 | + name=tool_name, |
| 833 | + status="error", |
| 834 | + ) |
| 835 | + |
693 | 836 | def _validate_and_resolve_path(self, path: str) -> Path: |
694 | 837 | """Validate and resolve a virtual path to filesystem path. |
695 | 838 |
|
|
0 commit comments