1
- import pytest
2
1
import asyncio
3
2
from unittest .mock import MagicMock
4
- from litellm .llms .vertex_ai .gemini .vertex_and_google_ai_studio_gemini import VertexGeminiConfig
3
+
4
+ import pytest
5
+
5
6
import litellm
6
7
from litellm import ModelResponse
8
+ from litellm .llms .vertex_ai .gemini .vertex_and_google_ai_studio_gemini import (
9
+ VertexGeminiConfig ,
10
+ )
11
+ from litellm .types .utils import ChoiceLogprobs
12
+
7
13
8
14
@pytest .mark .asyncio
9
15
async def test_transform_response_with_avglogprobs ():
@@ -13,41 +19,39 @@ async def test_transform_response_with_avglogprobs():
13
19
"""
14
20
# Create a mock response with avgLogprobs
15
21
response_json = {
16
- "candidates" : [{
17
- "content" : {"parts" : [{"text" : "Test response" }], "role" : "model" },
18
- "finishReason" : "STOP" ,
19
- "avgLogprobs" : - 0.3445799010140555
20
- }],
22
+ "candidates" : [
23
+ {
24
+ "content" : {"parts" : [{"text" : "Test response" }], "role" : "model" },
25
+ "finishReason" : "STOP" ,
26
+ "avgLogprobs" : - 0.3445799010140555 ,
27
+ }
28
+ ],
21
29
"usageMetadata" : {
22
30
"promptTokenCount" : 10 ,
23
31
"candidatesTokenCount" : 5 ,
24
- "totalTokenCount" : 15
25
- }
32
+ "totalTokenCount" : 15 ,
33
+ },
26
34
}
27
-
35
+
28
36
# Create a mock HTTP response
29
37
mock_response = MagicMock ()
30
38
mock_response .json .return_value = response_json
31
-
39
+
32
40
# Create a mock logging object
33
41
mock_logging = MagicMock ()
34
-
42
+
35
43
# Create an instance of VertexGeminiConfig
36
44
config = VertexGeminiConfig ()
37
-
45
+
38
46
# Create a ModelResponse object
39
47
model_response = ModelResponse (
40
48
id = "test-id" ,
41
49
choices = [],
42
50
created = 1234567890 ,
43
51
model = "gemini-2.0-flash" ,
44
- usage = {
45
- "prompt_tokens" : 10 ,
46
- "completion_tokens" : 5 ,
47
- "total_tokens" : 15
48
- }
52
+ usage = {"prompt_tokens" : 10 , "completion_tokens" : 5 , "total_tokens" : 15 },
49
53
)
50
-
54
+
51
55
# Call the transform_response method
52
56
transformed_response = config .transform_response (
53
57
model = "gemini-2.0-flash" ,
@@ -58,9 +62,63 @@ async def test_transform_response_with_avglogprobs():
58
62
messages = [],
59
63
optional_params = {},
60
64
litellm_params = {},
61
- encoding = None
65
+ encoding = None ,
62
66
)
63
-
67
+
64
68
# Assert that the avgLogprobs was correctly added to the model response
65
69
assert len (transformed_response .choices ) == 1
70
+ assert isinstance (transformed_response .choices [0 ].logprobs , ChoiceLogprobs )
66
71
assert transformed_response .choices [0 ].logprobs == - 0.3445799010140555
72
+
73
+
74
+ def test_top_logprobs ():
75
+ non_default_params = {
76
+ "top_logprobs" : 2 ,
77
+ "logprobs" : True ,
78
+ }
79
+ optional_params = {}
80
+ model = "gemini"
81
+
82
+ v = VertexGeminiConfig ().map_openai_params (
83
+ non_default_params = non_default_params ,
84
+ optional_params = optional_params ,
85
+ model = model ,
86
+ drop_params = False ,
87
+ )
88
+ assert v ["responseLogprobs" ] is non_default_params ["logprobs" ]
89
+ assert v ["logprobs" ] is non_default_params ["top_logprobs" ]
90
+
91
+
92
+ def test_get_model_for_vertex_ai_url ():
93
+ # Test case 1: Regular model name
94
+ model = "gemini-pro"
95
+ result = VertexGeminiConfig .get_model_for_vertex_ai_url (model )
96
+ assert result == "gemini-pro"
97
+
98
+ # Test case 2: Gemini spec model with UUID
99
+ model = "gemini/ft-uuid-123"
100
+ result = VertexGeminiConfig .get_model_for_vertex_ai_url (model )
101
+ assert result == "ft-uuid-123"
102
+
103
+
104
+ def test_is_model_gemini_spec_model ():
105
+ # Test case 1: None input
106
+ assert VertexGeminiConfig ._is_model_gemini_spec_model (None ) == False
107
+
108
+ # Test case 2: Regular model name
109
+ assert VertexGeminiConfig ._is_model_gemini_spec_model ("gemini-pro" ) == False
110
+
111
+ # Test case 3: Gemini spec model
112
+ assert VertexGeminiConfig ._is_model_gemini_spec_model ("gemini/custom-model" ) == True
113
+
114
+
115
+ def test_get_model_name_from_gemini_spec_model ():
116
+ # Test case 1: Regular model name
117
+ model = "gemini-pro"
118
+ result = VertexGeminiConfig ._get_model_name_from_gemini_spec_model (model )
119
+ assert result == "gemini-pro"
120
+
121
+ # Test case 2: Gemini spec model
122
+ model = "gemini/ft-uuid-123"
123
+ result = VertexGeminiConfig ._get_model_name_from_gemini_spec_model (model )
124
+ assert result == "ft-uuid-123"
0 commit comments