Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions app/api/v1/admin/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def _sanitize_token_text(value) -> str:
@router.get("/tokens", dependencies=[Depends(verify_app_key)])
async def get_tokens():
"""获取所有 Token"""
storage = get_storage()
tokens = await storage.load_tokens()
return tokens or {}
mgr = await get_token_manager()
results = {}
for pool_name, pool in mgr.pools.items():
results[pool_name] = [t.model_dump() for t in pool.list()]
return results or {}


@router.post("/tokens", dependencies=[Depends(verify_app_key)])
Expand Down Expand Up @@ -146,10 +148,15 @@ async def refresh_tokens(data: dict):
mgr,
)

# 强制保存变更到存储
await mgr._save(force=True)

results = {}
for token, res in raw_results.items():
if res.get("ok"):
results[token] = res.get("data", False)
# 只要请求执行了(不论 Token 是否可用),对于刷新动作来说都是完成的
# 我们通过检查是否包含 ok 字段来判定任务是否真正执行过
if "ok" in res:
results[token] = res.get("ok")
else:
results[token] = False

Expand Down
3 changes: 2 additions & 1 deletion app/services/grok/batch_services/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ async def get(self, token: str) -> Dict:
)
return data

except Exception:
except Exception as e:
# 最后一次失败已经被记录
logger.debug(f"UsageService.get failed for token {token[:10]}...: {str(e)}")
raise


Expand Down
55 changes: 47 additions & 8 deletions app/services/reverse/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,48 @@ async def _do_request():
)

if response.status_code != 200:
try:
resp_text = response.text
except Exception:
resp_text = "N/A"

# --- 识别逻辑开始 ---
# 区分是真正的 Token 过期还是 Cloudflare 拦截
is_token_expired = False
server_header = response.headers.get("Server", "").lower()
content_type = response.headers.get("Content-Type", "").lower()

# 1. 只有当返回不是 JSON 且包含 cloudflare 关键字,或者包含特定的 challenge 标志时,才认为是网络拦截
is_cloudflare = "challenge-platform" in resp_text
if "cloudflare" in server_header and "application/json" not in content_type:
is_cloudflare = True

# 2. 如果是 401 且返回 JSON 内容包含认证失败关键字,则确认为 Token 过期
if response.status_code == 401 and "application/json" in content_type:
# 增加 unauthenticated 和 bad-credentials 等更精确的关键字
body_lower = resp_text.lower()
auth_error_keywords = ["unauthorized", "not logged in", "unauthenticated", "bad-credentials"]
if any(k in body_lower for k in auth_error_keywords):
is_token_expired = True
# --- 识别逻辑结束 ---

logger.error(
f"RateLimitsReverse: Request failed, {response.status_code}",
"RateLimitsReverse: Request failed, status={}, is_token_expired={}, is_cloudflare={}, Body: {}",
response.status_code,
is_token_expired,
is_cloudflare,
resp_text[:300],
extra={"error_type": "UpstreamException"},
)

raise UpstreamException(
message=f"RateLimitsReverse: Request failed, {response.status_code}",
details={"status": response.status_code},
details={
"status": response.status_code,
"body": resp_text,
"is_token_expired": is_token_expired,
"is_cloudflare": is_cloudflare
},
)

return response
Expand All @@ -80,20 +115,24 @@ async def _do_request():
# Handle upstream exception
if isinstance(e, UpstreamException):
status = None
if e.details and "status" in e.details:
status = e.details["status"]
else:
if e.details and isinstance(e.details, dict):
status = e.details.get("status")

if status is None:
status = getattr(e, "status_code", None)

logger.debug(f"RateLimitsReverse: Upstream error caught: {str(e)}, status={status}")
raise

# Handle other non-upstream exceptions
import traceback
error_details = traceback.format_exc()
logger.error(
f"RateLimitsReverse: Request failed, {str(e)}",
extra={"error_type": type(e).__name__},
f"RateLimitsReverse: Unexpected error, {type(e).__name__}: {str(e)}\n{error_details}"
)
raise UpstreamException(
message=f"RateLimitsReverse: Request failed, {str(e)}",
details={"status": 502, "error": str(e)},
details={"status": 502, "error": str(e), "traceback": error_details},
)


