diff --git a/src/genai_utils/gemini.py b/src/genai_utils/gemini.py index 63d41e1..d27141f 100644 --- a/src/genai_utils/gemini.py +++ b/src/genai_utils/gemini.py @@ -45,12 +45,18 @@ } -class GeminiError(Exception): +class GeminiError(RuntimeError): """ Exception raised when something goes wrong with Gemini. """ +class NoGroundingError(GeminiError): + """ + Exception raised if grounding doesn't run when asked. + """ + + class ModelConfig(BaseModel): """ Config for a Gemini model. @@ -564,18 +570,21 @@ class Movie(BaseModel): ), ) + if not (response.candidates and response.text and isinstance(response.text, str)): + raise GeminiError( + f"No model output: possible reason: {response.prompt_feedback}" + ) + if use_grounding: grounding_ran = check_grounding_ran(response) if not grounding_ran: - _logger.warning( + _logger.error( "Grounding Info: GROUNDING FAILED - see previous log messages for reason" ) + raise NoGroundingError("Grounding did not run") - if response.candidates and response.text and isinstance(response.text, str): if inline_citations and response.candidates[0].grounding_metadata: text_with_citations = add_citations(response) return text_with_citations - else: - return response.text - raise GeminiError(f"No model output: possible reason: {response.prompt_feedback}") + return response.text diff --git a/tests/genai_utils/test_gemini.py b/tests/genai_utils/test_gemini.py index 82d4be0..1df6e07 100644 --- a/tests/genai_utils/test_gemini.py +++ b/tests/genai_utils/test_gemini.py @@ -5,15 +5,17 @@ from google.genai.client import AsyncClient from google.genai.models import Models from pydantic import BaseModel, Field -from pytest import mark, param +from pytest import mark, param, raises from genai_utils.gemini import ( DEFAULT_PARAMETERS, GeminiError, ModelConfig, + NoGroundingError, generate_model_config, get_thinking_config, run_prompt_async, + validate_labels, ) @@ -147,6 +149,35 @@ async def test_error_if_citations_and_no_grounding(mock_client): assert False +@patch("genai_utils.gemini.genai.Client") +async def test_no_grounding_error_when_grounding_does_not_run(mock_client): + client = Mock(Client) + models = Mock(Models) + async_client = Mock(AsyncClient) + + async def get_no_grounding_metadata_response(): + candidate = Mock() + candidate.grounding_metadata = None + response = Mock() + response.candidates = [candidate] + response.text = "response!" + return response + + models.generate_content.return_value = get_no_grounding_metadata_response() + client.aio = async_client + async_client.models = models + mock_client.return_value = client + + with raises(NoGroundingError): + await run_prompt_async( + "do something", + use_grounding=True, + model_config=ModelConfig( + project="project", location="location", model_name="model" + ), + ) + + @mark.parametrize( "model_name,do_thinking,expected", [ @@ -175,3 +206,90 @@ def test_get_thinking_config( ): thinking_config = get_thinking_config(model_name, do_thinking) assert thinking_config == expected + + +# --- validate_labels --- + + +def test_validate_labels_valid(): + labels = {"valid-key": "valid-value", "another_key": "value-123"} + assert validate_labels(labels) == labels + + +@mark.parametrize( + "labels", + [ + param({"": "value"}, id="empty-key"), + param({"a" * 64: "value"}, id="key-too-long"), + param({"key": "a" * 64}, id="value-too-long"), + param({"1key": "value"}, id="key-starts-with-digit"), + param({"_key": "value"}, id="key-starts-with-underscore"), + param({"KEY": "value"}, id="key-uppercase"), + param({"key.dots": "value"}, id="key-with-dots"), + param({"key": "VALUE"}, id="value-uppercase"), + param({"key": "val ue"}, id="value-with-space"), + ], +) +def test_validate_labels_invalid_input_dropped(labels): + assert validate_labels(labels) == {} + + +def test_validate_labels_mixed_keeps_only_valid(): + labels = {"valid": "ok", "INVALID": "value", "": "empty"} + assert validate_labels(labels) == {"valid": "ok"} + + +# --- run_prompt_async happy path --- + + +@patch("genai_utils.gemini.genai.Client") +async def test_run_prompt_async_returns_text(mock_client): + client = Mock(Client) + models = Mock(Models) + async_client = Mock(AsyncClient) + + response = Mock() + response.candidates = ["yes!"] + response.text = "response!" + + async def get_response(): + return response + + models.generate_content.return_value = get_response() + client.aio = async_client + async_client.models = models + mock_client.return_value = client + + result = await run_prompt_async( + "do something", + model_config=ModelConfig( + project="p", location="l", model_name="gemini-2.0-flash" + ), + ) + assert result == "response!" + + +@patch("genai_utils.gemini.genai.Client") +async def test_run_prompt_async_raises_when_no_output(mock_client): + client = Mock(Client) + models = Mock(Models) + async_client = Mock(AsyncClient) + + response = Mock() + response.candidates = None + response.text = None + response.prompt_feedback = "blocked" + + async def get_response(): + return response + + models.generate_content.return_value = get_response() + client.aio = async_client + async_client.models = models + mock_client.return_value = client + + with raises(GeminiError): + await run_prompt_async( + "do something", + model_config=ModelConfig(project="p", location="l", model_name="model"), + ) diff --git a/tests/genai_utils/test_grounding.py b/tests/genai_utils/test_grounding.py index 10ae6f7..c0da80b 100644 --- a/tests/genai_utils/test_grounding.py +++ b/tests/genai_utils/test_grounding.py @@ -2,9 +2,10 @@ import requests from google.genai import types -from pytest import mark, param +from pytest import mark, param, raises from genai_utils.gemini import ( + GeminiError, add_citations, check_grounding_ran, follow_redirect, @@ -177,6 +178,12 @@ def test_insert_citation( "response, expected", [ param(types.GenerateContentResponse(candidates=None), False), + param( + types.GenerateContentResponse( + candidates=[types.Candidate(grounding_metadata=None)] + ), + False, + ), param( types.GenerateContentResponse( candidates=[types.Candidate(grounding_metadata=dummy_grounding)] @@ -194,3 +201,19 @@ def test_insert_citation( def test_check_grounding_ran(response: types.GenerateContentResponse, expected: bool): did_grounding = check_grounding_ran(response) assert did_grounding == expected + + +@mark.parametrize( + "candidates,text", + [ + param(None, None, id="no-candidates"), + param([Mock()], None, id="no-text"), + ], +) +def test_add_citations_raises_when_missing_output(candidates, text): + response = Mock(types.GenerateContentResponse) + response.candidates = candidates + response.text = text + response.prompt_feedback = "blocked" + with raises(GeminiError): + add_citations(response)