Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion src/genai_utils/gemini.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import os
import re
Expand Down Expand Up @@ -303,6 +304,98 @@ def run_prompt(
use_grounding: bool = False,
inline_citations: bool = False,
labels: dict[str, str] = {},
) -> str:
"""
A synchronous version of `run_prompt_async`.

Parameters
----------
prompt: str
The prompt given to the model
video_uri: str | None
A Google Cloud URI for a video that you want to prompt.
output_schema: types.SchemaUnion | None
A valid schema for the model output.
Generally, we'd recommend this being a pydantic BaseModel inheriting class,
which defines the desired schema of the model output.
```python
from pydantic import BaseModel, Field

class Movie(BaseModel):
title: str = Field(description="The title of the movie")
year: int = Field(description="The year the film was released in the UK")

schema = Movie
# or
schema = list[Movie]
```
Use this if you want structured JSON output.
system_instruction: str | None
An instruction to the model which essentially goes before the prompt.
For example:
```
You are a fact checker and you must base all your answers on evidence
```
generation_config: dict[str, Any]
The parameters for the generation. See the docs (`generation config`_).
safety_settings: dict[generative_models.HarmCategory, generative_models.HarmBlockThreshold]
The safety settings for generation. Determines what will be blocked.
See the docs (`safety settings`_)
model_config: ModelConfig | None
The config for the Gemini model.
Specifies project, location, and model name.
If None, will attempt to use environment variables:
`GEMINI_PROJECT`, `GEMINI_LOCATION`, and `GEMINI_MODEL`.
use_grounding: bool
Whether Gemini should perform a Google search to ground results.
This will allow it to pull from up-to-date information,
and makes the output more likely to be factual.
Does not work with structured output.
See the docs (`grounding`_).
inline_citations: bool
Whether output should include citations inline with the text.
These citations will be links to be used as evidence.
This is only possible if grounding is set to true.
labels: dict[str, str]
Optional labels to attach to the API call for tracking and monitoring purposes.
Labels are key-value pairs that can be used to organize and filter requests
in Google Cloud logs and metrics.

Returns
-------
The text output of the Gemini model.

.. _generation config: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig
.. _safety settings: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters
.. _grounding: https://ai.google.dev/gemini-api/docs/google-search
"""
return asyncio.run(
run_prompt_async(
prompt=prompt,
video_uri=video_uri,
output_schema=output_schema,
system_instruction=system_instruction,
generation_config=generation_config,
safety_settings=safety_settings,
model_config=model_config,
use_grounding=use_grounding,
inline_citations=inline_citations,
labels=labels,
)
)


