Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/genai_utils/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,18 @@
}


class GeminiError(Exception):
class GeminiError(RuntimeError):
"""
Exception raised when something goes wrong with Gemini.
"""


class NoGroundingError(GeminiError):
"""
Exception raised if grounding doesn't run when asked.
"""


class ModelConfig(BaseModel):
"""
Config for a Gemini model.
Expand Down Expand Up @@ -564,18 +570,21 @@ class Movie(BaseModel):
),
)

if not (response.candidates and response.text and isinstance(response.text, str)):
raise GeminiError(
f"No model output: possible reason: {response.prompt_feedback}"
)

if use_grounding:
grounding_ran = check_grounding_ran(response)
if not grounding_ran:
_logger.warning(
_logger.error(
"Grounding Info: GROUNDING FAILED - see previous log messages for reason"
)
raise NoGroundingError("Grounding did not run")

if response.candidates and response.text and isinstance(response.text, str):
if inline_citations and response.candidates[0].grounding_metadata:
text_with_citations = add_citations(response)
return text_with_citations
else:
return response.text

raise GeminiError(f"No model output: possible reason: {response.prompt_feedback}")
return response.text
120 changes: 119 additions & 1 deletion tests/genai_utils/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from google.genai.client import AsyncClient
from google.genai.models import Models
from pydantic import BaseModel, Field
from pytest import mark, param
from pytest import mark, param, raises

from genai_utils.gemini import (
DEFAULT_PARAMETERS,
GeminiError,
ModelConfig,
NoGroundingError,
generate_model_config,
get_thinking_config,
run_prompt_async,
validate_labels,
)


Expand Down Expand Up @@ -147,6 +149,35 @@ async def test_error_if_citations_and_no_grounding(mock_client):
assert False


@patch("genai_utils.gemini.genai.Client")
async def test_no_grounding_error_when_grounding_does_not_run(mock_client):
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

async def get_no_grounding_metadata_response():
candidate = Mock()
candidate.grounding_metadata = None
response = Mock()
response.candidates = [candidate]
response.text = "response!"
return response

models.generate_content.return_value = get_no_grounding_metadata_response()
client.aio = async_client
async_client.models = models
mock_client.return_value = client

with raises(NoGroundingError):
await run_prompt_async(
"do something",
use_grounding=True,
model_config=ModelConfig(
project="project", location="location", model_name="model"
),
)


@mark.parametrize(
"model_name,do_thinking,expected",
[
Expand Down Expand Up @@ -175,3 +206,90 @@ def test_get_thinking_config(
):
thinking_config = get_thinking_config(model_name, do_thinking)
assert thinking_config == expected


# --- validate_labels ---


def test_validate_labels_valid():
labels = {"valid-key": "valid-value", "another_key": "value-123"}
assert validate_labels(labels) == labels


@mark.parametrize(
"labels",
[
param({"": "value"}, id="empty-key"),
param({"a" * 64: "value"}, id="key-too-long"),
param({"key": "a" * 64}, id="value-too-long"),
param({"1key": "value"}, id="key-starts-with-digit"),
param({"_key": "value"}, id="key-starts-with-underscore"),
param({"KEY": "value"}, id="key-uppercase"),
param({"key.dots": "value"}, id="key-with-dots"),
param({"key": "VALUE"}, id="value-uppercase"),
param({"key": "val ue"}, id="value-with-space"),
],
)
def test_validate_labels_invalid_input_dropped(labels):
assert validate_labels(labels) == {}


def test_validate_labels_mixed_keeps_only_valid():
labels = {"valid": "ok", "INVALID": "value", "": "empty"}
assert validate_labels(labels) == {"valid": "ok"}


# --- run_prompt_async happy path ---


@patch("genai_utils.gemini.genai.Client")
async def test_run_prompt_async_returns_text(mock_client):
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

response = Mock()
response.candidates = ["yes!"]
response.text = "response!"

async def get_response():
return response

models.generate_content.return_value = get_response()
client.aio = async_client
async_client.models = models
mock_client.return_value = client

result = await run_prompt_async(
"do something",
model_config=ModelConfig(
project="p", location="l", model_name="gemini-2.0-flash"
),
)
assert result == "response!"


@patch("genai_utils.gemini.genai.Client")
async def test_run_prompt_async_raises_when_no_output(mock_client):
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

response = Mock()
response.candidates = None
response.text = None
response.prompt_feedback = "blocked"

async def get_response():
return response

models.generate_content.return_value = get_response()
client.aio = async_client
async_client.models = models
mock_client.return_value = client

with raises(GeminiError):
await run_prompt_async(
"do something",
model_config=ModelConfig(project="p", location="l", model_name="model"),
)
25 changes: 24 additions & 1 deletion tests/genai_utils/test_grounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import requests
from google.genai import types
from pytest import mark, param
from pytest import mark, param, raises

from genai_utils.gemini import (
GeminiError,
add_citations,
check_grounding_ran,
follow_redirect,
Expand Down Expand Up @@ -177,6 +178,12 @@ def test_insert_citation(
"response, expected",
[
param(types.GenerateContentResponse(candidates=None), False),
param(
types.GenerateContentResponse(
candidates=[types.Candidate(grounding_metadata=None)]
),
False,
),
param(
types.GenerateContentResponse(
candidates=[types.Candidate(grounding_metadata=dummy_grounding)]
Expand All @@ -194,3 +201,19 @@ def test_insert_citation(
def test_check_grounding_ran(response: types.GenerateContentResponse, expected: bool):
did_grounding = check_grounding_ran(response)
assert did_grounding == expected


@mark.parametrize(
"candidates,text",
[
param(None, None, id="no-candidates"),
param([Mock()], None, id="no-text"),
],
)
def test_add_citations_raises_when_missing_output(candidates, text):
response = Mock(types.GenerateContentResponse)
response.candidates = candidates
response.text = text
response.prompt_feedback = "blocked"
with raises(GeminiError):
add_citations(response)