22
22
import json
23
23
import mock
24
24
import pytest
25
- from vertexai .preview import caching
25
+ from vertexai .caching import _caching
26
26
from google .cloud .aiplatform import initializer
27
27
import vertexai
28
28
from google .cloud .aiplatform_v1beta1 .types .cached_content import (
35
35
from google .cloud .aiplatform_v1beta1 .types .tool import (
36
36
ToolConfig as GapicToolConfig ,
37
37
)
38
- from google .cloud .aiplatform_v1beta1 .services import (
38
+ from google .cloud .aiplatform_v1 .services import (
39
39
gen_ai_cache_service ,
40
40
)
41
+ from vertexai .generative_models ._generative_models import (
42
+ Content ,
43
+ PartsType ,
44
+ Tool ,
45
+ ToolConfig ,
46
+ ContentsType ,
47
+ )
41
48
42
49
43
50
_TEST_PROJECT = "test-project"
@@ -141,7 +148,7 @@ def list_cached_contents(self, request):
141
148
142
149
@pytest .mark .usefixtures ("google_auth_mock" )
143
150
class TestCaching :
144
- """Unit tests for caching .CachedContent."""
151
+ """Unit tests for _caching .CachedContent."""
145
152
146
153
def setup_method (self ):
147
154
vertexai .init (
@@ -156,7 +163,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
156
163
full_resource_name = (
157
164
"projects/123/locations/europe-west1/cachedContents/contents-id"
158
165
)
159
- cache = caching .CachedContent (
166
+ cache = _caching .CachedContent (
160
167
cached_content_name = full_resource_name ,
161
168
)
162
169
@@ -166,7 +173,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
166
173
def test_constructor_with_only_content_id (self , mock_get_cached_content ):
167
174
partial_resource_name = "contents-id"
168
175
169
- cache = caching .CachedContent (
176
+ cache = _caching .CachedContent (
170
177
cached_content_name = partial_resource_name ,
171
178
)
172
179
@@ -179,7 +186,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
179
186
def test_get_with_content_id (self , mock_get_cached_content ):
180
187
partial_resource_name = "contents-id"
181
188
182
- cache = caching .CachedContent .get (
189
+ cache = _caching .CachedContent .get (
183
190
cached_content_name = partial_resource_name ,
184
191
)
185
192
@@ -192,7 +199,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
192
199
def test_create_with_real_payload (
193
200
self , mock_create_cached_content , mock_get_cached_content
194
201
):
195
- cache = caching .CachedContent .create (
202
+ cache = _caching .CachedContent .create (
196
203
model_name = "model-name" ,
197
204
system_instruction = GapicContent (
198
205
role = "system" , parts = [GapicPart (text = "system instruction" )]
@@ -219,7 +226,7 @@ def test_create_with_real_payload(
219
226
def test_create_with_real_payload_and_wrapped_type (
220
227
self , mock_create_cached_content , mock_get_cached_content
221
228
):
222
- cache = caching .CachedContent .create (
229
+ cache = _caching .CachedContent .create (
223
230
model_name = "model-name" ,
224
231
system_instruction = "Please answer my questions with cool" ,
225
232
tools = [],
@@ -239,15 +246,15 @@ def test_create_with_real_payload_and_wrapped_type(
239
246
assert cache .display_name == _TEST_DISPLAY_NAME
240
247
241
248
def test_list (self , mock_list_cached_contents ):
242
- cached_contents = caching .CachedContent .list ()
249
+ cached_contents = _caching .CachedContent .list ()
243
250
for i , cached_content in enumerate (cached_contents ):
244
251
assert cached_content .name == f"cached_content{ i + 1 } _from_list_request"
245
252
assert cached_content .model_name == f"model-name{ i + 1 } "
246
253
247
254
def test_print_a_cached_content (
248
255
self , mock_create_cached_content , mock_get_cached_content
249
256
):
250
- cached_content = caching .CachedContent .create (
257
+ cached_content = _caching .CachedContent .create (
251
258
model_name = "model-name" ,
252
259
system_instruction = "Please answer my questions with cool" ,
253
260
tools = [],
0 commit comments