async def run_prompt_async(
prompt: str,
video_uri: str | None = None,
output_schema: types.SchemaUnion | None = None,
system_instruction: str | None = None,
generation_config: dict[str, Any] = DEFAULT_PARAMETERS,
safety_settings: list[types.SafetySetting] = DEFAULT_SAFETY_SETTINGS,
model_config: ModelConfig | None = None,
use_grounding: bool = False,
inline_citations: bool = False,
labels: dict[str, str] = {},
) -> str:
"""
Runs a prompt through the model.
Expand Down Expand Up @@ -405,7 +498,7 @@ class Movie(BaseModel):
merged_labels = DEFAULT_LABELS | labels
validate_labels(merged_labels)

response = client.models.generate_content(
response = await client.aio.models.generate_content(
model=model_config.model_name,
contents=types.Content(role="user", parts=parts),
config=types.GenerateContentConfig(
Expand Down
31 changes: 19 additions & 12 deletions tests/genai_utils/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock, patch

from google.genai import Client
from google.genai.client import AsyncClient
from google.genai.models import Models
from pydantic import BaseModel, Field

Expand All @@ -10,7 +11,7 @@
GeminiError,
ModelConfig,
generate_model_config,
run_prompt,
run_prompt_async,
)


Expand All @@ -24,7 +25,7 @@ class DummySchema(BaseModel):
colour: str = Field(description="Colour of dog")


def get_dummy():
async def get_dummy():
return DummyResponse()


Expand Down Expand Up @@ -57,25 +58,27 @@ def test_generate_model_config_no_env_vars():


@patch("genai_utils.gemini.genai.Client")
def test_dont_overwrite_generation_config(mock_client):
async def test_dont_overwrite_generation_config(mock_client):
copy_of_params = {**DEFAULT_PARAMETERS}
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

assert DEFAULT_PARAMETERS == copy_of_params
run_prompt(
await run_prompt_async(
"do something",
output_schema=DummySchema,
model_config=ModelConfig(
project="project", location="location", model_name="model"
),
)
models.generate_content.return_value = get_dummy()
run_prompt(
await run_prompt_async(
"do something",
model_config=ModelConfig(
project="project", location="location", model_name="model"
Expand All @@ -89,16 +92,18 @@ def test_dont_overwrite_generation_config(mock_client):


@patch("genai_utils.gemini.genai.Client")
def test_error_if_grounding_with_schema(mock_client):
async def test_error_if_grounding_with_schema(mock_client):
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

try:
run_prompt(
await run_prompt_async(
"do something",
output_schema=DummySchema,
use_grounding=True,
Expand All @@ -114,16 +119,18 @@ def test_error_if_grounding_with_schema(mock_client):


@patch("genai_utils.gemini.genai.Client")
def test_error_if_citations_and_no_grounding(mock_client):
async def test_error_if_citations_and_no_grounding(mock_client):
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

try:
run_prompt(
await run_prompt_async(
"do something",
use_grounding=False,
inline_citations=True,
Expand Down
34 changes: 23 additions & 11 deletions tests/genai_utils/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@

import pytest
from google.genai import Client
from google.genai.client import AsyncClient
from google.genai.models import Models

from genai_utils.gemini import GeminiError, ModelConfig, run_prompt, validate_labels
from genai_utils.gemini import (
GeminiError,
ModelConfig,
run_prompt_async,
validate_labels,
)


class DummyResponse:
candidates = "yes!"
text = "response!"


def get_dummy():
async def get_dummy():
return DummyResponse()


Expand Down Expand Up @@ -101,18 +107,20 @@ def test_validate_labels_valid_special_chars():


@patch("genai_utils.gemini.genai.Client")
def test_run_prompt_with_valid_labels(mock_client):
async def test_run_prompt_with_valid_labels(mock_client):
"""Test that run_prompt accepts and uses valid labels"""
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

labels = {"team": "ai", "project": "test"}

run_prompt(
await run_prompt_async(
"test prompt",
labels=labels,
model_config=ModelConfig(
Expand All @@ -128,19 +136,21 @@ def test_run_prompt_with_valid_labels(mock_client):


@patch("genai_utils.gemini.genai.Client")
def test_run_prompt_with_invalid_labels(mock_client):
async def test_run_prompt_with_invalid_labels(mock_client):
"""Test that run_prompt rejects invalid labels"""
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

invalid_labels = {"Invalid": "value"} # uppercase key

with pytest.raises(GeminiError, match="must start with a lowercase letter"):
run_prompt(
await run_prompt_async(
"test prompt",
labels=invalid_labels,
model_config=ModelConfig(
Expand All @@ -151,7 +161,7 @@ def test_run_prompt_with_invalid_labels(mock_client):

@patch("genai_utils.gemini.genai.Client")
@patch.dict(os.environ, {"GENAI_LABEL_TEAM": "ai", "GENAI_LABEL_ENV": "test"})
def test_run_prompt_merges_env_labels(mock_client):
async def test_run_prompt_merges_env_labels(mock_client):
"""Test that run_prompt merges environment labels with request labels"""
# Need to reload the module to pick up the new environment variables
import importlib
Expand All @@ -162,14 +172,16 @@ def test_run_prompt_merges_env_labels(mock_client):

client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)

models.generate_content.return_value = get_dummy()
client.models = models
client.aio = async_client
async_client.models = models
mock_client.return_value = client

request_labels = {"project": "test"}

genai_utils.gemini.run_prompt(
await genai_utils.gemini.run_prompt_async(
"test prompt",
labels=request_labels,
model_config=ModelConfig(
Expand Down