Skip to content

Commit cd210bb

Browse files
vertex-sdk-botcopybara-github
authored andcommittedJan 23, 2025·
feat: GA Context Cache Python SDK
FUTURE_COPYBARA_INTEGRATE_REVIEW=#4861 from googleapis:release-please--branches--main 039f2cb PiperOrigin-RevId: 718130866
1 parent c2e7ce4 commit cd210bb

File tree

9 files changed

+124
-74
lines changed

9 files changed

+124
-74
lines changed
 

‎google/cloud/aiplatform/compat/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@
181181
services.featurestore_online_serving_service_client_v1
182182
)
183183
services.featurestore_service_client = services.featurestore_service_client_v1
184-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
185-
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1beta1
184+
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1
186185
services.job_service_client = services.job_service_client_v1
187186
services.model_garden_service_client = services.model_garden_service_client_v1
188187
services.model_service_client = services.model_service_client_v1
@@ -203,8 +202,7 @@
203202
types.annotation_spec = types.annotation_spec_v1
204203
types.artifact = types.artifact_v1
205204
types.batch_prediction_job = types.batch_prediction_job_v1
206-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
207-
types.cached_content = types.cached_content_v1beta1
205+
types.cached_content = types.cached_content_v1
208206
types.completion_stats = types.completion_stats_v1
209207
types.context = types.context_v1
210208
types.custom_job = types.custom_job_v1

‎google/cloud/aiplatform/compat/services/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@
137137
from google.cloud.aiplatform_v1.services.featurestore_service import (
138138
client as featurestore_service_client_v1,
139139
)
140+
from google.cloud.aiplatform_v1.services.gen_ai_cache_service import (
141+
client as gen_ai_cache_service_client_v1,
142+
)
140143
from google.cloud.aiplatform_v1.services.index_service import (
141144
client as index_service_client_v1,
142145
)

‎google/cloud/aiplatform/compat/types/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
annotation_spec as annotation_spec_v1,
119119
artifact as artifact_v1,
120120
batch_prediction_job as batch_prediction_job_v1,
121+
cached_content as cached_content_v1,
121122
completion_stats as completion_stats_v1,
122123
context as context_v1,
123124
custom_job as custom_job_v1,

