Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 102 additions & 35 deletions src/deepseek_cursor_proxy/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,19 @@ def recover_messages_from_missing_reasoning(
messages: list[dict[str, Any]],
missing_indexes: list[int],
) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]:
"""Surgically remove only the messages with unrecoverable reasoning and their
associated tool results, preserving all other conversation context.

Previous behavior (nuclear): dropped *all* messages except system + last user,
causing catastrophic context loss (production: 307 of 309 messages dropped).

New behavior (surgical): for each assistant message with missing reasoning:
- If it had tool_calls, remove it and all subsequent tool result messages
(until the next user/system/non-tool assistant message).
- If it had no tool_calls but needed reasoning (positioned between tool
calls), find the preceding tool-call pair and remove that chain.
"""
# --- Strategy 1: recovery_boundary (existing logic, unchanged) ---
recovery_boundary_index = next(
(
index
Expand Down Expand Up @@ -623,42 +636,79 @@ def recover_messages_from_missing_reasoning(
},
)

last_user_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
if last_user_index == -1:
return (
messages,
0,
None,
{
"strategy": "none",
"missing_indexes": missing_indexes,
"last_user_index": None,
"dropped_messages": 0,
"notice": None,
},
)
# --- Strategy 2: surgical removal (replaces nuclear latest_user) ---
removed_indices: set[int] = set()
missing_set = set(missing_indexes)

for missing_idx in sorted(missing_indexes):
if missing_idx >= len(messages) or missing_idx in removed_indices:
continue
msg = messages[missing_idx]
if msg.get("role") != "assistant":
continue

removed_indices.add(missing_idx)

if msg.get("tool_calls"):
j = missing_idx + 1
while j < len(messages):
nxt = messages[j]
if nxt.get("role") == "tool":
removed_indices.add(j)
j += 1
elif nxt.get("role") == "assistant" and j in missing_set:
removed_indices.add(j)
if nxt.get("tool_calls"):
j += 1
continue
j += 1
else:
break
else:
# Message has no tool_calls but was marked as needing reasoning
# (preceded by tool results from a valid tool-call assistant).
# Remove only this message; the tool results remain valid context.
pass

recovered = leading_system_messages(messages)
omitted_messages = len(messages) - len(recovered) - 1
recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT})
recovered.append(messages[last_user_index])
if not removed_indices:
last_user_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
if last_user_index == -1:
return (
messages,
0,
None,
{
"strategy": "none",
"missing_indexes": missing_indexes,
"last_user_index": None,
"dropped_messages": 0,
"notice": None,
},
)

recovered = [msg for i, msg in enumerate(messages) if i not in removed_indices]
omitted_messages = len(messages) - len(recovered)
# Surgical removal preserves all surrounding context, so there is no need to
# plant a user-facing recovery notice / boundary marker. Suppressing it keeps
# the conversation clean and prevents the recovery_boundary strategy from
# nuking earlier context on later turns.
return (
recovered,
omitted_messages,
RECOVERY_NOTICE_CONTENT,
None,
{
"strategy": "latest_user",
"strategy": "surgical",
"missing_indexes": missing_indexes,
"last_user_index": last_user_index,
"dropped_messages": omitted_messages,
"notice": RECOVERY_NOTICE_CONTENT,
"removed_indices": sorted(removed_indices),
"notice": None,
},
)

