diff --git a/app/api/v1/admin/token.py b/app/api/v1/admin/token.py index d417ee887..f2afa11cb 100644 --- a/app/api/v1/admin/token.py +++ b/app/api/v1/admin/token.py @@ -46,9 +46,11 @@ def _sanitize_token_text(value) -> str: @router.get("/tokens", dependencies=[Depends(verify_app_key)]) async def get_tokens(): """获取所有 Token""" - storage = get_storage() - tokens = await storage.load_tokens() - return tokens or {} + mgr = await get_token_manager() + results = {} + for pool_name, pool in mgr.pools.items(): + results[pool_name] = [t.model_dump() for t in pool.list()] + return results or {} @router.post("/tokens", dependencies=[Depends(verify_app_key)]) @@ -146,10 +148,15 @@ async def refresh_tokens(data: dict): mgr, ) + # 强制保存变更到存储 + await mgr._save(force=True) + results = {} for token, res in raw_results.items(): - if res.get("ok"): - results[token] = res.get("data", False) + # 只要请求执行了(不论 Token 是否可用),对于刷新动作来说都是完成的 + # 我们通过检查是否包含 ok 字段来判定任务是否真正执行过 + if "ok" in res: + results[token] = res.get("ok") else: results[token] = False diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index fa4dd76c2..9904737da 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -60,8 +60,9 @@ async def get(self, token: str) -> Dict: ) return data - except Exception: + except Exception as e: # 最后一次失败已经被记录 + logger.debug(f"UsageService.get failed for token {token[:10]}...: {str(e)}") raise diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py index 10e6d71f6..e33a32c62 100644 --- a/app/services/reverse/rate_limits.py +++ b/app/services/reverse/rate_limits.py @@ -63,13 +63,48 @@ async def _do_request(): ) if response.status_code != 200: + try: + resp_text = response.text + except Exception: + resp_text = "N/A" + + # --- 识别逻辑开始 --- + # 区分是真正的 Token 过期还是 Cloudflare 拦截 + is_token_expired = False + server_header = response.headers.get("Server", "").lower() + content_type = response.headers.get("Content-Type", "").lower() + + # 1. 只有当返回不是 JSON 且包含 cloudflare 关键字,或者包含特定的 challenge 标志时,才认为是网络拦截 + is_cloudflare = "challenge-platform" in resp_text + if "cloudflare" in server_header and "application/json" not in content_type: + is_cloudflare = True + + # 2. 如果是 401 且返回 JSON 内容包含认证失败关键字,则确认为 Token 过期 + if response.status_code == 401 and "application/json" in content_type: + # 增加 unauthenticated 和 bad-credentials 等更精确的关键字 + body_lower = resp_text.lower() + auth_error_keywords = ["unauthorized", "not logged in", "unauthenticated", "bad-credentials"] + if any(k in body_lower for k in auth_error_keywords): + is_token_expired = True + # --- 识别逻辑结束 --- + logger.error( - f"RateLimitsReverse: Request failed, {response.status_code}", + "RateLimitsReverse: Request failed, status={}, is_token_expired={}, is_cloudflare={}, Body: {}", + response.status_code, + is_token_expired, + is_cloudflare, + resp_text[:300], extra={"error_type": "UpstreamException"}, ) + raise UpstreamException( message=f"RateLimitsReverse: Request failed, {response.status_code}", - details={"status": response.status_code}, + details={ + "status": response.status_code, + "body": resp_text, + "is_token_expired": is_token_expired, + "is_cloudflare": is_cloudflare + }, ) return response @@ -80,20 +115,24 @@ async def _do_request(): # Handle upstream exception if isinstance(e, UpstreamException): status = None - if e.details and "status" in e.details: - status = e.details["status"] - else: + if e.details and isinstance(e.details, dict): + status = e.details.get("status") + + if status is None: status = getattr(e, "status_code", None) + + logger.debug(f"RateLimitsReverse: Upstream error caught: {str(e)}, status={status}") raise # Handle other non-upstream exceptions + import traceback + error_details = traceback.format_exc() logger.error( - f"RateLimitsReverse: Request failed, {str(e)}", - extra={"error_type": type(e).__name__}, + f"RateLimitsReverse: Unexpected error, {type(e).__name__}: {str(e)}\n{error_details}" ) raise UpstreamException( message=f"RateLimitsReverse: Request failed, {str(e)}", - details={"status": 502, "error": str(e)}, + details={"status": 502, "error": str(e), "traceback": error_details}, ) diff --git a/app/services/reverse/utils/retry.py b/app/services/reverse/utils/retry.py index 971eab05b..16f81fb4c 100644 --- a/app/services/reverse/utils/retry.py +++ b/app/services/reverse/utils/retry.py @@ -32,7 +32,7 @@ def __init__(self): # Decorrelated jitter state self._last_delay = self.backoff_base - def should_retry(self, status_code: int) -> bool: + def should_retry(self, status_code: int, error: Exception = None) -> bool: """Check if should retry.""" if self.attempt >= self.max_retry: return False @@ -40,6 +40,15 @@ def should_retry(self, status_code: int) -> bool: return False if self.total_delay >= self.retry_budget: return False + + # --- 准确判定逻辑开始 --- + # 如果已经明确判定为 Token 过期,则不进行重试 + if isinstance(error, UpstreamException) and error.details: + if error.details.get("is_token_expired", False): + logger.warning("Confirmed Token Expired, skipping retries.") + return False + # --- 准确判定逻辑结束 --- + return True def record_error(self, status_code: int, error: Exception): @@ -175,14 +184,16 @@ def extract_status(e: Exception) -> Optional[int]: if status_code is None: # Error cannot be identified as retryable - logger.error(f"Non-retryable error: {e}") + import traceback + error_details = traceback.format_exc() + logger.error(f"Non-retryable error: {type(e).__name__}: {e}\n{error_details}") raise # Record error ctx.record_error(status_code, e) # Check if should retry - if ctx.should_retry(status_code): + if ctx.should_retry(status_code, e): # Extract Retry-After retry_after = extract_retry_after(e) diff --git a/app/services/token/manager.py b/app/services/token/manager.py index eb668edb6..fde666619 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -576,16 +576,29 @@ async def sync_usage( except Exception as e: if isinstance(e, UpstreamException): - status = None - if e.details and "status" in e.details: - status = e.details["status"] - else: - status = getattr(e, "status_code", None) + status = e.details.get("status") if e.details else getattr(e, "status_code", None) + is_token_expired = e.details.get("is_token_expired", False) if e.details else False + if status == 401: - await self.record_fail(token_str, status, "rate_limits_auth_failed") + # 只要是 401,都应该记录一次失败,增加 fail_count + reason = "rate_limits_auth_failed" if is_token_expired else "rate_limits_auth_unknown" + + # 如果确认为过期,传入 threshold=1 强制立即失效 + await self.record_fail(token_str, status, reason, threshold=1 if is_token_expired else None) + + if is_token_expired: + # 只有确认过期的才跳过 fallback + logger.warning( + f"Token {raw_token[:10]}...: API sync failed (Confirmed Token Expired), skipping fallback" + ) + return False + logger.warning( - f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})" + f"Token {raw_token[:10]}...: API sync failed, error: {e}" ) + # 如果不执行降级扣费(例如在刷新状态时),则直接返回 False 表示同步失败 + if not consume_on_fail: + return False # 降级:本地预估扣费 if consume_on_fail: @@ -598,15 +611,16 @@ async def sync_usage( return False async def record_fail( - self, token_str: str, status_code: int = 401, reason: str = "" + self, token_str: str, status_code: int = 401, reason: str = "", threshold: Optional[int] = None ) -> bool: """ 记录 Token 失败 Args: token_str: Token 字符串 - status_code: HTTP 状态码 + status_code: HTTP Status Code reason: 失败原因 + threshold: 强制失败阈值 Returns: 是否成功 @@ -617,18 +631,22 @@ async def record_fail( token = pool.get(raw_token) if token: if status_code == 401: - threshold = get_config("token.fail_threshold", FAIL_THRESHOLD) - try: - threshold = int(threshold) - except (TypeError, ValueError): - threshold = FAIL_THRESHOLD + if threshold is None: + threshold = get_config("token.fail_threshold", FAIL_THRESHOLD) + try: + threshold = int(threshold) + except (TypeError, ValueError): + threshold = FAIL_THRESHOLD + if threshold < 1: threshold = 1 token.record_fail(status_code, reason, threshold=threshold) - logger.warning( + + log_level = logger.warning if token.status == TokenStatus.EXPIRED else logger.info + log_level( f"Token {raw_token[:10]}...: recorded {status_code} failure " - f"({token.fail_count}/{threshold}) - {reason}" + f"({token.fail_count}/{threshold}) - {reason} - status: {token.status}" ) self._track_token_change(token, pool.name, "state") self._schedule_save()