Skip to content

Commit 3e2a3d0

Browse files
test: initial commit fixing gemini logprobs
Fixes #9888
1 parent 7d383fc commit 3e2a3d0

File tree

4 files changed

+84
-92
lines changed

4 files changed

+84
-92
lines changed

litellm/types/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def __init__(
724724
finish_reason=None,
725725
index=0,
726726
message: Optional[Union[Message, dict]] = None,
727-
logprobs=None,
727+
logprobs: Optional[Union[ChoiceLogprobs, dict]] = None,
728728
enhancements=None,
729729
**params,
730730
):
@@ -746,7 +746,7 @@ def __init__(
746746
if logprobs is not None:
747747
if isinstance(logprobs, dict):
748748
self.logprobs = ChoiceLogprobs(**logprobs)
749-
else:
749+
elif isinstance(logprobs, ChoiceLogprobs):
750750
self.logprobs = logprobs
751751
if enhancements is not None:
752752
self.enhancements = enhancements

tests/litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio.py

Lines changed: 0 additions & 68 deletions
This file was deleted.
Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
import pytest
21
import asyncio
32
from unittest.mock import MagicMock
4-
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
3+
4+
import pytest
5+
56
import litellm
67
from litellm import ModelResponse
8+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
9+
VertexGeminiConfig,
10+
)
11+
from litellm.types.utils import ChoiceLogprobs
12+
713

814
@pytest.mark.asyncio
915
async def test_transform_response_with_avglogprobs():
@@ -13,41 +19,39 @@ async def test_transform_response_with_avglogprobs():
1319
"""
1420
# Create a mock response with avgLogprobs
1521
response_json = {
16-
"candidates": [{
17-
"content": {"parts": [{"text": "Test response"}], "role": "model"},
18-
"finishReason": "STOP",
19-
"avgLogprobs": -0.3445799010140555
20-
}],
22+
"candidates": [
23+
{
24+
"content": {"parts": [{"text": "Test response"}], "role": "model"},
25+
"finishReason": "STOP",
26+
"avgLogprobs": -0.3445799010140555,
27+
}
28+
],
2129
"usageMetadata": {
2230
"promptTokenCount": 10,
2331
"candidatesTokenCount": 5,
24-
"totalTokenCount": 15
25-
}
32+
"totalTokenCount": 15,
33+
},
2634
}
27-
35+
2836
# Create a mock HTTP response
2937
mock_response = MagicMock()
3038
mock_response.json.return_value = response_json
31-
39+
3240
# Create a mock logging object
3341
mock_logging = MagicMock()
34-
42+
3543
# Create an instance of VertexGeminiConfig
3644
config = VertexGeminiConfig()
37-
45+
3846
# Create a ModelResponse object
3947
model_response = ModelResponse(
4048
id="test-id",
4149
choices=[],
4250
created=1234567890,
4351
model="gemini-2.0-flash",
44-
usage={
45-
"prompt_tokens": 10,
46-
"completion_tokens": 5,
47-
"total_tokens": 15
48-
}
52+
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
4953
)
50-
54+
5155
# Call the transform_response method
5256
transformed_response = config.transform_response(
5357
model="gemini-2.0-flash",
@@ -58,9 +62,63 @@ async def test_transform_response_with_avglogprobs():
5862
messages=[],
5963
optional_params={},
6064
litellm_params={},
61-
encoding=None
65+
encoding=None,
6266
)
63-
67+
6468
# Assert that the avgLogprobs was correctly added to the model response
6569
assert len(transformed_response.choices) == 1
70+
assert isinstance(transformed_response.choices[0].logprobs, ChoiceLogprobs)
6671
assert transformed_response.choices[0].logprobs == -0.3445799010140555
72+
73+
74+
def test_top_logprobs():
75+
non_default_params = {
76+
"top_logprobs": 2,
77+
"logprobs": True,
78+
}
79+
optional_params = {}
80+
model = "gemini"
81+
82+
v = VertexGeminiConfig().map_openai_params(
83+
non_default_params=non_default_params,
84+
optional_params=optional_params,
85+
model=model,
86+
drop_params=False,
87+
)
88+
assert v["responseLogprobs"] is non_default_params["logprobs"]
89+
assert v["logprobs"] is non_default_params["top_logprobs"]
90+
91+
92+
def test_get_model_for_vertex_ai_url():
93+
# Test case 1: Regular model name
94+
model = "gemini-pro"
95+
result = VertexGeminiConfig.get_model_for_vertex_ai_url(model)
96+
assert result == "gemini-pro"
97+
98+
# Test case 2: Gemini spec model with UUID
99+
model = "gemini/ft-uuid-123"
100+
result = VertexGeminiConfig.get_model_for_vertex_ai_url(model)
101+
assert result == "ft-uuid-123"
102+
103+
104+
def test_is_model_gemini_spec_model():
105+
# Test case 1: None input
106+
assert VertexGeminiConfig._is_model_gemini_spec_model(None) == False
107+
108+
# Test case 2: Regular model name
109+
assert VertexGeminiConfig._is_model_gemini_spec_model("gemini-pro") == False
110+
111+
# Test case 3: Gemini spec model
112+
assert VertexGeminiConfig._is_model_gemini_spec_model("gemini/custom-model") == True
113+
114+
115+
def test_get_model_name_from_gemini_spec_model():
116+
# Test case 1: Regular model name
117+
model = "gemini-pro"
118+
result = VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
119+
assert result == "gemini-pro"
120+
121+
# Test case 2: Gemini spec model
122+
model = "gemini/ft-uuid-123"
123+
result = VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
124+
assert result == "ft-uuid-123"

tests/llm_translation/test_gemini.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,6 @@ def test_gemini_image_generation():
8282
messages=[{"role": "user", "content": "Generate an image of a cat"}],
8383
modalities=["image", "text"],
8484
)
85-
assert response.choices[0].message.content is not None
85+
assert response.choices[0].message.content is not None
86+
87+

0 commit comments

Comments
 (0)