Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 137 additions & 20 deletions langextract/providers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

"""Gemini provider for LangExtract."""

# pylint: disable=duplicate-code

from __future__ import annotations

import concurrent.futures
import dataclasses
import time
from typing import Any, Final, Iterator, Sequence

from absl import logging
Expand All @@ -37,6 +39,11 @@
_DEFAULT_LOCATION = 'us-central1'
_MIME_TYPE_JSON = 'application/json'

# Default retry configuration for transient errors (503, 429, etc.)
_DEFAULT_MAX_RETRIES = 3
_DEFAULT_RETRY_DELAY = 1.0 # Initial delay in seconds
_DEFAULT_MAX_RETRY_DELAY = 16.0 # Maximum delay in seconds

_API_CONFIG_KEYS: Final[set[str]] = {
'response_mime_type',
'response_schema',
Expand Down Expand Up @@ -68,6 +75,9 @@ class GeminiLanguageModel(base_model.BaseLanguageModel): # pylint: disable=too-
temperature: float = 0.0
max_workers: int = 10
fence_output: bool = False
max_retries: int = _DEFAULT_MAX_RETRIES
retry_delay: float = _DEFAULT_RETRY_DELAY
max_retry_delay: float = _DEFAULT_MAX_RETRY_DELAY
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)
Expand Down Expand Up @@ -105,6 +115,9 @@ def __init__(
temperature: float = 0.0,
max_workers: int = 10,
fence_output: bool = False,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
max_retry_delay: float = _DEFAULT_MAX_RETRY_DELAY,
**kwargs,
) -> None:
"""Initialize the Gemini language model.
Expand All @@ -123,6 +136,11 @@ def __init__(
max_workers: Maximum number of parallel API calls.
fence_output: Whether to wrap output in markdown fences (ignored,
Gemini handles this based on schema).
max_retries: Maximum number of retry attempts for transient errors
(503, 429, network errors). Set to 0 to disable retries.
retry_delay: Initial delay in seconds before first retry.
Subsequent delays increase exponentially.
max_retry_delay: Maximum delay in seconds between retries.
**kwargs: Additional Gemini API parameters. Only allowlisted keys are
forwarded to the API (response_schema, response_mime_type, tools,
safety_settings, stop_sequences, candidate_count, system_instruction).
Expand All @@ -149,6 +167,9 @@ def __init__(
self.temperature = temperature
self.max_workers = max_workers
self.fence_output = fence_output
self.max_retries = max_retries
self.retry_delay = retry_delay
self.max_retry_delay = max_retry_delay

# Extract batch config before we filter kwargs into _extra_kwargs
batch_cfg_dict = kwargs.pop('batch', None)
Expand Down Expand Up @@ -199,31 +220,127 @@ def _validate_schema_config(self) -> None:
'Set format_type=JSON or use_schema_constraints=False.'
)

def _is_retryable_error(self, error: Exception) -> bool:
"""Determine if an error is retryable (transient).

Retryable errors include:
- 503 Service Unavailable (model overloaded)
- 429 Too Many Requests (rate limiting)
- Network-related errors (connection, timeout, etc.)

Non-retryable errors include:
- 400 Bad Request (invalid input)
- 401 Unauthorized (authentication failure)
- 403 Forbidden (permission denied)
- 404 Not Found

Args:
error: The exception to check.

Returns:
True if the error is retryable, False otherwise.
"""
error_str = str(error).lower()

# Check for 503 (service unavailable / model overloaded)
if '503' in error_str or 'overloaded' in error_str:
return True

# Check for 429 (rate limiting)
if '429' in error_str or 'rate limit' in error_str or 'quota' in error_str:
return True

# Check for 500 (internal server error) - may be transient
if '500' in error_str and 'internal' in error_str:
return True

# Network-related errors are typically transient
if isinstance(error, (ConnectionError, TimeoutError, OSError)):
return True

# Check for connection/timeout keywords in error message
if any(
keyword in error_str
for keyword in ['timeout', 'connection', 'reset', 'unavailable']
):
return True

return False

def _process_single_prompt(
self, prompt: str, config: dict
) -> core_types.ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
try:
# Apply stored kwargs that weren't already set in config
for key, value in self._extra_kwargs.items():
if key not in config and value is not None:
config[key] = value

if self.gemini_schema:
self._validate_schema_config()
config.setdefault('response_mime_type', 'application/json')
config.setdefault('response_schema', self.gemini_schema.schema_dict)

response = self._client.models.generate_content(
model=self.model_id, contents=prompt, config=config
)
"""Process a single prompt and return a ScoredOutput.

return core_types.ScoredOutput(score=1.0, output=response.text)
Implements exponential backoff retry for transient errors (503, 429, etc.).
Each chunk is retried independently, allowing other chunks to succeed
even if one chunk encounters temporary failures.

except Exception as e:
raise exceptions.InferenceRuntimeError(
f'Gemini API error: {str(e)}', original=e
) from e
Args:
prompt: The prompt to process.
config: Configuration dictionary for the API call.

Returns:
A ScoredOutput containing the model's response.

Raises:
InferenceRuntimeError: If the API call fails after all retry attempts
or encounters a non-retryable error.
"""
last_exception: Exception | None = None
delay = self.retry_delay

for attempt in range(self.max_retries + 1):
try:
# Apply stored kwargs that weren't already set in config
# Make a copy to avoid mutating on retries
call_config = dict(config)
for key, value in self._extra_kwargs.items():
if key not in call_config and value is not None:
call_config[key] = value

if self.gemini_schema:
self._validate_schema_config()
call_config.setdefault('response_mime_type', 'application/json')
call_config.setdefault(
'response_schema', self.gemini_schema.schema_dict
)

response = self._client.models.generate_content(
model=self.model_id, contents=prompt, config=call_config
)

return core_types.ScoredOutput(score=1.0, output=response.text)

except Exception as e:
last_exception = e

# Check if we should retry
if attempt < self.max_retries and self._is_retryable_error(e):
logging.warning(
'Retryable error on attempt %d/%d: %s. Retrying in %.1f'
' seconds...',
attempt + 1,
self.max_retries + 1,
str(e),
delay,
)
time.sleep(delay)
# Exponential backoff with cap
delay = min(delay * 2, self.max_retry_delay)
continue

# Non-retryable error or max retries exhausted
raise exceptions.InferenceRuntimeError(
f'Gemini API error: {str(e)}', original=e
) from e

# This should not be reached, but handle it just in case
raise exceptions.InferenceRuntimeError(
f'Gemini API error after {self.max_retries + 1} attempts:'
f' {str(last_exception)}',
original=last_exception,
)

def infer(
self, batch_prompts: Sequence[str], **kwargs
Expand Down
Loading