Skip to content

Commit 6bad0ee

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: GA Context Cache Python SDK
FUTURE_COPYBARA_INTEGRATE_REVIEW=#4861 from googleapis:release-please--branches--main 039f2cb PiperOrigin-RevId: 718931779
1 parent c2e7ce4 commit 6bad0ee

File tree

9 files changed

+140
-74
lines changed

9 files changed

+140
-74
lines changed

google/cloud/aiplatform/compat/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@
181181
services.featurestore_online_serving_service_client_v1
182182
)
183183
services.featurestore_service_client = services.featurestore_service_client_v1
184-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
185-
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1beta1
184+
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1
186185
services.job_service_client = services.job_service_client_v1
187186
services.model_garden_service_client = services.model_garden_service_client_v1
188187
services.model_service_client = services.model_service_client_v1
@@ -203,8 +202,7 @@
203202
types.annotation_spec = types.annotation_spec_v1
204203
types.artifact = types.artifact_v1
205204
types.batch_prediction_job = types.batch_prediction_job_v1
206-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
207-
types.cached_content = types.cached_content_v1beta1
205+
types.cached_content = types.cached_content_v1
208206
types.completion_stats = types.completion_stats_v1
209207
types.context = types.context_v1
210208
types.custom_job = types.custom_job_v1

google/cloud/aiplatform/compat/services/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@
137137
from google.cloud.aiplatform_v1.services.featurestore_service import (
138138
client as featurestore_service_client_v1,
139139
)
140+
from google.cloud.aiplatform_v1.services.gen_ai_cache_service import (
141+
client as gen_ai_cache_service_client_v1,
142+
)
140143
from google.cloud.aiplatform_v1.services.index_service import (
141144
client as index_service_client_v1,
142145
)

google/cloud/aiplatform/compat/types/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
annotation_spec as annotation_spec_v1,
119119
artifact as artifact_v1,
120120
batch_prediction_job as batch_prediction_job_v1,
121+
cached_content as cached_content_v1,
121122
completion_stats as completion_stats_v1,
122123
context as context_v1,
123124
custom_job as custom_job_v1,

