diff --git a/src/deepseek_cursor_proxy/transform.py b/src/deepseek_cursor_proxy/transform.py index 0053d98..b44e82f 100644 --- a/src/deepseek_cursor_proxy/transform.py +++ b/src/deepseek_cursor_proxy/transform.py @@ -577,6 +577,19 @@ def recover_messages_from_missing_reasoning( messages: list[dict[str, Any]], missing_indexes: list[int], ) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]: + """Surgically remove only the messages with unrecoverable reasoning and their + associated tool results, preserving all other conversation context. + + Previous behavior (nuclear): dropped *all* messages except system + last user, + causing catastrophic context loss (production: 307 of 309 messages dropped). + + New behavior (surgical): for each assistant message with missing reasoning: + - If it had tool_calls, remove it and all subsequent tool result messages + (until the next user/system/non-tool assistant message). + - If it had no tool_calls but needed reasoning (positioned between tool + calls), find the preceding tool-call pair and remove that chain. + """ + # --- Strategy 1: recovery_boundary (existing logic, unchanged) --- recovery_boundary_index = next( ( index @@ -623,42 +636,79 @@ def recover_messages_from_missing_reasoning( }, ) - last_user_index = next( - ( - index - for index in range(len(messages) - 1, -1, -1) - if messages[index].get("role") == "user" - ), - -1, - ) - if last_user_index == -1: - return ( - messages, - 0, - None, - { - "strategy": "none", - "missing_indexes": missing_indexes, - "last_user_index": None, - "dropped_messages": 0, - "notice": None, - }, - ) + # --- Strategy 2: surgical removal (replaces nuclear latest_user) --- + removed_indices: set[int] = set() + missing_set = set(missing_indexes) + + for missing_idx in sorted(missing_indexes): + if missing_idx >= len(messages) or missing_idx in removed_indices: + continue + msg = messages[missing_idx] + if msg.get("role") != "assistant": + continue + + removed_indices.add(missing_idx) + + if msg.get("tool_calls"): + j = missing_idx + 1 + while j < len(messages): + nxt = messages[j] + if nxt.get("role") == "tool": + removed_indices.add(j) + j += 1 + elif nxt.get("role") == "assistant" and j in missing_set: + removed_indices.add(j) + if nxt.get("tool_calls"): + j += 1 + continue + j += 1 + else: + break + else: + # Message has no tool_calls but was marked as needing reasoning + # (preceded by tool results from a valid tool-call assistant). + # Remove only this message; the tool results remain valid context. + pass - recovered = leading_system_messages(messages) - omitted_messages = len(messages) - len(recovered) - 1 - recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT}) - recovered.append(messages[last_user_index]) + if not removed_indices: + last_user_index = next( + ( + index + for index in range(len(messages) - 1, -1, -1) + if messages[index].get("role") == "user" + ), + -1, + ) + if last_user_index == -1: + return ( + messages, + 0, + None, + { + "strategy": "none", + "missing_indexes": missing_indexes, + "last_user_index": None, + "dropped_messages": 0, + "notice": None, + }, + ) + + recovered = [msg for i, msg in enumerate(messages) if i not in removed_indices] + omitted_messages = len(messages) - len(recovered) + # Surgical removal preserves all surrounding context, so there is no need to + # plant a user-facing recovery notice / boundary marker. Suppressing it keeps + # the conversation clean and prevents the recovery_boundary strategy from + # nuking earlier context on later turns. return ( recovered, omitted_messages, - RECOVERY_NOTICE_CONTENT, + None, { - "strategy": "latest_user", + "strategy": "surgical", "missing_indexes": missing_indexes, - "last_user_index": last_user_index, "dropped_messages": omitted_messages, - "notice": RECOVERY_NOTICE_CONTENT, + "removed_indices": sorted(removed_indices), + "notice": None, }, ) @@ -820,12 +870,6 @@ def prepare_upstream_request( recovery_dropped_messages = 0 recovery_notice = None recovery_steps: list[dict[str, Any]] = [] - if thinking_enabled and config.missing_reasoning_strategy == "recover": - boundary = active_messages_from_recovery_boundary(pre_repair_messages) - if boundary is not None: - messages_for_repair, retired_prefix_messages, boundary_step = boundary - continued_recovery_boundary = True - recovery_steps.append(boundary_step) messages, patched_count, missing_indexes, reasoning_diagnostics = ( normalize_messages( @@ -836,6 +880,29 @@ def prepare_upstream_request( keep_reasoning=not thinking_disabled, ) ) + + if ( + missing_indexes + and thinking_enabled + and config.missing_reasoning_strategy == "recover" + ): + boundary = active_messages_from_recovery_boundary(pre_repair_messages) + if boundary is not None: + messages_for_repair, retired_prefix_messages, boundary_step = boundary + continued_recovery_boundary = True + recovery_steps.append(boundary_step) + ( + messages, + patched_count, + missing_indexes, + reasoning_diagnostics, + ) = normalize_messages( + messages_for_repair, + store, + cache_namespace, + repair_reasoning=thinking_enabled, + keep_reasoning=not thinking_disabled, + ) while missing_indexes and config.missing_reasoning_strategy == "recover": recovered_messages, dropped_messages, notice, recovery_step = ( recover_messages_from_missing_reasoning(messages, missing_indexes) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e763356..2a7a5f1 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -543,10 +543,11 @@ def test_disabled_does_not_inject_reasoning(self) -> None: class RecoveryTests(_StrictUpstreamCase): - def test_cold_cache_recovers_to_latest_user_with_notice(self) -> None: - """Stale tool history with no cached reasoning: proxy keeps only - the latest user message + recovery system message and prefixes a - user-facing notice into the response.""" + def test_cold_cache_surgically_removes_unrecoverable_tool_chain(self) -> None: + """Stale tool history with no cached reasoning: the proxy surgically + removes only the assistant tool-call message and its tool result, + preserving all surrounding user/system context and emitting no + user-facing recovery notice.""" status, response = _post( f"{self.proxy.url}/v1/chat/completions", { @@ -577,12 +578,13 @@ def test_cold_cache_recovers_to_latest_user_with_notice(self) -> None: self.assertEqual(status, 200, response) sent = StrictFakeDeepSeek.requests[-1] self.assertEqual( - [m["role"] for m in sent["messages"]], ["system", "system", "user"] + [m["role"] for m in sent["messages"]], ["system", "user", "user"] ) + self.assertEqual(sent["messages"][1]["content"], "old work") self.assertEqual( sent["messages"][-1]["content"], "Thanks. What about Saturday?" ) - self.assertIn( + self.assertNotIn( "[deepseek-cursor-proxy] Refreshed reasoning", response["choices"][0]["message"]["content"], ) diff --git a/tests/test_trace.py b/tests/test_trace.py index bf6db0e..4c37d45 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -314,7 +314,7 @@ def test_captures_recovery_diagnostics(self) -> None: ) trace = _read_single_trace(self.writer.session_dir) self.assertEqual( - trace["transform"]["recovery_steps"][0]["strategy"], "latest_user" + trace["transform"]["recovery_steps"][0]["strategy"], "surgical" ) self.assertGreaterEqual( len( diff --git a/tests/test_transform.py b/tests/test_transform.py index ef67574..94304ba 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -24,12 +24,113 @@ normalize_reasoning_effort, prepare_upstream_request, reasoning_cache_namespace, + recover_messages_from_missing_reasoning, rewrite_response_body, strip_cursor_thinking_blocks, strip_recovery_notice_for_upstream, ) +class SurgicalRecoveryTests(unittest.TestCase): + """Direct coverage for recover_messages_from_missing_reasoning. + + Regression guard for the production incident where a single uncacheable + tool-call assistant inside a long conversation caused the proxy to drop + *all* but the last user message (observed: 307 of 309 messages dropped). + """ + + def _convo(self, n_turns: int) -> list[dict]: + messages: list[dict] = [{"role": "system", "content": "sys"}] + for i in range(n_turns): + messages.append({"role": "user", "content": f"user {i}"}) + messages.append( + { + "role": "assistant", + "content": "", + "reasoning_content": f"reasoning {i}", + "tool_calls": [ + { + "id": f"call_{i}", + "type": "function", + "function": {"name": "do", "arguments": "{}"}, + } + ], + } + ) + messages.append( + {"role": "tool", "tool_call_id": f"call_{i}", "content": f"r{i}"} + ) + messages.append({"role": "assistant", "content": f"answer {i}"}) + return messages + + def test_single_miss_in_long_conversation_preserves_bulk_context(self) -> None: + messages = self._convo(50) # 1 + 50*4 = 201 messages + # One uncacheable tool-call assistant somewhere in the middle. + missing_idx = 1 + 25 * 4 + 1 # the assistant of turn 25 + self.assertEqual(messages[missing_idx]["role"], "assistant") + self.assertTrue(messages[missing_idx].get("tool_calls")) + del messages[missing_idx]["reasoning_content"] + + recovered, dropped, notice, step = recover_messages_from_missing_reasoning( + messages, [missing_idx] + ) + + # Only the missing tool-call assistant + its tool result are removed. + self.assertEqual(dropped, 2) + self.assertEqual(len(recovered), len(messages) - 2) + self.assertEqual(step["strategy"], "surgical") + # No user-facing notice / boundary is planted for surgical removals. + self.assertIsNone(notice) + # Every other turn's content survives intact. + self.assertIn({"role": "user", "content": "user 0"}, recovered) + self.assertIn({"role": "user", "content": "user 49"}, recovered) + self.assertIn({"role": "assistant", "content": "answer 24"}, recovered) + + def test_assistant_without_tool_calls_removes_only_itself(self) -> None: + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "u"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "ok", + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": {"name": "do", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_a", "content": "res"}, + {"role": "assistant", "content": "final answer"}, + ] + # The trailing plain-text assistant (needs reasoning because it follows + # a tool result) is the one missing reasoning. + recovered, dropped, notice, step = recover_messages_from_missing_reasoning( + messages, [4] + ) + self.assertEqual(dropped, 1) + self.assertEqual(step["strategy"], "surgical") + self.assertIsNone(notice) + # The valid tool-call pair is untouched (no orphaned tool message). + roles = [m["role"] for m in recovered] + self.assertEqual(roles, ["system", "user", "assistant", "tool"]) + + def test_multiple_misses_remove_each_chain_independently(self) -> None: + messages = self._convo(10) + idx_a = 1 + 2 * 4 + 1 + idx_b = 1 + 7 * 4 + 1 + del messages[idx_a]["reasoning_content"] + del messages[idx_b]["reasoning_content"] + recovered, dropped, _notice, step = recover_messages_from_missing_reasoning( + messages, [idx_a, idx_b] + ) + self.assertEqual(dropped, 4) # two assistant+tool pairs + self.assertEqual(step["strategy"], "surgical") + self.assertEqual(len(recovered), len(messages) - 4) + + def _default_cache_namespace() -> str: return reasoning_cache_namespace( ProxyConfig(), @@ -687,8 +788,10 @@ def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None: recovered_assistant.pop("reasoning_content", None) # Cursor's next request echoes the recovered assistant + tool result. - # The proxy should detect the recovery boundary, retire the prefix, - # and continue cleanly without recovering again. + # The stale `call_old` assistant is still missing its reasoning, so the + # proxy surgically removes that pair again (idempotent) while restoring + # the new assistant's reasoning from cache. Surrounding user context — + # including the original "old model turn" — is preserved. second_payload = { "model": "deepseek-v4-pro", "messages": [ @@ -705,10 +808,14 @@ def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None: ) self.assertEqual(second_prepared.missing_reasoning_messages, 0) - self.assertEqual(second_prepared.recovered_reasoning_messages, 0) - self.assertEqual(second_prepared.recovery_dropped_messages, 0) - self.assertTrue(second_prepared.continued_recovery_boundary) - self.assertGreater(second_prepared.retired_prefix_messages, 0) + self.assertEqual(second_prepared.recovered_reasoning_messages, 1) + self.assertGreater(second_prepared.recovery_dropped_messages, 0) + self.assertFalse(second_prepared.continued_recovery_boundary) + self.assertEqual(second_prepared.retired_prefix_messages, 0) + self.assertEqual( + [m["role"] for m in second_prepared.payload["messages"]], + ["user", "user", "assistant", "tool"], + ) self.assertEqual( second_prepared.payload["messages"][2]["reasoning_content"], "Need the new lookup.",