diff --git a/README.md b/README.md index 53f2400..88f8f25 100644 --- a/README.md +++ b/README.md @@ -109,4 +109,9 @@ prompt = ( result = run_prompt(prompt, use_grounding=True, inline_citations=True) pp(result) -``` \ No newline at end of file +``` + +## Thinking + +You can enable or disable thinking in the model by toggling the `do_thinking` parameter. +Only enable this if the task is complex enough to require it, because it makes things slow and expensive. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0d05424..c7d29ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = { file = "LICENSE.md" } requires-python = ">=3.10" dependencies = [ "google-cloud-core>=2.4.3", - "google-genai>=1.27.0", + "google-genai>=1.59.0", "json-repair~=0.40.0", "pydantic>=2.11.7", "rapidfuzz>=3.13.0", diff --git a/scripts/gemini_demo.py b/scripts/gemini_demo.py index cad3b2d..b24aebd 100644 --- a/scripts/gemini_demo.py +++ b/scripts/gemini_demo.py @@ -74,3 +74,12 @@ class Monarch(BaseModel): answer_inline_citations = run_prompt(prompt, use_grounding=True, inline_citations=True) print("GROUNDING W/ CITATIONS", "\n", "-" * 100, "\n", answer_inline_citations, "\n\n") + +# Thinking +print("THINKING" + "\n" + "-" * 100) + +test_prompt = """ +Why is fact checking important? +""".strip() +output = run_prompt(test_prompt, do_thinking=True) +pp(output) diff --git a/src/genai_utils/gemini.py b/src/genai_utils/gemini.py index 4e9ca3e..9c17114 100644 --- a/src/genai_utils/gemini.py +++ b/src/genai_utils/gemini.py @@ -293,6 +293,43 @@ def check_grounding_ran(response: types.GenerateContentResponse) -> bool: return bool(n_searches and n_chunks and n_supports) +def get_thinking_config( + model_name: str, do_thinking: bool +) -> types.ThinkingConfig | None: + """ + Gets the thinking cofig required for the current model. + Thinking is set differently before and after Gemini 3.0. + Certain models like the 2.5 and 3.0 pro models, do not allow grounding to be disabled. + """ + if "gemini-2.5-pro" in model_name: + if not do_thinking: + _logger.warning( + "It is not possible to turn off thinking with this model. Setting to minimum." + ) + return types.ThinkingConfig(thinking_budget=128) # minimum thinking + return types.ThinkingConfig(thinking_budget=-1) # dynamic budget + + if ( + model_name < "gemini-2.6" + ): # there is no 2.6, but this means it will catch all 2.5 variants + if do_thinking: + return types.ThinkingConfig(thinking_budget=-1) # dynamic budget + return types.ThinkingConfig(thinking_budget=0) # disable thinking + + if model_name >= "gemini-3": + if not do_thinking: + if "pro" in model_name: + _logger.warning( + "Cannot disable thinking in this model. Setting thinking to low." + ) + return types.ThinkingConfig(thinking_level=types.ThinkingLevel.LOW) + return types.ThinkingConfig(thinking_level=types.ThinkingLevel.MINIMAL) + return None + + _logger.warning("Did not recognise the model provided, defaulting to None") + return None + + def run_prompt( prompt: str, video_uri: str | None = None, @@ -302,6 +339,7 @@ def run_prompt( safety_settings: list[types.SafetySetting] = DEFAULT_SAFETY_SETTINGS, model_config: ModelConfig | None = None, use_grounding: bool = False, + do_thinking: bool = False, inline_citations: bool = False, labels: dict[str, str] = {}, ) -> str: @@ -352,6 +390,10 @@ class Movie(BaseModel): and makes the output more likely to be factual. Does not work with structured output. See the docs (`grounding`_). + do_thinking: bool + Whether Gemini should use a thought process. + This is more expensive but may yield better results. + Do not use for bulk tasks that don't require complex thoughts. inline_citations: bool Whether output should include citations inline with the text. These citations will be links to be used as evidence. @@ -379,6 +421,7 @@ class Movie(BaseModel): safety_settings=safety_settings, model_config=model_config, use_grounding=use_grounding, + do_thinking=do_thinking, inline_citations=inline_citations, labels=labels, ) @@ -394,6 +437,7 @@ async def run_prompt_async( safety_settings: list[types.SafetySetting] = DEFAULT_SAFETY_SETTINGS, model_config: ModelConfig | None = None, use_grounding: bool = False, + do_thinking: bool = False, inline_citations: bool = False, labels: dict[str, str] = {}, ) -> str: @@ -444,6 +488,10 @@ class Movie(BaseModel): and makes the output more likely to be factual. Does not work with structured output. See the docs (`grounding`_). + do_thinking: bool + Whether Gemini should use a thought process. + This is more expensive but may yield better results. + Do not use for bulk tasks that don't require complex thoughts. inline_citations: bool Whether output should include citations inline with the text. These citations will be links to be used as evidence. @@ -506,6 +554,7 @@ class Movie(BaseModel): safety_settings=safety_settings, **built_gen_config, labels=merged_labels, + thinking_config=get_thinking_config(model_config.model_name, do_thinking), ), ) diff --git a/tests/genai_utils/test_gemini.py b/tests/genai_utils/test_gemini.py index d5aa049..82d4be0 100644 --- a/tests/genai_utils/test_gemini.py +++ b/tests/genai_utils/test_gemini.py @@ -1,16 +1,18 @@ import os from unittest.mock import Mock, patch -from google.genai import Client +from google.genai import Client, types from google.genai.client import AsyncClient from google.genai.models import Models from pydantic import BaseModel, Field +from pytest import mark, param from genai_utils.gemini import ( DEFAULT_PARAMETERS, GeminiError, ModelConfig, generate_model_config, + get_thinking_config, run_prompt_async, ) @@ -143,3 +145,33 @@ async def test_error_if_citations_and_no_grounding(mock_client): return assert False + + +@mark.parametrize( + "model_name,do_thinking,expected", + [ + param("gemini-2.0-flash", False, types.ThinkingConfig(thinking_budget=0)), + param("gemini-2.0-flash", True, types.ThinkingConfig(thinking_budget=-1)), + param("gemini-2.5-flash-lite", False, types.ThinkingConfig(thinking_budget=0)), + param("gemini-2.5-flash-lite", True, types.ThinkingConfig(thinking_budget=-1)), + param("gemini-2.5-pro", False, types.ThinkingConfig(thinking_budget=128)), + param("gemini-2.5-pro", True, types.ThinkingConfig(thinking_budget=-1)), + param( + "gemini-3.0-flash", + False, + types.ThinkingConfig(thinking_level=types.ThinkingLevel.MINIMAL), + ), + param("gemini-3.0-flash", True, None), + param( + "gemini-3.0-pro", + False, + types.ThinkingConfig(thinking_level=types.ThinkingLevel.LOW), + ), + param("gemini-3.0-pro", True, None), + ], +) +def test_get_thinking_config( + model_name: str, do_thinking: bool, expected: types.ThinkingConfig +): + thinking_config = get_thinking_config(model_name, do_thinking) + assert thinking_config == expected diff --git a/uv.lock b/uv.lock index 5958a37..3940573 100644 --- a/uv.lock +++ b/uv.lock @@ -74,15 +74,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/46/863c90dcd3f9d41b109b7f19032ae0db021f0b2a81482ba0a1e28c84de86/black-25.9.0-py3-none-any.whl", hash = "sha256:474b34c1342cdc157d307b56c4c65bce916480c4a8f6551fdc6bf9b486a7c4ae", size = 203363, upload-time = "2025-09-19T00:27:35.724Z" }, ] -[[package]] -name = "cachetools" -version = "6.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cc/7e/b975b5814bd36faf009faebe22c1072a1fa1168db34d285ef0ba071ad78c/cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201", size = 31325, upload-time = "2025-10-12T14:55:30.139Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" }, -] - [[package]] name = "certifi" version = "2025.10.5" @@ -220,6 +211,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] +[[package]] +name = "distro" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -283,7 +283,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "google-cloud-core", specifier = ">=2.4.3" }, - { name = "google-genai", specifier = ">=1.27.0" }, + { name = "google-genai", specifier = ">=1.59.0" }, { name = "json-repair", specifier = "~=0.40.0" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "rapidfuzz", specifier = ">=3.13.0" }, @@ -320,16 +320,20 @@ wheels = [ [[package]] name = "google-auth" -version = "2.43.0" +version = "2.47.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cachetools" }, { name = "pyasn1-modules" }, { name = "rsa" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ff/ef/66d14cf0e01b08d2d51ffc3c20410c4e134a1548fc246a6081eae585a4fe/google_auth-2.43.0.tar.gz", hash = "sha256:88228eee5fc21b62a1b5fe773ca15e67778cb07dc8363adcb4a8827b52d81483", size = 296359, upload-time = "2025-11-06T00:13:36.587Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/3c/ec64b9a275ca22fa1cd3b6e77fefcf837b0732c890aa32d2bd21313d9b33/google_auth-2.47.0.tar.gz", hash = "sha256:833229070a9dfee1a353ae9877dcd2dec069a8281a4e72e72f77d4a70ff945da", size = 323719, upload-time = "2026-01-06T21:55:31.045Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/d1/385110a9ae86d91cc14c5282c61fe9f4dc41c0b9f7d423c6ad77038c4448/google_auth-2.43.0-py2.py3-none-any.whl", hash = "sha256:af628ba6fa493f75c7e9dbe9373d148ca9f4399b5ea29976519e0a3848eddd16", size = 223114, upload-time = "2025-11-06T00:13:35.209Z" }, + { url = "https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl", hash = "sha256:c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498", size = 234867, upload-time = "2026-01-06T21:55:28.6Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, ] [[package]] @@ -347,21 +351,23 @@ wheels = [ [[package]] name = "google-genai" -version = "1.49.0" +version = "1.59.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, - { name = "google-auth" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, { name = "httpx" }, { name = "pydantic" }, { name = "requests" }, + { name = "sniffio" }, { name = "tenacity" }, { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/82/49/1a724ee3c3748fa50721d53a52d9fee88c67d0c43bb16eb2b10ee89ab239/google_genai-1.49.0.tar.gz", hash = "sha256:35eb16023b72e298571ae30e919c810694f258f2ba68fc77a2185c7c8829ad5a", size = 253493, upload-time = "2025-11-05T22:41:03.278Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/34/c03bcbc759d67ac3d96077838cdc1eac85417de6ea3b65b313fe53043eee/google_genai-1.59.0.tar.gz", hash = "sha256:0b7a2dc24582850ae57294209d8dfc2c4f5fcfde0a3f11d81dc5aca75fb619e2", size = 487374, upload-time = "2026-01-15T20:29:46.619Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/d3/84a152746dc7bdebb8ba0fd7d6157263044acd1d14b2a53e8df4a307b6b7/google_genai-1.49.0-py3-none-any.whl", hash = "sha256:ad49cd5be5b63397069e7aef9a4fe0a84cbdf25fcd93408e795292308db4ef32", size = 256098, upload-time = "2025-11-05T22:41:01.429Z" }, + { url = "https://files.pythonhosted.org/packages/aa/53/6d00692fe50d73409b3406ae90c71bc4499c8ae7fac377ba16e283da917c/google_genai-1.59.0-py3-none-any.whl", hash = "sha256:59fc01a225d074fe9d1e626c3433da292f33249dadce4deb34edea698305a6df", size = 719099, upload-time = "2026-01-15T20:29:44.604Z" }, ] [[package]]