Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/deepseek_cursor_proxy/reasoning_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def tool_call_ids(message: dict[str, Any]) -> list[str]:
return ids


def tool_call_names(message: dict[str, Any]) -> list[str]:
names: list[str] = []
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
function = tool_call.get("function")
if isinstance(function, dict) and function.get("name"):
names.append(str(function["name"]))
return names


def message_signature(message: dict[str, Any]) -> str:
tool_calls = [
normalize_tool_call(tool_call)
Expand Down Expand Up @@ -172,6 +183,10 @@ def store_assistant_message(self, message: dict[str, Any], scope: str) -> int:
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
keys.extend(
f"scope:{scope}:tool_name:{tool_name}"
for tool_name in tool_call_names(message)
)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tool_call_names() can return duplicate names when a single assistant message includes multiple tool calls to the same function. That will add duplicate scope:{scope}:tool_name:... keys, causing redundant SQLite writes (and inflating the returned key count). Consider de-duping tool names (or de-duping keys overall) before calling put() to avoid unnecessary DB churn.

Suggested change
)
)
keys = list(dict.fromkeys(keys))

Copilot uses AI. Check for mistakes.
for key in keys:
self.put(key, reasoning, message)
return len(keys)
Expand All @@ -192,6 +207,10 @@ def lookup_for_message(self, message: dict[str, Any], scope: str) -> str | None:
)
if reasoning is not None:
return reasoning
for tool_name in tool_call_names(message):
reasoning = self.get(f"scope:{scope}:tool_name:{tool_name}")
if reasoning is not None:
return reasoning
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tool-name-only fallback in lookup_for_message() introduces new matching behavior that can materially affect reasoning restoration when tool-call arguments/IDs are incomplete (e.g., Stop during streaming). Please add/adjust unit tests to cover (1) storing an assistant tool-call message and restoring its reasoning when only the function name matches, and (2) ensuring this fallback is only used when the higher-specificity keys miss.

Copilot uses AI. Check for mistakes.
return None

def clear(self) -> int:
Expand Down
75 changes: 45 additions & 30 deletions src/deepseek_cursor_proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def do_POST(self) -> None:
prepared.payload["messages"],
prepared.cache_namespace,
prepared.recovery_notice,
prepared.record_response_scope,
)
else:
sent_response = self._proxy_regular_response(
Expand All @@ -274,6 +275,7 @@ def do_POST(self) -> None:
prepared.payload["messages"],
prepared.cache_namespace,
prepared.recovery_notice,
prepared.record_response_scope,
)
if not sent_response:
return
Expand Down Expand Up @@ -436,6 +438,7 @@ def _proxy_regular_response(
request_messages: list[dict[str, Any]],
cache_namespace: str,
recovery_notice: str | None = None,
record_response_scope: str | None = None,
) -> bool:
body = read_response_body(response)
try:
Expand All @@ -446,6 +449,7 @@ def _proxy_regular_response(
request_messages,
cache_namespace,
content_prefix=recovery_notice,
scope=record_response_scope,
)
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
LOG.warning("failed to rewrite upstream JSON response: %s", exc)
Expand Down Expand Up @@ -476,6 +480,7 @@ def _proxy_streaming_response(
request_messages: list[dict[str, Any]],
cache_namespace: str,
recovery_notice: str | None = None,
record_response_scope: str | None = None,
) -> bool:
sent_headers = self._send_response_headers(
getattr(response, "status", 200),
Expand All @@ -496,38 +501,48 @@ def _proxy_streaming_response(
if self.config.cursor_display_reasoning
else None
)
scope = conversation_scope(request_messages, cache_namespace)
scope = (
record_response_scope
if record_response_scope is not None
else conversation_scope(request_messages, cache_namespace)
)
finalized = False
pending_recovery_notice = recovery_notice
while True:
try:
line = response.readline()
except (HTTPException, OSError) as exc:
LOG.warning("upstream streaming response read failed: %s", exc)
return False
if not line:
break
rewritten, finalized, pending_recovery_notice = self._rewrite_sse_line(
line,
original_model,
accumulator,
scope,
display_adapter,
pending_recovery_notice,
)
if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True
):
return False
if finalized:
break

if not finalized:
if self.config.verbose:
log_json("model streaming assistant messages", accumulator.messages())
stored = accumulator.store_reasoning(self.reasoning_store, scope)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
try:
while True:
try:
line = response.readline()
except (HTTPException, OSError) as exc:
LOG.warning("upstream streaming response read failed: %s", exc)
return False
if not line:
break
rewritten, finalized, pending_recovery_notice = self._rewrite_sse_line(
line,
original_model,
accumulator,
scope,
display_adapter,
pending_recovery_notice,
)
if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True
):
return False
if finalized:
break
finally:
if not finalized:
if self.config.verbose:
log_json(
"model streaming assistant messages", accumulator.messages()
)
stored = accumulator.store_reasoning(self.reasoning_store, scope)
if stored:
LOG.info(
"stored %s streaming reasoning cache key(s) before exit",
stored,
)
return True

def _rewrite_sse_line(
Expand Down
18 changes: 15 additions & 3 deletions src/deepseek_cursor_proxy/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class PreparedRequest:
recovered_reasoning_messages: int = 0
recovery_dropped_messages: int = 0
recovery_notice: str | None = None
record_response_scope: str | None = None


def normalize_reasoning_effort(value: Any) -> str:
Expand Down Expand Up @@ -485,6 +486,8 @@ def prepare_upstream_request(
repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
record_response_scope = conversation_scope(messages, cache_namespace)

recovered_count = 0
recovery_dropped_messages = 0
recovery_notice = None
Expand Down Expand Up @@ -517,6 +520,7 @@ def prepare_upstream_request(
recovered_reasoning_messages=recovered_count,
recovery_dropped_messages=recovery_dropped_messages,
recovery_notice=recovery_notice,
record_response_scope=record_response_scope,
)


Expand All @@ -525,20 +529,23 @@ def record_response_reasoning(
store: ReasoningStore | None,
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
scope: str | None = None,
) -> int:
if store is None:
return 0
stored = 0
choices = response_payload.get("choices")
if not isinstance(choices, list):
return stored
scope = conversation_scope(request_messages, cache_namespace)
response_scope = scope if scope is not None else conversation_scope(
request_messages, cache_namespace
)
for choice in choices:
if not isinstance(choice, dict):
continue
message = choice.get("message")
if isinstance(message, dict):
stored += store.store_assistant_message(message, scope)
stored += store.store_assistant_message(message, response_scope)
return stored


Expand All @@ -549,13 +556,18 @@ def rewrite_response_body(
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
content_prefix: str | None = None,
scope: str | None = None,
) -> bytes:
response_payload = json.loads(body.decode("utf-8"))
if isinstance(response_payload, dict):
if content_prefix:
prefix_response_content(response_payload, content_prefix)
record_response_reasoning(
response_payload, store, request_messages, cache_namespace
response_payload,
store,
request_messages,
cache_namespace,
scope=scope,
)
if "model" in response_payload:
response_payload["model"] = original_model
Expand Down
Loading