Expand Down
17 changes: 14 additions & 3 deletions app/services/reverse/utils/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ def __init__(self):
# Decorrelated jitter state
self._last_delay = self.backoff_base

def should_retry(self, status_code: int) -> bool:
def should_retry(self, status_code: int, error: Exception = None) -> bool:
"""Check if should retry."""
if self.attempt >= self.max_retry:
return False
if status_code not in self.retry_codes:
return False
if self.total_delay >= self.retry_budget:
return False

# --- 准确判定逻辑开始 ---
# 如果已经明确判定为 Token 过期,则不进行重试
if isinstance(error, UpstreamException) and error.details:
if error.details.get("is_token_expired", False):
logger.warning("Confirmed Token Expired, skipping retries.")
return False
# --- 准确判定逻辑结束 ---

return True

def record_error(self, status_code: int, error: Exception):
Expand Down Expand Up @@ -175,14 +184,16 @@ def extract_status(e: Exception) -> Optional[int]:

if status_code is None:
# Error cannot be identified as retryable
logger.error(f"Non-retryable error: {e}")
import traceback
error_details = traceback.format_exc()
logger.error(f"Non-retryable error: {type(e).__name__}: {e}\n{error_details}")
raise

# Record error
ctx.record_error(status_code, e)

# Check if should retry
if ctx.should_retry(status_code):
if ctx.should_retry(status_code, e):
# Extract Retry-After
retry_after = extract_retry_after(e)

Expand Down
50 changes: 34 additions & 16 deletions app/services/token/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,16 +576,29 @@ async def sync_usage(

except Exception as e:
if isinstance(e, UpstreamException):
status = None
if e.details and "status" in e.details:
status = e.details["status"]
else:
status = getattr(e, "status_code", None)
status = e.details.get("status") if e.details else getattr(e, "status_code", None)
is_token_expired = e.details.get("is_token_expired", False) if e.details else False

if status == 401:
await self.record_fail(token_str, status, "rate_limits_auth_failed")
# 只要是 401,都应该记录一次失败,增加 fail_count
reason = "rate_limits_auth_failed" if is_token_expired else "rate_limits_auth_unknown"

# 如果确认为过期,传入 threshold=1 强制立即失效
await self.record_fail(token_str, status, reason, threshold=1 if is_token_expired else None)

if is_token_expired:
# 只有确认过期的才跳过 fallback
logger.warning(
f"Token {raw_token[:10]}...: API sync failed (Confirmed Token Expired), skipping fallback"
)
return False

logger.warning(
f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})"
f"Token {raw_token[:10]}...: API sync failed, error: {e}"
)
# 如果不执行降级扣费(例如在刷新状态时),则直接返回 False 表示同步失败
if not consume_on_fail:
return False

# 降级:本地预估扣费
if consume_on_fail:
Expand All @@ -598,15 +611,16 @@ async def sync_usage(
return False

async def record_fail(
self, token_str: str, status_code: int = 401, reason: str = ""
self, token_str: str, status_code: int = 401, reason: str = "", threshold: Optional[int] = None
) -> bool:
"""
记录 Token 失败

Args:
token_str: Token 字符串
status_code: HTTP 状态码
status_code: HTTP Status Code
reason: 失败原因
threshold: 强制失败阈值

Returns:
是否成功
Expand All @@ -617,18 +631,22 @@ async def record_fail(
token = pool.get(raw_token)
if token:
if status_code == 401:
threshold = get_config("token.fail_threshold", FAIL_THRESHOLD)
try:
threshold = int(threshold)
except (TypeError, ValueError):
threshold = FAIL_THRESHOLD
if threshold is None:
threshold = get_config("token.fail_threshold", FAIL_THRESHOLD)
try:
threshold = int(threshold)
except (TypeError, ValueError):
threshold = FAIL_THRESHOLD

if threshold < 1:
threshold = 1

token.record_fail(status_code, reason, threshold=threshold)
logger.warning(

log_level = logger.warning if token.status == TokenStatus.EXPIRED else logger.info
log_level(
f"Token {raw_token[:10]}...: recorded {status_code} failure "
f"({token.fail_count}/{threshold}) - {reason}"
f"({token.fail_count}/{threshold}) - {reason} - status: {token.status}"
)
self._track_token_change(token, pool.name, "state")
self._schedule_save()
Expand Down