Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/utils/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,32 @@ def extract_json_from_markdown(content: str) -> str:
"""Extract JSON from markdown if present and clean control characters."""
content = content.strip()

if content.startswith("```json") and content.endswith("```"):
# Handle Gemini's format: "```json\n...\n```"
if content.startswith('"```json') and content.endswith('```"'):
content = content[8:-4].strip()
elif content.startswith('"```') and content.endswith('```"'):
content = content[4:-4].strip()
# Handle standard markdown format: ```json\n...\n```
elif content.startswith("```json") and content.endswith("```"):
content = content[7:-3].strip()
elif content.startswith("```") and content.endswith("```"):
content = content[3:-3].strip()

return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", content)


def fix_common_json_errors(content: str) -> str:
"""Fix common JSON syntax errors."""
# Fix extra equals signs (e.g., "area":="value" -> "area":"value")
content = re.sub(r':\s*=\s*"', ':"', content)

# Fix missing quotes around keys
content = re.sub(r'(\w+):\s*"', r'"\1":"', content)

# Fix trailing commas
return re.sub(r",(\s*[}\]])", r"\1", content)


def parse_llm_json_response(raw_content: Union[str, Any]) -> Dict[str, Any]:
"""Parse LLM JSON response."""
try:
Expand All @@ -31,8 +49,12 @@ def parse_llm_json_response(raw_content: Union[str, Any]) -> Dict[str, Any]:
# Clean the content first
cleaned_content = extract_json_from_markdown(raw_content)

# Fix common JSON errors
cleaned_content = fix_common_json_errors(cleaned_content)

# Parse the JSON
return json.loads(cleaned_content)
result = json.loads(cleaned_content)
return result if isinstance(result, dict) else {}

except json.JSONDecodeError as e:
log.error(f"Failed to parse JSON response: {e}")
Expand All @@ -50,7 +72,8 @@ def parse_llm_json_response(raw_content: Union[str, Any]) -> Dict[str, Any]:
log.warning(
"Attempting to fix unterminated JSON by truncating to last complete entry"
)
return json.loads(fixed_content)
result = json.loads(fixed_content)
return result if isinstance(result, dict) else {}
except Exception as fix_error:
log.error(f"Failed to fix JSON: {fix_error}")

Expand Down
31 changes: 18 additions & 13 deletions src/utils/model_client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)


MAX_TOKENS = 1024 * 10

logger = logging.getLogger(__name__)

GEMINI_STUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
Expand Down Expand Up @@ -48,30 +50,31 @@ def __init__(self, client: Any, max_retries: int = 3):
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def create(self, *args, **kwargs):
async def create(self, *args: Any, **kwargs: Any) -> Any:
"""Create with retry logic for transient errors."""
return await self.client.create(*args, **kwargs)

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Delegate all other attributes to the wrapped client."""
return getattr(self.client, name)


def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs) -> Any:
"""Return a model client for the given model name with retry logic."""
def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any) -> Any:
"""Get a model client for the given model name."""
n = model_name.lower()

if n.startswith(("gpt-", "o1-", "o3-")):
# Add max_tokens to prevent truncated responses
kwargs.setdefault("max_tokens", 4096)
client = OpenAIChatCompletionClient(model=model_name, seed=seed, **kwargs)
return RetryableModelClient(client)
if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")):
kwargs.setdefault("max_completion_tokens", MAX_TOKENS)
openai_client = OpenAIChatCompletionClient(
model=model_name, seed=seed, **kwargs
)
return RetryableModelClient(openai_client)

if "claude" in n:
# Add max_tokens to prevent truncated responses
kwargs.setdefault("max_tokens", 4096)
client = AnthropicChatCompletionClient(model=model_name, **kwargs)
return RetryableModelClient(client)
kwargs.setdefault("max_tokens", MAX_TOKENS)
kwargs.setdefault("timeout", None)
anthropic_client = AnthropicChatCompletionClient(model=model_name, **kwargs)
return RetryableModelClient(anthropic_client)

if "gemini" in n:
api_key = kwargs.pop("api_key", os.getenv("GOOGLE_API_KEY"))
Expand All @@ -89,6 +92,8 @@ def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs) -> A
),
)

kwargs.setdefault("max_completion_tokens", MAX_TOKENS)

client = OpenAIChatCompletionClient(
model=model_name,
base_url=GEMINI_STUDIO_BASE,
Expand Down
Loading