google/cloud/aiplatform/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
feature_registry_service_client_v1,
7878
featurestore_online_serving_service_client_v1,
7979
featurestore_service_client_v1,
80+
gen_ai_cache_service_client_v1,
8081
index_service_client_v1,
8182
index_endpoint_service_client_v1,
8283
job_service_client_v1,
@@ -805,8 +806,7 @@ class GenAiCacheServiceClientWithOverride(ClientWithOverride):
805806
_version_map = (
806807
(
807808
compat.V1,
808-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
809-
gen_ai_cache_service_client_v1beta1.GenAiCacheServiceClient,
809+
gen_ai_cache_service_client_v1.GenAiCacheServiceClient,
810810
),
811811
(
812812
compat.V1BETA1,

tests/unit/vertexai/test_caching.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import mock
2424
import pytest
25-
from vertexai.preview import caching
25+
from vertexai.caching import _caching
2626
from google.cloud.aiplatform import initializer
2727
import vertexai
2828
from google.cloud.aiplatform_v1beta1.types.cached_content import (
@@ -35,9 +35,16 @@
3535
from google.cloud.aiplatform_v1beta1.types.tool import (
3636
ToolConfig as GapicToolConfig,
3737
)
38-
from google.cloud.aiplatform_v1beta1.services import (
38+
from google.cloud.aiplatform_v1.services import (
3939
gen_ai_cache_service,
4040
)
41+
from vertexai.generative_models._generative_models import (
42+
Content,
43+
PartsType,
44+
Tool,
45+
ToolConfig,
46+
ContentsType,
47+
)
4148

4249

4350
_TEST_PROJECT = "test-project"
@@ -141,7 +148,7 @@ def list_cached_contents(self, request):
141148

142149
@pytest.mark.usefixtures("google_auth_mock")
143150
class TestCaching:
144-
"""Unit tests for caching.CachedContent."""
151+
"""Unit tests for _caching.CachedContent."""
145152

146153
def setup_method(self):
147154
vertexai.init(
@@ -156,7 +163,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
156163
full_resource_name = (
157164
"projects/123/locations/europe-west1/cachedContents/contents-id"
158165
)
159-
cache = caching.CachedContent(
166+
cache = _caching.CachedContent(
160167
cached_content_name=full_resource_name,
161168
)
162169

@@ -166,7 +173,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
166173
def test_constructor_with_only_content_id(self, mock_get_cached_content):
167174
partial_resource_name = "contents-id"
168175

169-
cache = caching.CachedContent(
176+
cache = _caching.CachedContent(
170177
cached_content_name=partial_resource_name,
171178
)
172179

@@ -179,7 +186,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
179186
def test_get_with_content_id(self, mock_get_cached_content):
180187
partial_resource_name = "contents-id"
181188

182-
cache = caching.CachedContent.get(
189+
cache = _caching.CachedContent.get(
183190
cached_content_name=partial_resource_name,
184191
)
185192

@@ -192,7 +199,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
192199
def test_create_with_real_payload(
193200
self, mock_create_cached_content, mock_get_cached_content
194201
):
195-
cache = caching.CachedContent.create(
202+
cache = _caching.CachedContent.create(
196203
model_name="model-name",
197204
system_instruction=GapicContent(
198205
role="system", parts=[GapicPart(text="system instruction")]
@@ -219,7 +226,7 @@ def test_create_with_real_payload(
219226
def test_create_with_real_payload_and_wrapped_type(
220227
self, mock_create_cached_content, mock_get_cached_content
221228
):
222-
cache = caching.CachedContent.create(
229+
cache = _caching.CachedContent.create(
223230
model_name="model-name",
224231
system_instruction="Please answer my questions with cool",
225232
tools=[],
@@ -239,15 +246,15 @@ def test_create_with_real_payload_and_wrapped_type(
239246
assert cache.display_name == _TEST_DISPLAY_NAME
240247

241248
def test_list(self, mock_list_cached_contents):
242-
cached_contents = caching.CachedContent.list()
249+
cached_contents = _caching.CachedContent.list()
243250
for i, cached_content in enumerate(cached_contents):
244251
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
245252
assert cached_content.model_name == f"model-name{i + 1}"
246253

247254
def test_print_a_cached_content(
248255
self, mock_create_cached_content, mock_get_cached_content
249256
):
250-
cached_content = caching.CachedContent.create(
257+
cached_content = _caching.CachedContent.create(
251258
model_name="model-name",
252259
system_instruction="Please answer my questions with cool",
253260
tools=[],

tests/unit/vertexai/test_generative_models.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@
3939
gapic_content_types,
4040
gapic_tool_types,
4141
)
42-
from google.cloud.aiplatform_v1beta1.types.cached_content import (
42+
from google.cloud.aiplatform_v1.types.cached_content import (
4343
CachedContent as GapicCachedContent,
4444
)
45-
from google.cloud.aiplatform_v1beta1.services import (
45+
from google.cloud.aiplatform_v1.services import (
4646
gen_ai_cache_service,
4747
)
4848
from vertexai.generative_models import _function_calling_utils
49-
from vertexai.preview import caching
49+
from vertexai.caching import _caching
5050

5151

5252
_TEST_PROJECT = "test-project"
@@ -655,11 +655,11 @@ def test_generative_model_from_cached_content(
655655
project_location_prefix = (
656656
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
657657
)
658-
cached_content = caching.CachedContent(
658+
cached_content = _caching.CachedContent(
659659
"cached-content-id-in-from-cached-content-test"
660660
)
661661

662-
model = preview_generative_models.GenerativeModel.from_cached_content(
662+
model = generative_models.GenerativeModel.from_cached_content(
663663
cached_content=cached_content
664664
)
665665

@@ -690,7 +690,7 @@ def test_generative_model_from_cached_content_with_resource_name(
690690
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
691691
)
692692

693-
model = preview_generative_models.GenerativeModel.from_cached_content(
693+
model = generative_models.GenerativeModel.from_cached_content(
694694
cached_content="cached-content-id-in-from-cached-content-test"
695695
)
696696

@@ -848,7 +848,7 @@ def test_generate_content(
848848
assert response5.text
849849

850850
@mock.patch.object(
851-
target=prediction_service.PredictionServiceClient,
851+
target=prediction_service_v1.PredictionServiceClient,
852852
attribute="generate_content",
853853
new=lambda self, request: gapic_prediction_service_types.GenerateContentResponse(
854854
candidates=[
@@ -870,11 +870,11 @@ def test_generate_content_with_cached_content(
870870
self,
871871
mock_get_cached_content_fixture,
872872
):
873-
cached_content = caching.CachedContent(
873+
cached_content = _caching.CachedContent(
874874
"cached-content-id-in-from-cached-content-test"
875875
)
876876

877-
model = preview_generative_models.GenerativeModel.from_cached_content(
877+
model = generative_models.GenerativeModel.from_cached_content(
878878
cached_content=cached_content
879879
)
880880

vertexai/caching/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Classes for working with the Gemini models."""
16+
17+
# We just want to re-export certain classes
18+
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.caching._caching import (
20+
CachedContent,
21+
)
22+
23+
__all__ = [
24+
"CachedContent",
25+
]

vertexai/caching/_caching.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from google.cloud.aiplatform.compat.types import (
2525
cached_content_v1beta1 as gca_cached_content,
2626
)
27-
from google.cloud.aiplatform_v1beta1.services import gen_ai_cache_service
27+
from google.cloud.aiplatform_v1.services import gen_ai_cache_service as gen_ai_cache_service_v1
2828
from google.cloud.aiplatform_v1beta1.types.cached_content import (
2929
CachedContent as GapicCachedContent,
3030
)
@@ -36,6 +36,7 @@
3636
GetCachedContentRequest,
3737
UpdateCachedContentRequest,
3838
)
39+
from google.cloud.aiplatform_v1 import types as types_v1
3940
from vertexai.generative_models import _generative_models
4041
from vertexai.generative_models._generative_models import (
4142
Content,
@@ -89,7 +90,7 @@ def _prepare_create_request(
8990
if ttl and expire_time:
9091
raise ValueError("Only one of ttl and expire_time can be set.")
9192

92-
request = CreateCachedContentRequest(
93+
request_v1beta1 = CreateCachedContentRequest(
9394
parent=f"projects/{project}/locations/{location}",
9495
cached_content=GapicCachedContent(
9596
model=model_name,
@@ -102,11 +103,32 @@ def _prepare_create_request(
102103
display_name=display_name,
103104
),
104105
)
105-
return request
106+
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
107+
try:
108+
request_v1 = types_v1.CreateCachedContentRequest.deserialize(
109+
serialized_message_v1beta1
110+
)
111+
except Exception as ex:
112+
raise ValueError(
113+
"Failed to convert CreateCachedContentRequest from v1beta1 to v1:\n"
114+
f"{serialized_message_v1beta1}"
115+
) from ex
116+
return request_v1
106117

107118

108119
def _prepare_get_cached_content_request(name: str) -> GetCachedContentRequest:
109-
return GetCachedContentRequest(name=name)
120+
request_v1beta1 = GetCachedContentRequest(name=name)
121+
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
122+
try:
123+
request_v1 = types_v1.GetCachedContentRequest.deserialize(
124+
serialized_message_v1beta1
125+
)
126+
except Exception as ex:
127+
raise ValueError(
128+
"Failed to convert GetCachedContentRequest from v1beta1 to v1:\n"
129+
f"{serialized_message_v1beta1}"
130+
) from ex
131+
return request_v1
110132

111133

112134
class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
@@ -122,7 +144,7 @@ class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
122144
client_class = aiplatform_utils.GenAiCacheServiceClientWithOverride
123145

124146
_gen_ai_cache_service_client_value: Optional[
125-
gen_ai_cache_service.GenAiCacheServiceClient
147+
gen_ai_cache_service_v1.GenAiCacheServiceClient
126148
] = None
127149

128150
def __init__(self, cached_content_name: str):
@@ -253,15 +275,25 @@ def update(
253275
update_mask.append("expire_time")
254276

255277
update_mask = field_mask_pb2.FieldMask(paths=update_mask)
256-
request = UpdateCachedContentRequest(
278+
request_v1beta1 = UpdateCachedContentRequest(
257279
cached_content=GapicCachedContent(
258280
name=self.resource_name,
259281
expire_time=expire_time,
260282
ttl=ttl,
261283
),
262284
update_mask=update_mask,
263285
)
264-
self.api_client.update_cached_content(request)
286+
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
287+
try:
288+
request_v1 = types_v1.UpdateCachedContentRequest.deserialize(
289+
serialized_message_v1beta1
290+
)
291+
except Exception as ex:
292+
raise ValueError(
293+
"Failed to convert UpdateCachedContentRequest from v1beta1 to v1:\n"
294+
f"{serialized_message_v1beta1}"
295+
) from ex
296+
self.api_client.update_cached_content(request_v1)
265297

266298
@property
267299
def expire_time(self) -> datetime.datetime:

0 commit comments

Comments
 (0)