Expand Down Expand Up @@ -820,12 +870,6 @@ def prepare_upstream_request(
recovery_dropped_messages = 0
recovery_notice = None
recovery_steps: list[dict[str, Any]] = []
if thinking_enabled and config.missing_reasoning_strategy == "recover":
boundary = active_messages_from_recovery_boundary(pre_repair_messages)
if boundary is not None:
messages_for_repair, retired_prefix_messages, boundary_step = boundary
continued_recovery_boundary = True
recovery_steps.append(boundary_step)

messages, patched_count, missing_indexes, reasoning_diagnostics = (
normalize_messages(
Expand All @@ -836,6 +880,29 @@ def prepare_upstream_request(
keep_reasoning=not thinking_disabled,
)
)

if (
missing_indexes
and thinking_enabled
and config.missing_reasoning_strategy == "recover"
):
boundary = active_messages_from_recovery_boundary(pre_repair_messages)
if boundary is not None:
messages_for_repair, retired_prefix_messages, boundary_step = boundary
continued_recovery_boundary = True
recovery_steps.append(boundary_step)
(
messages,
patched_count,
missing_indexes,
reasoning_diagnostics,
) = normalize_messages(
messages_for_repair,
store,
cache_namespace,
repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
while missing_indexes and config.missing_reasoning_strategy == "recover":
recovered_messages, dropped_messages, notice, recovery_step = (
recover_messages_from_missing_reasoning(messages, missing_indexes)
Expand Down
14 changes: 8 additions & 6 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,11 @@ def test_disabled_does_not_inject_reasoning(self) -> None:


class RecoveryTests(_StrictUpstreamCase):
def test_cold_cache_recovers_to_latest_user_with_notice(self) -> None:
"""Stale tool history with no cached reasoning: proxy keeps only
the latest user message + recovery system message and prefixes a
user-facing notice into the response."""
def test_cold_cache_surgically_removes_unrecoverable_tool_chain(self) -> None:
"""Stale tool history with no cached reasoning: the proxy surgically
removes only the assistant tool-call message and its tool result,
preserving all surrounding user/system context and emitting no
user-facing recovery notice."""
status, response = _post(
f"{self.proxy.url}/v1/chat/completions",
{
Expand Down Expand Up @@ -577,12 +578,13 @@ def test_cold_cache_recovers_to_latest_user_with_notice(self) -> None:
self.assertEqual(status, 200, response)
sent = StrictFakeDeepSeek.requests[-1]
self.assertEqual(
[m["role"] for m in sent["messages"]], ["system", "system", "user"]
[m["role"] for m in sent["messages"]], ["system", "user", "user"]
)
self.assertEqual(sent["messages"][1]["content"], "old work")
self.assertEqual(
sent["messages"][-1]["content"], "Thanks. What about Saturday?"
)
self.assertIn(
self.assertNotIn(
"[deepseek-cursor-proxy] Refreshed reasoning",
response["choices"][0]["message"]["content"],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def test_captures_recovery_diagnostics(self) -> None:
)
trace = _read_single_trace(self.writer.session_dir)
self.assertEqual(
trace["transform"]["recovery_steps"][0]["strategy"], "latest_user"
trace["transform"]["recovery_steps"][0]["strategy"], "surgical"
)
self.assertGreaterEqual(
len(
Expand Down
119 changes: 113 additions & 6 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,113 @@
normalize_reasoning_effort,
prepare_upstream_request,
reasoning_cache_namespace,
recover_messages_from_missing_reasoning,
rewrite_response_body,
strip_cursor_thinking_blocks,
strip_recovery_notice_for_upstream,
)


class SurgicalRecoveryTests(unittest.TestCase):
"""Direct coverage for recover_messages_from_missing_reasoning.

Regression guard for the production incident where a single uncacheable
tool-call assistant inside a long conversation caused the proxy to drop
*all* but the last user message (observed: 307 of 309 messages dropped).
"""

def _convo(self, n_turns: int) -> list[dict]:
messages: list[dict] = [{"role": "system", "content": "sys"}]
for i in range(n_turns):
messages.append({"role": "user", "content": f"user {i}"})
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": f"reasoning {i}",
"tool_calls": [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": "do", "arguments": "{}"},
}
],
}
)
messages.append(
{"role": "tool", "tool_call_id": f"call_{i}", "content": f"r{i}"}
)
messages.append({"role": "assistant", "content": f"answer {i}"})
return messages

def test_single_miss_in_long_conversation_preserves_bulk_context(self) -> None:
messages = self._convo(50) # 1 + 50*4 = 201 messages
# One uncacheable tool-call assistant somewhere in the middle.
missing_idx = 1 + 25 * 4 + 1 # the assistant of turn 25
self.assertEqual(messages[missing_idx]["role"], "assistant")
self.assertTrue(messages[missing_idx].get("tool_calls"))
del messages[missing_idx]["reasoning_content"]

recovered, dropped, notice, step = recover_messages_from_missing_reasoning(
messages, [missing_idx]
)

# Only the missing tool-call assistant + its tool result are removed.
self.assertEqual(dropped, 2)
self.assertEqual(len(recovered), len(messages) - 2)
self.assertEqual(step["strategy"], "surgical")
# No user-facing notice / boundary is planted for surgical removals.
self.assertIsNone(notice)
# Every other turn's content survives intact.
self.assertIn({"role": "user", "content": "user 0"}, recovered)
self.assertIn({"role": "user", "content": "user 49"}, recovered)
self.assertIn({"role": "assistant", "content": "answer 24"}, recovered)

def test_assistant_without_tool_calls_removes_only_itself(self) -> None:
messages = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u"},
{
"role": "assistant",
"content": "",
"reasoning_content": "ok",
"tool_calls": [
{
"id": "call_a",
"type": "function",
"function": {"name": "do", "arguments": "{}"},
}
],
},
{"role": "tool", "tool_call_id": "call_a", "content": "res"},
{"role": "assistant", "content": "final answer"},
]
# The trailing plain-text assistant (needs reasoning because it follows
# a tool result) is the one missing reasoning.
recovered, dropped, notice, step = recover_messages_from_missing_reasoning(
messages, [4]
)
self.assertEqual(dropped, 1)
self.assertEqual(step["strategy"], "surgical")
self.assertIsNone(notice)
# The valid tool-call pair is untouched (no orphaned tool message).
roles = [m["role"] for m in recovered]
self.assertEqual(roles, ["system", "user", "assistant", "tool"])

def test_multiple_misses_remove_each_chain_independently(self) -> None:
messages = self._convo(10)
idx_a = 1 + 2 * 4 + 1
idx_b = 1 + 7 * 4 + 1
del messages[idx_a]["reasoning_content"]
del messages[idx_b]["reasoning_content"]
recovered, dropped, _notice, step = recover_messages_from_missing_reasoning(
messages, [idx_a, idx_b]
)
self.assertEqual(dropped, 4) # two assistant+tool pairs
self.assertEqual(step["strategy"], "surgical")
self.assertEqual(len(recovered), len(messages) - 4)


def _default_cache_namespace() -> str:
return reasoning_cache_namespace(
ProxyConfig(),
Expand Down Expand Up @@ -687,8 +788,10 @@ def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None:
recovered_assistant.pop("reasoning_content", None)

# Cursor's next request echoes the recovered assistant + tool result.
# The proxy should detect the recovery boundary, retire the prefix,
# and continue cleanly without recovering again.
# The stale `call_old` assistant is still missing its reasoning, so the
# proxy surgically removes that pair again (idempotent) while restoring
# the new assistant's reasoning from cache. Surrounding user context —
# including the original "old model turn" — is preserved.
second_payload = {
"model": "deepseek-v4-pro",
"messages": [
Expand All @@ -705,10 +808,14 @@ def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None:
)

self.assertEqual(second_prepared.missing_reasoning_messages, 0)
self.assertEqual(second_prepared.recovered_reasoning_messages, 0)
self.assertEqual(second_prepared.recovery_dropped_messages, 0)
self.assertTrue(second_prepared.continued_recovery_boundary)
self.assertGreater(second_prepared.retired_prefix_messages, 0)
self.assertEqual(second_prepared.recovered_reasoning_messages, 1)
self.assertGreater(second_prepared.recovery_dropped_messages, 0)
self.assertFalse(second_prepared.continued_recovery_boundary)
self.assertEqual(second_prepared.retired_prefix_messages, 0)
self.assertEqual(
[m["role"] for m in second_prepared.payload["messages"]],
["user", "user", "assistant", "tool"],
)
self.assertEqual(
second_prepared.payload["messages"][2]["reasoning_content"],
"Need the new lookup.",
Expand Down