diff --git a/src/genai_utils/gemini.py b/src/genai_utils/gemini.py index f9f3edd..4e9ca3e 100644 --- a/src/genai_utils/gemini.py +++ b/src/genai_utils/gemini.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import re @@ -303,6 +304,98 @@ def run_prompt( use_grounding: bool = False, inline_citations: bool = False, labels: dict[str, str] = {}, +) -> str: + """ + A synchronous version of `run_prompt_async`. + + Parameters + ---------- + prompt: str + The prompt given to the model + video_uri: str | None + A Google Cloud URI for a video that you want to prompt. + output_schema: types.SchemaUnion | None + A valid schema for the model output. + Generally, we'd recommend this being a pydantic BaseModel inheriting class, + which defines the desired schema of the model output. + ```python + from pydantic import BaseModel, Field + + class Movie(BaseModel): + title: str = Field(description="The title of the movie") + year: int = Field(description="The year the film was released in the UK") + + schema = Movie + # or + schema = list[Movie] + ``` + Use this if you want structured JSON output. + system_instruction: str | None + An instruction to the model which essentially goes before the prompt. + For example: + ``` + You are a fact checker and you must base all your answers on evidence + ``` + generation_config: dict[str, Any] + The parameters for the generation. See the docs (`generation config`_). + safety_settings: dict[generative_models.HarmCategory, generative_models.HarmBlockThreshold] + The safety settings for generation. Determines what will be blocked. + See the docs (`safety settings`_) + model_config: ModelConfig | None + The config for the Gemini model. + Specifies project, location, and model name. + If None, will attempt to use environment variables: + `GEMINI_PROJECT`, `GEMINI_LOCATION`, and `GEMINI_MODEL`. + use_grounding: bool + Whether Gemini should perform a Google search to ground results. + This will allow it to pull from up-to-date information, + and makes the output more likely to be factual. + Does not work with structured output. + See the docs (`grounding`_). + inline_citations: bool + Whether output should include citations inline with the text. + These citations will be links to be used as evidence. + This is only possible if grounding is set to true. + labels: dict[str, str] + Optional labels to attach to the API call for tracking and monitoring purposes. + Labels are key-value pairs that can be used to organize and filter requests + in Google Cloud logs and metrics. + + Returns + ------- + The text output of the Gemini model. + + .. _generation config: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig + .. _safety settings: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters + .. _grounding: https://ai.google.dev/gemini-api/docs/google-search + """ + return asyncio.run( + run_prompt_async( + prompt=prompt, + video_uri=video_uri, + output_schema=output_schema, + system_instruction=system_instruction, + generation_config=generation_config, + safety_settings=safety_settings, + model_config=model_config, + use_grounding=use_grounding, + inline_citations=inline_citations, + labels=labels, + ) + ) + + +async def run_prompt_async( + prompt: str, + video_uri: str | None = None, + output_schema: types.SchemaUnion | None = None, + system_instruction: str | None = None, + generation_config: dict[str, Any] = DEFAULT_PARAMETERS, + safety_settings: list[types.SafetySetting] = DEFAULT_SAFETY_SETTINGS, + model_config: ModelConfig | None = None, + use_grounding: bool = False, + inline_citations: bool = False, + labels: dict[str, str] = {}, ) -> str: """ Runs a prompt through the model. @@ -405,7 +498,7 @@ class Movie(BaseModel): merged_labels = DEFAULT_LABELS | labels validate_labels(merged_labels) - response = client.models.generate_content( + response = await client.aio.models.generate_content( model=model_config.model_name, contents=types.Content(role="user", parts=parts), config=types.GenerateContentConfig( diff --git a/tests/genai_utils/test_gemini.py b/tests/genai_utils/test_gemini.py index 138e25f..d5aa049 100644 --- a/tests/genai_utils/test_gemini.py +++ b/tests/genai_utils/test_gemini.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch from google.genai import Client +from google.genai.client import AsyncClient from google.genai.models import Models from pydantic import BaseModel, Field @@ -10,7 +11,7 @@ GeminiError, ModelConfig, generate_model_config, - run_prompt, + run_prompt_async, ) @@ -24,7 +25,7 @@ class DummySchema(BaseModel): colour: str = Field(description="Colour of dog") -def get_dummy(): +async def get_dummy(): return DummyResponse() @@ -57,17 +58,19 @@ def test_generate_model_config_no_env_vars(): @patch("genai_utils.gemini.genai.Client") -def test_dont_overwrite_generation_config(mock_client): +async def test_dont_overwrite_generation_config(mock_client): copy_of_params = {**DEFAULT_PARAMETERS} client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client assert DEFAULT_PARAMETERS == copy_of_params - run_prompt( + await run_prompt_async( "do something", output_schema=DummySchema, model_config=ModelConfig( @@ -75,7 +78,7 @@ def test_dont_overwrite_generation_config(mock_client): ), ) models.generate_content.return_value = get_dummy() - run_prompt( + await run_prompt_async( "do something", model_config=ModelConfig( project="project", location="location", model_name="model" @@ -89,16 +92,18 @@ def test_dont_overwrite_generation_config(mock_client): @patch("genai_utils.gemini.genai.Client") -def test_error_if_grounding_with_schema(mock_client): +async def test_error_if_grounding_with_schema(mock_client): client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client try: - run_prompt( + await run_prompt_async( "do something", output_schema=DummySchema, use_grounding=True, @@ -114,16 +119,18 @@ def test_error_if_grounding_with_schema(mock_client): @patch("genai_utils.gemini.genai.Client") -def test_error_if_citations_and_no_grounding(mock_client): +async def test_error_if_citations_and_no_grounding(mock_client): client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client try: - run_prompt( + await run_prompt_async( "do something", use_grounding=False, inline_citations=True, diff --git a/tests/genai_utils/test_labels.py b/tests/genai_utils/test_labels.py index cdd16be..5866dfc 100644 --- a/tests/genai_utils/test_labels.py +++ b/tests/genai_utils/test_labels.py @@ -3,9 +3,15 @@ import pytest from google.genai import Client +from google.genai.client import AsyncClient from google.genai.models import Models -from genai_utils.gemini import GeminiError, ModelConfig, run_prompt, validate_labels +from genai_utils.gemini import ( + GeminiError, + ModelConfig, + run_prompt_async, + validate_labels, +) class DummyResponse: @@ -13,7 +19,7 @@ class DummyResponse: text = "response!" -def get_dummy(): +async def get_dummy(): return DummyResponse() @@ -101,18 +107,20 @@ def test_validate_labels_valid_special_chars(): @patch("genai_utils.gemini.genai.Client") -def test_run_prompt_with_valid_labels(mock_client): +async def test_run_prompt_with_valid_labels(mock_client): """Test that run_prompt accepts and uses valid labels""" client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client labels = {"team": "ai", "project": "test"} - run_prompt( + await run_prompt_async( "test prompt", labels=labels, model_config=ModelConfig( @@ -128,19 +136,21 @@ def test_run_prompt_with_valid_labels(mock_client): @patch("genai_utils.gemini.genai.Client") -def test_run_prompt_with_invalid_labels(mock_client): +async def test_run_prompt_with_invalid_labels(mock_client): """Test that run_prompt rejects invalid labels""" client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client invalid_labels = {"Invalid": "value"} # uppercase key with pytest.raises(GeminiError, match="must start with a lowercase letter"): - run_prompt( + await run_prompt_async( "test prompt", labels=invalid_labels, model_config=ModelConfig( @@ -151,7 +161,7 @@ def test_run_prompt_with_invalid_labels(mock_client): @patch("genai_utils.gemini.genai.Client") @patch.dict(os.environ, {"GENAI_LABEL_TEAM": "ai", "GENAI_LABEL_ENV": "test"}) -def test_run_prompt_merges_env_labels(mock_client): +async def test_run_prompt_merges_env_labels(mock_client): """Test that run_prompt merges environment labels with request labels""" # Need to reload the module to pick up the new environment variables import importlib @@ -162,14 +172,16 @@ def test_run_prompt_merges_env_labels(mock_client): client = Mock(Client) models = Mock(Models) + async_client = Mock(AsyncClient) models.generate_content.return_value = get_dummy() - client.models = models + client.aio = async_client + async_client.models = models mock_client.return_value = client request_labels = {"project": "test"} - genai_utils.gemini.run_prompt( + await genai_utils.gemini.run_prompt_async( "test prompt", labels=request_labels, model_config=ModelConfig(