Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nadirclaw/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def status():
click.echo("NadirClaw Status")
click.echo("-" * 40)
click.echo(f"Simple model: {settings.SIMPLE_MODEL}")
if settings.has_mid_tier:
click.echo(f"Mid model: {settings.MID_MODEL}")
click.echo(f"Complex model: {settings.COMPLEX_MODEL}")
if settings.has_explicit_tiers:
click.echo("Tier config: explicit (env vars)")
Expand Down
8 changes: 5 additions & 3 deletions nadirclaw/savings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def calculate_actual_cost(entries: List[Dict[str, Any]]) -> float:
"""Calculate the actual cost of all requests using the models NadirClaw chose."""
total = 0.0
for e in entries:
model = e.get("selected_model", "")
model = e.get("selected_model") or ""
pt = _safe_int(e.get("prompt_tokens", 0))
ct = _safe_int(e.get("completion_tokens", 0))
cost_in, cost_out = get_model_cost(model)
Expand Down Expand Up @@ -217,10 +217,12 @@ def format_savings_text(report: Dict[str, Any]) -> str:
lines.append("Routing Distribution")
lines.append("-" * 30)
total = sum(tiers.values())
for tier, count in sorted(tiers.items()):
# Sort carefully handling None keys
for tier, count in sorted(tiers.items(), key=lambda item: str(item[0])):
tier_str = str(tier) if tier is not None else "unknown"
pct = count / total * 100 if total else 0
bar = "█" * int(pct / 2)
lines.append(f" {tier:12s} {count:>5} ({pct:4.1f}%) {bar}")
lines.append(f" {tier_str:12s} {count:>5} ({pct:4.1f}%) {bar}")

# Monthly projection
proj = report.get("projection", {})
Expand Down
116 changes: 77 additions & 39 deletions nadirclaw/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def _is_oauth_token(token: str) -> bool:
_VERTEX_DEFAULT_LOCATION = "us-central1"


def _get_gemini_client(api_key: str):
def _get_gemini_client(api_key: Optional[str]):
"""Get or create a thread-safe, per-key google-genai Client.

Handles both API keys (AIza...) and OAuth access tokens (ya29...).
Expand All @@ -643,10 +643,11 @@ def _get_gemini_client(api_key: str):
OAuth tokens (from OpenClaw/Gemini CLI) must use the Vertex AI path.
"""
with _gemini_client_lock:
if api_key not in _gemini_clients:
cache_key = api_key if api_key is not None else "adc_default"
if cache_key not in _gemini_clients:
from google import genai

if _is_oauth_token(api_key):
if api_key and _is_oauth_token(api_key):
from google.oauth2.credentials import Credentials
from nadirclaw.credentials import get_gemini_oauth_config

