|
23 | 23 | ChatCompletionChunk,
|
24 | 24 | ChatCompletionMessageCustomToolCall,
|
25 | 25 | ChatCompletionMessageFunctionToolCall,
|
| 26 | + ChatCompletionMessageParam, |
26 | 27 | )
|
27 | 28 | from openai.types.chat.chat_completion_message import (
|
28 | 29 | Annotation,
|
@@ -267,6 +268,10 @@ async def _fetch_response(
|
267 | 268 | input, preserve_thinking_blocks=preserve_thinking_blocks
|
268 | 269 | )
|
269 | 270 |
|
| 271 | + # Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501 |
| 272 | + if preserve_thinking_blocks: |
| 273 | + converted_messages = self._fix_tool_message_ordering(converted_messages) |
| 274 | + |
270 | 275 | if system_instructions:
|
271 | 276 | converted_messages.insert(
|
272 | 277 | 0,
|
@@ -379,6 +384,121 @@ async def _fetch_response(
|
379 | 384 | )
|
380 | 385 | return response, ret
|
381 | 386 |
|
| 387 | + def _fix_tool_message_ordering( |
| 388 | + self, messages: list[ChatCompletionMessageParam] |
| 389 | + ) -> list[ChatCompletionMessageParam]: |
| 390 | + """ |
| 391 | + Fix the ordering of tool messages to ensure tool_use messages come before tool_result messages. |
| 392 | +
|
| 393 | + This addresses the interleaved thinking bug where conversation histories may contain |
| 394 | + tool results before their corresponding tool calls, causing Anthropic API to reject the request. |
| 395 | + """ # noqa: E501 |
| 396 | + if not messages: |
| 397 | + return messages |
| 398 | + |
| 399 | + # Collect all tool calls and tool results |
| 400 | + tool_call_messages = {} # tool_id -> (index, message) |
| 401 | + tool_result_messages = {} # tool_id -> (index, message) |
| 402 | + other_messages = [] # (index, message) for non-tool messages |
| 403 | + |
| 404 | + for i, message in enumerate(messages): |
| 405 | + if not isinstance(message, dict): |
| 406 | + other_messages.append((i, message)) |
| 407 | + continue |
| 408 | + |
| 409 | + role = message.get("role") |
| 410 | + |
| 411 | + if role == "assistant" and message.get("tool_calls"): |
| 412 | + # Extract tool calls from this assistant message |
| 413 | + tool_calls = message.get("tool_calls", []) |
| 414 | + if isinstance(tool_calls, list): |
| 415 | + for tool_call in tool_calls: |
| 416 | + if isinstance(tool_call, dict): |
| 417 | + tool_id = tool_call.get("id") |
| 418 | + if tool_id: |
| 419 | + # Create a separate assistant message for each tool call |
| 420 | + single_tool_msg = cast(dict[str, Any], message.copy()) |
| 421 | + single_tool_msg["tool_calls"] = [tool_call] |
| 422 | + tool_call_messages[tool_id] = ( |
| 423 | + i, |
| 424 | + cast(ChatCompletionMessageParam, single_tool_msg), |
| 425 | + ) |
| 426 | + |
| 427 | + elif role == "tool": |
| 428 | + tool_call_id = message.get("tool_call_id") |
| 429 | + if tool_call_id: |
| 430 | + tool_result_messages[tool_call_id] = (i, message) |
| 431 | + else: |
| 432 | + other_messages.append((i, message)) |
| 433 | + else: |
| 434 | + other_messages.append((i, message)) |
| 435 | + |
| 436 | + # First, identify which tool results will be paired to avoid duplicates |
| 437 | + paired_tool_result_indices = set() |
| 438 | + for tool_id in tool_call_messages: |
| 439 | + if tool_id in tool_result_messages: |
| 440 | + tool_result_idx, _ = tool_result_messages[tool_id] |
| 441 | + paired_tool_result_indices.add(tool_result_idx) |
| 442 | + |
| 443 | + # Create the fixed message sequence |
| 444 | + fixed_messages: list[ChatCompletionMessageParam] = [] |
| 445 | + used_indices = set() |
| 446 | + |
| 447 | + # Add messages in their original order, but ensure tool_use → tool_result pairing |
| 448 | + for i, original_message in enumerate(messages): |
| 449 | + if i in used_indices: |
| 450 | + continue |
| 451 | + |
| 452 | + if not isinstance(original_message, dict): |
| 453 | + fixed_messages.append(original_message) |
| 454 | + used_indices.add(i) |
| 455 | + continue |
| 456 | + |
| 457 | + role = original_message.get("role") |
| 458 | + |
| 459 | + if role == "assistant" and original_message.get("tool_calls"): |
| 460 | + # Process each tool call in this assistant message |
| 461 | + tool_calls = original_message.get("tool_calls", []) |
| 462 | + if isinstance(tool_calls, list): |
| 463 | + for tool_call in tool_calls: |
| 464 | + if isinstance(tool_call, dict): |
| 465 | + tool_id = tool_call.get("id") |
| 466 | + if ( |
| 467 | + tool_id |
| 468 | + and tool_id in tool_call_messages |
| 469 | + and tool_id in tool_result_messages |
| 470 | + ): |
| 471 | + # Add tool_use → tool_result pair |
| 472 | + _, tool_call_msg = tool_call_messages[tool_id] |
| 473 | + tool_result_idx, tool_result_msg = tool_result_messages[tool_id] |
| 474 | + |
| 475 | + fixed_messages.append(tool_call_msg) |
| 476 | + fixed_messages.append(tool_result_msg) |
| 477 | + |
| 478 | + # Mark both as used |
| 479 | + used_indices.add(tool_call_messages[tool_id][0]) |
| 480 | + used_indices.add(tool_result_idx) |
| 481 | + elif tool_id and tool_id in tool_call_messages: |
| 482 | + # Tool call without result - add just the tool call |
| 483 | + _, tool_call_msg = tool_call_messages[tool_id] |
| 484 | + fixed_messages.append(tool_call_msg) |
| 485 | + used_indices.add(tool_call_messages[tool_id][0]) |
| 486 | + |
| 487 | + used_indices.add(i) # Mark original multi-tool message as used |
| 488 | + |
| 489 | + elif role == "tool": |
| 490 | + # Only preserve unmatched tool results to avoid duplicates |
| 491 | + if i not in paired_tool_result_indices: |
| 492 | + fixed_messages.append(original_message) |
| 493 | + used_indices.add(i) |
| 494 | + |
| 495 | + else: |
| 496 | + # Regular message - add it normally |
| 497 | + fixed_messages.append(original_message) |
| 498 | + used_indices.add(i) |
| 499 | + |
| 500 | + return fixed_messages |
| 501 | + |
382 | 502 | def _remove_not_given(self, value: Any) -> Any:
|
383 | 503 | if isinstance(value, NotGiven):
|
384 | 504 | return None
|
|
0 commit comments