From b0faf8292cb4cae3c66b807aaf46de48f5cbb011 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Wed, 21 May 2025 17:32:24 +0200 Subject: [PATCH 01/10] feat: added vendor id --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4390bc7d6b..c852429cd4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -273,9 +273,12 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: 'Content field missing from Gemini response', str(response) ) parts = response['candidates'][0]['content']['parts'] + vendor_id = response.get('vendor_id', None) usage = _metadata_as_usage(response) usage.requests = 1 - return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage) + return _process_response_from_parts( + parts, response.get('model_version', self._model_name), usage, vendor_id=vendor_id + ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" @@ -597,7 +600,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( - parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage + parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage, vendor_id: str | None ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -609,7 +612,7 @@ def _process_response_from_parts( raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, usage=usage, model_name=model_name) + return ModelResponse(parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id) class _GeminiFunctionCall(TypedDict): @@ -721,6 +724,7 @@ class _GeminiResponse(TypedDict): usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] + vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] class _GeminiCandidates(TypedDict): From f486d622184bcb99d90940f5439cd2f137559412 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Wed, 21 May 2025 17:46:59 +0200 Subject: [PATCH 02/10] feat: added finish_reason --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c852429cd4..89f432b0e2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -274,10 +274,15 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: ) parts = response['candidates'][0]['content']['parts'] vendor_id = response.get('vendor_id', None) + vendor_details = {'finish_reason': response['candidates'][0].get('finish_reason')} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( - parts, response.get('model_version', self._model_name), usage, vendor_id=vendor_id + parts, + response.get('model_version', self._model_name), + usage, + vendor_id=vendor_id, + vendor_details=vendor_details, ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: @@ -600,7 +605,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( - parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage, vendor_id: str | None + parts: Sequence[_GeminiPartUnion], + model_name: GeminiModelName, + usage: usage.Usage, + vendor_id: str | None, + vendor_details: dict[str, Any] | None, ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -612,7 +621,9 @@ def _process_response_from_parts( raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id) + return ModelResponse( + parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details + ) class _GeminiFunctionCall(TypedDict): From 273e78367bf44ec16c4c918df1435ffb687b7043 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Thu, 22 May 2025 16:32:40 +0200 Subject: [PATCH 03/10] fix: added default for vendor details --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 89f432b0e2..9504a04e65 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -263,6 +263,8 @@ async def _make_request( yield r def _process_response(self, response: _GeminiResponse) -> ModelResponse: + vendor_details: dict[str, Any] | None = None + if len(response['candidates']) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover if 'content' not in response['candidates'][0]: @@ -274,7 +276,9 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: ) parts = response['candidates'][0]['content']['parts'] vendor_id = response.get('vendor_id', None) - vendor_details = {'finish_reason': response['candidates'][0].get('finish_reason')} + finish_reason = response['candidates'][0].get('finish_reason') + if finish_reason: + vendor_details = {'finish_reason': finish_reason} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( @@ -609,7 +613,7 @@ def _process_response_from_parts( model_name: GeminiModelName, usage: usage.Usage, vendor_id: str | None, - vendor_details: dict[str, Any] | None, + vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: From 8ae19b7c12e89580dff06d30b456ff7ce37f01cd Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 09:25:29 +0200 Subject: [PATCH 04/10] test: removed repr from vendor details and fixed gemini tests --- pydantic_ai_slim/pydantic_ai/messages.py | 2 +- tests/models/test_gemini.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index a0ffe38ff6..e64c0d1a80 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -589,7 +589,7 @@ class ModelResponse: kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" - vendor_details: dict[str, Any] | None = field(default=None, repr=False) + vendor_details: dict[str, Any] | None = field(default=None) """Additional vendor-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ef21c5ac0b..2a3cef85d5 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -540,6 +540,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -555,6 +556,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -562,6 +564,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -585,6 +588,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -647,6 +651,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -666,6 +671,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -688,6 +694,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -1099,6 +1106,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1122,6 +1130,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -1232,6 +1241,7 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) From 09af582795c0125dfe0bfc5329fd39742288ab44 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 10:51:52 +0200 Subject: [PATCH 05/10] feat: added id and vendor details to google --- pydantic_ai_slim/pydantic_ai/models/google.py | 23 +++++++++++++++---- tests/models/test_google.py | 6 +++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 63ba1741de..d613f79ae5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -6,7 +6,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace from datetime import datetime -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from uuid import uuid4 from typing_extensions import assert_never @@ -287,9 +287,16 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: 'Content field missing from Gemini response', str(response) ) # pragma: no cover parts = response.candidates[0].content.parts or [] + vendor_id = response.response_id or None + vendor_details: dict[str, Any] | None = None + finish_reason = response.candidates[0].finish_reason + if finish_reason: + vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) usage.requests = 1 - return _process_response_from_parts(parts, response.model_version or self._model_name, usage) + return _process_response_from_parts( + parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details + ) async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" @@ -435,7 +442,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict: return ContentDict(role='model', parts=parts) -def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse: +def _process_response_from_parts( + parts: list[Part], + model_name: GoogleModelName, + usage: usage.Usage, + vendor_id: str | None, + vendor_details: dict[str, Any] | None = None, +) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: if part.text: @@ -450,7 +463,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, model_name=model_name, usage=usage) + return ModelResponse( + parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details + ) def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict: diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 9e7010dfda..7cd17256c3 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -85,6 +85,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -138,6 +139,7 @@ async def temperature(city: str, date: datetime.date) -> str: usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -157,6 +159,7 @@ async def temperature(city: str, date: datetime.date) -> str: usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -215,6 +218,7 @@ async def get_capital(country: str) -> str: ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -235,6 +239,7 @@ async def get_capital(country: str) -> str: usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -469,6 +474,7 @@ def instructions() -> str: usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), model_name='gemini-2.0-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) From 518532a30ef735105ee091390bb7ae1f3d17e614 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 11:19:37 +0200 Subject: [PATCH 06/10] test: added tests for coverage --- .../test_google_no_finish_reason.yaml | 66 +++++++++++++++++++ tests/models/test_gemini.py | 10 +++ tests/models/test_google.py | 10 +++ 3 files changed, 86 insertions(+) create mode 100644 tests/models/cassettes/test_google/test_google_no_finish_reason.yaml diff --git a/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml b/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml new file mode 100644 index 0000000000..9c984f4ced --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml @@ -0,0 +1,66 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '169' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Hello! + role: user + generationConfig: {} + systemInstruction: + parts: + - text: You are a chatbot. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '644' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=322 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0009223055941137401 + content: + parts: + - text: | + Hello there! How can I help you today? + role: model + modelVersion: gemini-1.5-flash + usageMetadata: + candidatesTokenCount: 11 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 11 + promptTokenCount: 7 + promptTokensDetails: + - modality: TEXT + tokenCount: 7 + totalTokenCount: 18 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 2a3cef85d5..a2ac82a9f3 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1282,3 +1282,13 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra assert result.output == snapshot( 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' ) + +async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): + response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None) + gemini_client = get_gemini_client(response) + m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) + agent = Agent(m) + + result = await agent.run('Hello World') + if result.all_messages()[1].vendor_details: + assert result.all_messages()[1].vendor_details.get("finish_reason") == None \ No newline at end of file diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 7cd17256c3..43bbbfa7ce 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -513,3 +513,13 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p with pytest.raises(UnexpectedModelBehavior, match='Safety settings triggered'): await agent.run('Tell me a joke about a Brazilians.') + +async def test_google_no_finish_reason(allow_model_requests: None, google_provider: GoogleProvider): + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + agent = Agent(model=model, system_prompt='You are a chatbot.') + + result = await agent.run('Hello!') + + if result.all_messages()[1].vendor_details: + assert result.all_messages()[1].vendor_details.get("finish_reason") == None From 3b6c7cb1e5d096d442b2cc5c092f08f7f6968b0e Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 11:26:21 +0200 Subject: [PATCH 07/10] fix: linted --- tests/models/test_gemini.py | 2 +- tests/models/test_google.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a2ac82a9f3..f3ed1a64d7 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1291,4 +1291,4 @@ async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): result = await agent.run('Hello World') if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get("finish_reason") == None \ No newline at end of file + assert result.all_messages()[1].vendor_details.get('finish_reason') == None \ No newline at end of file diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 43bbbfa7ce..ae31fd80eb 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -522,4 +522,4 @@ async def test_google_no_finish_reason(allow_model_requests: None, google_provid result = await agent.run('Hello!') if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get("finish_reason") == None + assert result.all_messages()[1].vendor_details.get('finish_reason') == None From 21bd74fcfb27acbb771f0225e0425a9ee77c8509 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 11:27:51 +0200 Subject: [PATCH 08/10] fix: linted --- tests/models/test_gemini.py | 7 +++++-- tests/models/test_google.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index f3ed1a64d7..aca41ef041 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1283,12 +1283,15 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' ) + async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): - response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None) + response = gemini_response( + _content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None + ) gemini_client = get_gemini_client(response) m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) agent = Agent(m) result = await agent.run('Hello World') if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get('finish_reason') == None \ No newline at end of file + assert result.all_messages()[1].vendor_details.get('finish_reason') is None diff --git a/tests/models/test_google.py b/tests/models/test_google.py index ae31fd80eb..bf9c51ff15 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -514,12 +514,12 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p with pytest.raises(UnexpectedModelBehavior, match='Safety settings triggered'): await agent.run('Tell me a joke about a Brazilians.') + async def test_google_no_finish_reason(allow_model_requests: None, google_provider: GoogleProvider): - model = GoogleModel('gemini-1.5-flash', provider=google_provider) agent = Agent(model=model, system_prompt='You are a chatbot.') result = await agent.run('Hello!') if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get('finish_reason') == None + assert result.all_messages()[1].vendor_details.get('finish_reason') is None From 636d8e0b4b349abdb14125bbc8ef991b74935510 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Fri, 23 May 2025 11:59:15 +0200 Subject: [PATCH 09/10] fix: test coverage --- tests/models/test_gemini.py | 6 ++++-- tests/models/test_google.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index aca41ef041..9edf8cef6b 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1293,5 +1293,7 @@ async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): agent = Agent(m) result = await agent.run('Hello World') - if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get('finish_reason') is None + + for message in result.all_messages(): + if isinstance(message, ModelResponse): + assert message.vendor_details is None diff --git a/tests/models/test_google.py b/tests/models/test_google.py index bf9c51ff15..19d6d687c5 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -521,5 +521,6 @@ async def test_google_no_finish_reason(allow_model_requests: None, google_provid result = await agent.run('Hello!') - if result.all_messages()[1].vendor_details: - assert result.all_messages()[1].vendor_details.get('finish_reason') is None + for message in result.all_messages(): + if isinstance(message, ModelResponse): + assert message.vendor_details is None From e398c5f2b157936faa5cb7a5a823839860330cf0 Mon Sep 17 00:00:00 2001 From: davide-andreoli Date: Mon, 26 May 2025 20:39:55 +0200 Subject: [PATCH 10/10] fix: removed test for reproducibility --- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- .../test_google_no_finish_reason.yaml | 66 ------------------- tests/models/test_google.py | 11 ---- 3 files changed, 1 insertion(+), 78 deletions(-) delete mode 100644 tests/models/cassettes/test_google/test_google_no_finish_reason.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index d613f79ae5..3f2b4b9c1e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -290,7 +290,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: vendor_id = response.response_id or None vendor_details: dict[str, Any] | None = None finish_reason = response.candidates[0].finish_reason - if finish_reason: + if finish_reason: # pragma: no branch vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) usage.requests = 1 diff --git a/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml b/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml deleted file mode 100644 index 9c984f4ced..0000000000 --- a/tests/models/cassettes/test_google/test_google_no_finish_reason.yaml +++ /dev/null @@ -1,66 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '169' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: Hello! - role: user - generationConfig: {} - systemInstruction: - parts: - - text: You are a chatbot. - role: user - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '644' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=322 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0009223055941137401 - content: - parts: - - text: | - Hello there! How can I help you today? - role: model - modelVersion: gemini-1.5-flash - usageMetadata: - candidatesTokenCount: 11 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 11 - promptTokenCount: 7 - promptTokensDetails: - - modality: TEXT - tokenCount: 7 - totalTokenCount: 18 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 19d6d687c5..7cd17256c3 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -513,14 +513,3 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p with pytest.raises(UnexpectedModelBehavior, match='Safety settings triggered'): await agent.run('Tell me a joke about a Brazilians.') - - -async def test_google_no_finish_reason(allow_model_requests: None, google_provider: GoogleProvider): - model = GoogleModel('gemini-1.5-flash', provider=google_provider) - agent = Agent(model=model, system_prompt='You are a chatbot.') - - result = await agent.run('Hello!') - - for message in result.all_messages(): - if isinstance(message, ModelResponse): - assert message.vendor_details is None