Expand All @@ -661,7 +662,7 @@ def _get_gemini_client(api_key: str):
"credentials include a project_id."
)
creds = Credentials(token=api_key)
_gemini_clients[api_key] = genai.Client(
_gemini_clients[cache_key] = genai.Client(
vertexai=True,
credentials=creds,
project=project_id,
Expand All @@ -671,10 +672,43 @@ def _get_gemini_client(api_key: str):
"Created Gemini client with OAuth credentials (Vertex AI, project=%s)",
project_id,
)
else:
_gemini_clients[api_key] = genai.Client(api_key=api_key)
elif api_key:
_gemini_clients[cache_key] = genai.Client(api_key=api_key)
logger.debug("Created Gemini client with API key")
return _gemini_clients[api_key]
else:
import google.auth
from google.auth.exceptions import DefaultCredentialsError
from fastapi import HTTPException

try:
credentials, project_id = google.auth.default()
except DefaultCredentialsError as e:
raise HTTPException(
status_code=500,
detail="No Google/Gemini API key configured and no Application Default Credentials found. "
"Set GEMINI_API_KEY, GOOGLE_API_KEY, or configure ADC.",
) from e

project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") or project_id

if not project_id:
logger.warning(
"Gemini ADC detected but no project_id found. "
"Set GOOGLE_CLOUD_PROJECT env var."
)

_gemini_clients[cache_key] = genai.Client(
vertexai=True,
credentials=credentials,
project=project_id,
location=os.environ.get("GOOGLE_CLOUD_LOCATION", _VERTEX_DEFAULT_LOCATION),
)
logger.debug(
"Created Gemini client with Application Default Credentials (Vertex AI, project=%s)",
project_id,
)

return _gemini_clients[cache_key]


async def _call_gemini(
Expand All @@ -698,13 +732,8 @@ async def _call_gemini(
MAX_RETRIES = 1 # Keep low — fallback handles the rest

api_key = get_credential(provider)
if not api_key:
raise HTTPException(
status_code=500,
detail="No Google/Gemini API key configured. "
"Set GEMINI_API_KEY or GOOGLE_API_KEY, or run: nadirclaw auth add -p google",
)

# Allow api_key to be None here; _get_gemini_client will attempt to use ADC instead.
client = _get_gemini_client(api_key)
native_model = _strip_gemini_prefix(model)

Expand Down Expand Up @@ -1672,7 +1701,10 @@ async def _true_stream_wrapper():
},
}

resp = JSONResponse(content=response_body)
resp = JSONResponse(
content=response_body,
headers=_routing_headers(selected_model, analysis_info),
)
if _injection_signal and should_warn():
resp.headers["X-Prompt-Guard-Warning"] = _injection_signal.pattern_name
return resp
Expand Down Expand Up @@ -1909,12 +1941,8 @@ async def _stream_gemini(
from nadirclaw.credentials import get_credential

api_key = get_credential(provider)
if not api_key:
raise HTTPException(
status_code=500,
detail="No Google/Gemini API key configured.",
)

# Allow api_key to be None here; _get_gemini_client will attempt to use ADC instead.
client = _get_gemini_client(api_key)
native_model = _strip_gemini_prefix(model)

Expand Down Expand Up @@ -1986,13 +2014,16 @@ def _iter_stream():
for chunk in all_chunks:
delta_dict: dict[str, Any] = {}
text = ""
if hasattr(chunk, "text") and chunk.text:
text = chunk.text
elif chunk.candidates:
candidate = chunk.candidates[0]
if hasattr(candidate, "content") and candidate.content and candidate.content.parts:
text_parts = [p.text for p in candidate.content.parts if hasattr(p, "text") and p.text]
text = "".join(text_parts)
try:
if getattr(chunk, "text", None):
text = getattr(chunk, "text")
elif getattr(chunk, "candidates", None):
candidate = chunk.candidates[0]
if getattr(candidate, "content", None) and getattr(candidate.content, "parts", None):
text_parts = [str(p.text) if getattr(p, "text", None) else "" for p in candidate.content.parts]
text = "".join(text_parts)
except Exception as e:
logger.warning("Error parsing Gemini stream chunk text: %s", e)

if text:
delta_dict["content"] = text
Expand All @@ -2006,16 +2037,25 @@ def _iter_stream():
}

finish_reason = None
if chunk.candidates:
raw_reason = getattr(chunk.candidates[0], "finish_reason", None)
if raw_reason:
reason_str = str(raw_reason).lower()
if "safety" in reason_str:
finish_reason = "content_filter"
elif "length" in reason_str or "max_tokens" in reason_str:
finish_reason = "length"
elif "stop" in reason_str:
finish_reason = "stop"
try:
if getattr(chunk, "candidates", None):
raw_reason = getattr(chunk.candidates[0], "finish_reason", None)
if raw_reason:
try:
reason_str = str(getattr(raw_reason, "value", raw_reason)).lower()
except Exception:
try:
reason_str = str(raw_reason).lower()
except TypeError:
reason_str = getattr(raw_reason, "name", "").lower()
if "safety" in reason_str:
finish_reason = "content_filter"
elif "length" in reason_str or "max_tokens" in reason_str:
finish_reason = "length"
elif "stop" in reason_str:
finish_reason = "stop"
except Exception as e:
logger.warning("Error parsing Gemini stream finish_reason: %s", e)

if delta_dict or finish_reason:
yield delta_dict, usage, finish_reason
Expand All @@ -2039,9 +2079,7 @@ async def _dispatch_model_stream(
raise RateLimitExhausted(model=model, retry_after=retry_after)

if provider == "google":
async_gen = None
# _stream_gemini is a sync generator; wrap it
for item in _stream_gemini(model, request, provider):
async for item in _stream_gemini(model, request, provider):
yield item
else:
async for item in _stream_litellm(model, request, provider):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_fallback_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ def test_default_chain_includes_tier_models(self):
chain = settings.FALLBACK_CHAIN
assert settings.COMPLEX_MODEL in chain
assert settings.SIMPLE_MODEL in chain
# Complex should come first
assert chain.index(settings.COMPLEX_MODEL) < chain.index(settings.SIMPLE_MODEL)

# In testing environments without explicit models configured,
# COMPLEX_MODEL might equal SIMPLE_MODEL if the default MODELS list is altered.
# Only verify ordering if they are distinct.
if settings.COMPLEX_MODEL != settings.SIMPLE_MODEL:
assert chain.index(settings.COMPLEX_MODEL) < chain.index(settings.SIMPLE_MODEL)

def test_custom_chain_from_env(self, monkeypatch):
"""NADIRCLAW_FALLBACK_CHAIN env var should override defaults."""
Expand Down
Loading