‎google/cloud/aiplatform/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
feature_registry_service_client_v1,
7878
featurestore_online_serving_service_client_v1,
7979
featurestore_service_client_v1,
80+
gen_ai_cache_service_client_v1,
8081
index_service_client_v1,
8182
index_endpoint_service_client_v1,
8283
job_service_client_v1,
@@ -805,8 +806,7 @@ class GenAiCacheServiceClientWithOverride(ClientWithOverride):
805806
_version_map = (
806807
(
807808
compat.V1,
808-
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
809-
gen_ai_cache_service_client_v1beta1.GenAiCacheServiceClient,
809+
gen_ai_cache_service_client_v1.GenAiCacheServiceClient,
810810
),
811811
(
812812
compat.V1BETA1,

‎tests/unit/vertexai/test_caching.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import mock
2424
import pytest
25-
from vertexai.preview import caching
25+
from vertexai.caching import _caching
2626
from google.cloud.aiplatform import initializer
2727
import vertexai
2828
from google.cloud.aiplatform_v1beta1.types.cached_content import (
@@ -35,7 +35,7 @@
3535
from google.cloud.aiplatform_v1beta1.types.tool import (
3636
ToolConfig as GapicToolConfig,
3737
)
38-
from google.cloud.aiplatform_v1beta1.services import (
38+
from google.cloud.aiplatform_v1.services import (
3939
gen_ai_cache_service,
4040
)
4141

@@ -141,7 +141,7 @@ def list_cached_contents(self, request):
141141

142142
@pytest.mark.usefixtures("google_auth_mock")
143143
class TestCaching:
144-
"""Unit tests for caching.CachedContent."""
144+
"""Unit tests for _caching.CachedContent."""
145145

146146
def setup_method(self):
147147
vertexai.init(
@@ -156,7 +156,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
156156
full_resource_name = (
157157
"projects/123/locations/europe-west1/cachedContents/contents-id"
158158
)
159-
cache = caching.CachedContent(
159+
cache = _caching.CachedContent(
160160
cached_content_name=full_resource_name,
161161
)
162162

@@ -166,7 +166,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
166166
def test_constructor_with_only_content_id(self, mock_get_cached_content):
167167
partial_resource_name = "contents-id"
168168

169-
cache = caching.CachedContent(
169+
cache = _caching.CachedContent(
170170
cached_content_name=partial_resource_name,
171171
)
172172

@@ -179,7 +179,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
179179
def test_get_with_content_id(self, mock_get_cached_content):
180180
partial_resource_name = "contents-id"
181181

182-
cache = caching.CachedContent.get(
182+
cache = _caching.CachedContent.get(
183183
cached_content_name=partial_resource_name,
184184
)
185185

@@ -192,7 +192,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
192192
def test_create_with_real_payload(
193193
self, mock_create_cached_content, mock_get_cached_content
194194
):
195-
cache = caching.CachedContent.create(
195+
cache = _caching.CachedContent.create(
196196
model_name="model-name",
197197
system_instruction=GapicContent(
198198
role="system", parts=[GapicPart(text="system instruction")]
@@ -219,7 +219,7 @@ def test_create_with_real_payload(
219219
def test_create_with_real_payload_and_wrapped_type(
220220
self, mock_create_cached_content, mock_get_cached_content
221221
):
222-
cache = caching.CachedContent.create(
222+
cache = _caching.CachedContent.create(
223223
model_name="model-name",
224224
system_instruction="Please answer my questions with cool",
225225
tools=[],
@@ -239,15 +239,15 @@ def test_create_with_real_payload_and_wrapped_type(
239239
assert cache.display_name == _TEST_DISPLAY_NAME
240240

241241
def test_list(self, mock_list_cached_contents):
242-
cached_contents = caching.CachedContent.list()
242+
cached_contents = _caching.CachedContent.list()
243243
for i, cached_content in enumerate(cached_contents):
244244
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
245245
assert cached_content.model_name == f"model-name{i + 1}"
246246

247247
def test_print_a_cached_content(
248248
self, mock_create_cached_content, mock_get_cached_content
249249
):
250-
cached_content = caching.CachedContent.create(
250+
cached_content = _caching.CachedContent.create(
251251
model_name="model-name",
252252
system_instruction="Please answer my questions with cool",
253253
tools=[],

‎tests/unit/vertexai/test_generative_models.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@
3939
gapic_content_types,
4040
gapic_tool_types,
4141
)
42-
from google.cloud.aiplatform_v1beta1.types.cached_content import (
42+
from google.cloud.aiplatform_v1.types.cached_content import (
4343
CachedContent as GapicCachedContent,
4444
)
45-
from google.cloud.aiplatform_v1beta1.services import (
45+
from google.cloud.aiplatform_v1.services import (
4646
gen_ai_cache_service,
4747
)
4848
from vertexai.generative_models import _function_calling_utils
49-
from vertexai.preview import caching
49+
from vertexai.caching import _caching
5050

5151

5252
_TEST_PROJECT = "test-project"
@@ -655,11 +655,11 @@ def test_generative_model_from_cached_content(
655655
project_location_prefix = (
656656
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
657657
)
658-
cached_content = caching.CachedContent(
658+
cached_content = _caching.CachedContent(
659659
"cached-content-id-in-from-cached-content-test"
660660
)
661661

662-
model = preview_generative_models.GenerativeModel.from_cached_content(
662+
model = generative_models.GenerativeModel.from_cached_content(
663663
cached_content=cached_content
664664
)
665665

@@ -690,7 +690,7 @@ def test_generative_model_from_cached_content_with_resource_name(
690690
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
691691
)
692692

693-
model = preview_generative_models.GenerativeModel.from_cached_content(
693+
model = generative_models.GenerativeModel.from_cached_content(
694694
cached_content="cached-content-id-in-from-cached-content-test"
695695
)
696696

@@ -848,7 +848,7 @@ def test_generate_content(
848848
assert response5.text
849849

850850
@mock.patch.object(
851-
target=prediction_service.PredictionServiceClient,
851+
target=prediction_service_v1.PredictionServiceClient,
852852
attribute="generate_content",
853853
new=lambda self, request: gapic_prediction_service_types.GenerateContentResponse(
854854
candidates=[
@@ -870,11 +870,11 @@ def test_generate_content_with_cached_content(
870870
self,
871871
mock_get_cached_content_fixture,
872872
):
873-
cached_content = caching.CachedContent(
873+
cached_content = _caching.CachedContent(
874874
"cached-content-id-in-from-cached-content-test"
875875
)
876876

877-
model = preview_generative_models.GenerativeModel.from_cached_content(
877+
model = generative_models.GenerativeModel.from_cached_content(
878878
cached_content=cached_content
879879
)
880880

‎vertexai/caching/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Classes for working with the Gemini models."""
16+
17+
# We just want to re-export certain classes
18+
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.caching._caching import (
20+
CachedContent,
21+
)
22+
23+
__all__ = [
24+
"CachedContent",
25+
]

‎vertexai/caching/_caching.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from google.cloud.aiplatform.compat.types import (
2525
cached_content_v1beta1 as gca_cached_content,
2626
)
27-
from google.cloud.aiplatform_v1beta1.services import gen_ai_cache_service
27+
from google.cloud.aiplatform_v1.services import (
28+
gen_ai_cache_service as gen_ai_cache_service_v1,
29+
)
2830
from google.cloud.aiplatform_v1beta1.types.cached_content import (
2931
CachedContent as GapicCachedContent,
3032
)
@@ -36,6 +38,7 @@
3638
GetCachedContentRequest,
3739
UpdateCachedContentRequest,
3840
)
41+
from google.cloud.aiplatform_v1 import types as types_v1
3942
from vertexai.generative_models import _generative_models
4043
from vertexai.generative_models._generative_models import (
4144
Content,
@@ -89,7 +92,7 @@ def _prepare_create_request(
8992
if ttl and expire_time:
9093
raise ValueError("Only one of ttl and expire_time can be set.")
9194

92-
request = CreateCachedContentRequest(
95+
request_v1beta1 = CreateCachedContentRequest(
9396
parent=f"projects/{project}/locations/{location}",
9497
cached_content=GapicCachedContent(
9598
model=model_name,
@@ -102,11 +105,21 @@ def _prepare_create_request(
102105
display_name=display_name,
103106
),
104107
)
105-
return request
108+
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
109+
try:
110+
request_v1 = types_v1.CreateCachedContentRequest.deserialize(
111+
serialized_message_v1beta1
112+
)
113+
except Exception as ex:
114+
raise ValueError(
115+
"Failed to convert CreateCachedContentRequest from v1beta1 to v1:\n"
116+
f"{serialized_message_v1beta1}"
117+
) from ex
118+
return request_v1
106119

107120

108121
def _prepare_get_cached_content_request(name: str) -> GetCachedContentRequest:
109-
return GetCachedContentRequest(name=name)
122+
return types_v1.GetCachedContentRequest(name=name)
110123

111124

112125
class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
@@ -122,7 +135,7 @@ class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
122135
client_class = aiplatform_utils.GenAiCacheServiceClientWithOverride
123136

124137
_gen_ai_cache_service_client_value: Optional[
125-
gen_ai_cache_service.GenAiCacheServiceClient
138+
gen_ai_cache_service_v1.GenAiCacheServiceClient
126139
] = None
127140

128141
def __init__(self, cached_content_name: str):
@@ -253,15 +266,25 @@ def update(
253266
update_mask.append("expire_time")
254267

255268
update_mask = field_mask_pb2.FieldMask(paths=update_mask)
256-
request = UpdateCachedContentRequest(
269+
request_v1beta1 = UpdateCachedContentRequest(
257270
cached_content=GapicCachedContent(
258271
name=self.resource_name,
259272
expire_time=expire_time,
260273
ttl=ttl,
261274
),
262275
update_mask=update_mask,
263276
)
264-
self.api_client.update_cached_content(request)
277+
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
278+
try:
279+
request_v1 = types_v1.UpdateCachedContentRequest.deserialize(
280+
serialized_message_v1beta1
281+
)
282+
except Exception as ex:
283+
raise ValueError(
284+
"Failed to convert UpdateCachedContentRequest from v1beta1 to v1:\n"
285+
f"{serialized_message_v1beta1}"
286+
) from ex
287+
self.api_client.update_cached_content(request_v1)
265288

266289
@property
267290
def expire_time(self) -> datetime.datetime:

0 commit comments

Comments
 (0)
Please sign in to comment.