diff --git a/nadirclaw/cli.py b/nadirclaw/cli.py index a270cfa..9181fa9 100644 --- a/nadirclaw/cli.py +++ b/nadirclaw/cli.py @@ -214,6 +214,8 @@ def status(): click.echo("NadirClaw Status") click.echo("-" * 40) click.echo(f"Simple model: {settings.SIMPLE_MODEL}") + if settings.has_mid_tier: + click.echo(f"Mid model: {settings.MID_MODEL}") click.echo(f"Complex model: {settings.COMPLEX_MODEL}") if settings.has_explicit_tiers: click.echo("Tier config: explicit (env vars)") diff --git a/nadirclaw/savings.py b/nadirclaw/savings.py index bc4efd2..fbf9b4b 100644 --- a/nadirclaw/savings.py +++ b/nadirclaw/savings.py @@ -35,7 +35,7 @@ def calculate_actual_cost(entries: List[Dict[str, Any]]) -> float: """Calculate the actual cost of all requests using the models NadirClaw chose.""" total = 0.0 for e in entries: - model = e.get("selected_model", "") + model = e.get("selected_model") or "" pt = _safe_int(e.get("prompt_tokens", 0)) ct = _safe_int(e.get("completion_tokens", 0)) cost_in, cost_out = get_model_cost(model) @@ -217,10 +217,12 @@ def format_savings_text(report: Dict[str, Any]) -> str: lines.append("Routing Distribution") lines.append("-" * 30) total = sum(tiers.values()) - for tier, count in sorted(tiers.items()): + # Sort carefully handling None keys + for tier, count in sorted(tiers.items(), key=lambda item: str(item[0])): + tier_str = str(tier) if tier is not None else "unknown" pct = count / total * 100 if total else 0 bar = "█" * int(pct / 2) - lines.append(f" {tier:12s} {count:>5} ({pct:4.1f}%) {bar}") + lines.append(f" {tier_str:12s} {count:>5} ({pct:4.1f}%) {bar}") # Monthly projection proj = report.get("projection", {}) diff --git a/nadirclaw/server.py b/nadirclaw/server.py index ad87403..e16cb5a 100644 --- a/nadirclaw/server.py +++ b/nadirclaw/server.py @@ -633,7 +633,7 @@ def _is_oauth_token(token: str) -> bool: _VERTEX_DEFAULT_LOCATION = "us-central1" -def _get_gemini_client(api_key: str): +def _get_gemini_client(api_key: Optional[str]): """Get or create a thread-safe, per-key google-genai Client. Handles both API keys (AIza...) and OAuth access tokens (ya29...). @@ -643,10 +643,11 @@ def _get_gemini_client(api_key: str): OAuth tokens (from OpenClaw/Gemini CLI) must use the Vertex AI path. """ with _gemini_client_lock: - if api_key not in _gemini_clients: + cache_key = api_key if api_key is not None else "adc_default" + if cache_key not in _gemini_clients: from google import genai - if _is_oauth_token(api_key): + if api_key and _is_oauth_token(api_key): from google.oauth2.credentials import Credentials from nadirclaw.credentials import get_gemini_oauth_config @@ -661,7 +662,7 @@ def _get_gemini_client(api_key: str): "credentials include a project_id." ) creds = Credentials(token=api_key) - _gemini_clients[api_key] = genai.Client( + _gemini_clients[cache_key] = genai.Client( vertexai=True, credentials=creds, project=project_id, @@ -671,10 +672,43 @@ def _get_gemini_client(api_key: str): "Created Gemini client with OAuth credentials (Vertex AI, project=%s)", project_id, ) - else: - _gemini_clients[api_key] = genai.Client(api_key=api_key) + elif api_key: + _gemini_clients[cache_key] = genai.Client(api_key=api_key) logger.debug("Created Gemini client with API key") - return _gemini_clients[api_key] + else: + import google.auth + from google.auth.exceptions import DefaultCredentialsError + from fastapi import HTTPException + + try: + credentials, project_id = google.auth.default() + except DefaultCredentialsError as e: + raise HTTPException( + status_code=500, + detail="No Google/Gemini API key configured and no Application Default Credentials found. " + "Set GEMINI_API_KEY, GOOGLE_API_KEY, or configure ADC.", + ) from e + + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") or project_id + + if not project_id: + logger.warning( + "Gemini ADC detected but no project_id found. " + "Set GOOGLE_CLOUD_PROJECT env var." + ) + + _gemini_clients[cache_key] = genai.Client( + vertexai=True, + credentials=credentials, + project=project_id, + location=os.environ.get("GOOGLE_CLOUD_LOCATION", _VERTEX_DEFAULT_LOCATION), + ) + logger.debug( + "Created Gemini client with Application Default Credentials (Vertex AI, project=%s)", + project_id, + ) + + return _gemini_clients[cache_key] async def _call_gemini( @@ -698,13 +732,8 @@ async def _call_gemini( MAX_RETRIES = 1 # Keep low — fallback handles the rest api_key = get_credential(provider) - if not api_key: - raise HTTPException( - status_code=500, - detail="No Google/Gemini API key configured. " - "Set GEMINI_API_KEY or GOOGLE_API_KEY, or run: nadirclaw auth add -p google", - ) + # Allow api_key to be None here; _get_gemini_client will attempt to use ADC instead. client = _get_gemini_client(api_key) native_model = _strip_gemini_prefix(model) @@ -1672,7 +1701,10 @@ async def _true_stream_wrapper(): }, } - resp = JSONResponse(content=response_body) + resp = JSONResponse( + content=response_body, + headers=_routing_headers(selected_model, analysis_info), + ) if _injection_signal and should_warn(): resp.headers["X-Prompt-Guard-Warning"] = _injection_signal.pattern_name return resp @@ -1909,12 +1941,8 @@ async def _stream_gemini( from nadirclaw.credentials import get_credential api_key = get_credential(provider) - if not api_key: - raise HTTPException( - status_code=500, - detail="No Google/Gemini API key configured.", - ) + # Allow api_key to be None here; _get_gemini_client will attempt to use ADC instead. client = _get_gemini_client(api_key) native_model = _strip_gemini_prefix(model) @@ -1986,13 +2014,16 @@ def _iter_stream(): for chunk in all_chunks: delta_dict: dict[str, Any] = {} text = "" - if hasattr(chunk, "text") and chunk.text: - text = chunk.text - elif chunk.candidates: - candidate = chunk.candidates[0] - if hasattr(candidate, "content") and candidate.content and candidate.content.parts: - text_parts = [p.text for p in candidate.content.parts if hasattr(p, "text") and p.text] - text = "".join(text_parts) + try: + if getattr(chunk, "text", None): + text = getattr(chunk, "text") + elif getattr(chunk, "candidates", None): + candidate = chunk.candidates[0] + if getattr(candidate, "content", None) and getattr(candidate.content, "parts", None): + text_parts = [str(p.text) if getattr(p, "text", None) else "" for p in candidate.content.parts] + text = "".join(text_parts) + except Exception as e: + logger.warning("Error parsing Gemini stream chunk text: %s", e) if text: delta_dict["content"] = text @@ -2006,16 +2037,25 @@ def _iter_stream(): } finish_reason = None - if chunk.candidates: - raw_reason = getattr(chunk.candidates[0], "finish_reason", None) - if raw_reason: - reason_str = str(raw_reason).lower() - if "safety" in reason_str: - finish_reason = "content_filter" - elif "length" in reason_str or "max_tokens" in reason_str: - finish_reason = "length" - elif "stop" in reason_str: - finish_reason = "stop" + try: + if getattr(chunk, "candidates", None): + raw_reason = getattr(chunk.candidates[0], "finish_reason", None) + if raw_reason: + try: + reason_str = str(getattr(raw_reason, "value", raw_reason)).lower() + except Exception: + try: + reason_str = str(raw_reason).lower() + except TypeError: + reason_str = getattr(raw_reason, "name", "").lower() + if "safety" in reason_str: + finish_reason = "content_filter" + elif "length" in reason_str or "max_tokens" in reason_str: + finish_reason = "length" + elif "stop" in reason_str: + finish_reason = "stop" + except Exception as e: + logger.warning("Error parsing Gemini stream finish_reason: %s", e) if delta_dict or finish_reason: yield delta_dict, usage, finish_reason @@ -2039,9 +2079,7 @@ async def _dispatch_model_stream( raise RateLimitExhausted(model=model, retry_after=retry_after) if provider == "google": - async_gen = None - # _stream_gemini is a sync generator; wrap it - for item in _stream_gemini(model, request, provider): + async for item in _stream_gemini(model, request, provider): yield item else: async for item in _stream_litellm(model, request, provider): diff --git a/tests/test_fallback_chain.py b/tests/test_fallback_chain.py index 808a8e6..719e51f 100644 --- a/tests/test_fallback_chain.py +++ b/tests/test_fallback_chain.py @@ -12,8 +12,12 @@ def test_default_chain_includes_tier_models(self): chain = settings.FALLBACK_CHAIN assert settings.COMPLEX_MODEL in chain assert settings.SIMPLE_MODEL in chain - # Complex should come first - assert chain.index(settings.COMPLEX_MODEL) < chain.index(settings.SIMPLE_MODEL) + + # In testing environments without explicit models configured, + # COMPLEX_MODEL might equal SIMPLE_MODEL if the default MODELS list is altered. + # Only verify ordering if they are distinct. + if settings.COMPLEX_MODEL != settings.SIMPLE_MODEL: + assert chain.index(settings.COMPLEX_MODEL) < chain.index(settings.SIMPLE_MODEL) def test_custom_chain_from_env(self, monkeypatch): """NADIRCLAW_FALLBACK_CHAIN env var should override defaults."""