diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index a0ffe38ff..e64c0d1a8 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -589,7 +589,7 @@ class ModelResponse: kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" - vendor_details: dict[str, Any] | None = field(default=None, repr=False) + vendor_details: dict[str, Any] | None = field(default=None) """Additional vendor-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4390bc7d6..9504a04e6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -263,6 +263,8 @@ async def _make_request( yield r def _process_response(self, response: _GeminiResponse) -> ModelResponse: + vendor_details: dict[str, Any] | None = None + if len(response['candidates']) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover if 'content' not in response['candidates'][0]: @@ -273,9 +275,19 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: 'Content field missing from Gemini response', str(response) ) parts = response['candidates'][0]['content']['parts'] + vendor_id = response.get('vendor_id', None) + finish_reason = response['candidates'][0].get('finish_reason') + if finish_reason: + vendor_details = {'finish_reason': finish_reason} usage = _metadata_as_usage(response) usage.requests = 1 - return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage) + return _process_response_from_parts( + parts, + response.get('model_version', self._model_name), + usage, + vendor_id=vendor_id, + vendor_details=vendor_details, + ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" @@ -597,7 +609,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( - parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage + parts: Sequence[_GeminiPartUnion], + model_name: GeminiModelName, + usage: usage.Usage, + vendor_id: str | None, + vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -609,7 +625,9 @@ def _process_response_from_parts( raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, usage=usage, model_name=model_name) + return ModelResponse( + parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details + ) class _GeminiFunctionCall(TypedDict): @@ -721,6 +739,7 @@ class _GeminiResponse(TypedDict): usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] + vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] class _GeminiCandidates(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 63ba1741d..3f2b4b9c1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -6,7 +6,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace from datetime import datetime -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from uuid import uuid4 from typing_extensions import assert_never @@ -287,9 +287,16 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: 'Content field missing from Gemini response', str(response) ) # pragma: no cover parts = response.candidates[0].content.parts or [] + vendor_id = response.response_id or None + vendor_details: dict[str, Any] | None = None + finish_reason = response.candidates[0].finish_reason + if finish_reason: # pragma: no branch + vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) usage.requests = 1 - return _process_response_from_parts(parts, response.model_version or self._model_name, usage) + return _process_response_from_parts( + parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details + ) async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" @@ -435,7 +442,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict: return ContentDict(role='model', parts=parts) -def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse: +def _process_response_from_parts( + parts: list[Part], + model_name: GoogleModelName, + usage: usage.Usage, + vendor_id: str | None, + vendor_details: dict[str, Any] | None = None, +) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: if part.text: @@ -450,7 +463,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, model_name=model_name, usage=usage) + return ModelResponse( + parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details + ) def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict: diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ef21c5ac0..9edf8cef6 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -540,6 +540,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -555,6 +556,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -562,6 +564,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -585,6 +588,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -647,6 +651,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -666,6 +671,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -688,6 +694,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -1099,6 +1106,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1122,6 +1130,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -1232,6 +1241,7 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -1272,3 +1282,18 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra assert result.output == snapshot( 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' ) + + +async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): + response = gemini_response( + _content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None + ) + gemini_client = get_gemini_client(response) + m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) + agent = Agent(m) + + result = await agent.run('Hello World') + + for message in result.all_messages(): + if isinstance(message, ModelResponse): + assert message.vendor_details is None diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 9e7010dfd..7cd17256c 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -85,6 +85,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -138,6 +139,7 @@ async def temperature(city: str, date: datetime.date) -> str: usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -157,6 +159,7 @@ async def temperature(city: str, date: datetime.date) -> str: usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -215,6 +218,7 @@ async def get_capital(country: str) -> str: ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -235,6 +239,7 @@ async def get_capital(country: str) -> str: usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -469,6 +474,7 @@ def instructions() -> str: usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), model_name='gemini-2.0-flash', timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, ), ] )