Skip to content

Commit 13dc904

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Update v1beta1 sdk for a few new protos
PiperOrigin-RevId: 700552835
1 parent c23c62d commit 13dc904

9 files changed

+262
-80
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

+50-26
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222
EmbeddingModelConfig,
2323
Filter,
2424
HybridSearch,
25+
LlmRanker,
2526
Pinecone,
2627
RagCorpus,
2728
RagFile,
2829
RagResource,
2930
RagRetrievalConfig,
31+
Ranking,
32+
RankService,
3033
SharePointSource,
3134
SharePointSources,
3235
SlackChannelsSource,
@@ -40,6 +43,7 @@
4043
from google.cloud.aiplatform_v1beta1 import (
4144
GoogleDriveSource,
4245
RagFileChunkingConfig,
46+
RagFileTransformationConfig,
4347
RagFileParsingConfig,
4448
ImportRagFilesConfig,
4549
ImportRagFilesRequest,
@@ -213,10 +217,20 @@
213217
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
214218
TEST_CHUNK_SIZE = 512
215219
TEST_CHUNK_OVERLAP = 100
220+
TEST_RAG_FILE_TRANSFORMATION_CONFIG = RagFileTransformationConfig(
221+
rag_file_chunking_config=RagFileChunkingConfig(
222+
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
223+
chunk_size=TEST_CHUNK_SIZE,
224+
chunk_overlap=TEST_CHUNK_OVERLAP,
225+
),
226+
),
227+
)
216228
# GCS
217-
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
229+
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
230+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
231+
)
218232
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
219-
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
233+
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = False
220234
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
221235
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
222236
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
@@ -229,24 +243,28 @@
229243
TEST_DRIVE_FOLDER_2 = (
230244
f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3"
231245
)
232-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
246+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig(
247+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
248+
)
233249
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
234250
GoogleDriveSource.ResourceId(
235251
resource_id=TEST_DRIVE_FOLDER_ID,
236252
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
237253
)
238254
]
239-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
255+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
240256
False
241257
)
242-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
258+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig(
259+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
260+
)
243261
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
244262
GoogleDriveSource.ResourceId(
245263
resource_id=TEST_DRIVE_FOLDER_ID,
246264
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
247265
)
248266
]
249-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
267+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
250268
True
251269
)
252270
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
@@ -261,11 +279,10 @@
261279
TEST_DRIVE_FILE_ID = "456"
262280
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
263281
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
264-
rag_file_chunking_config=RagFileChunkingConfig(
265-
chunk_size=TEST_CHUNK_SIZE,
266-
chunk_overlap=TEST_CHUNK_OVERLAP,
282+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
283+
rag_file_parsing_config=RagFileParsingConfig(
284+
advanced_parser=RagFileParsingConfig.AdvancedParser(use_advanced_pdf_parsing=False)
267285
),
268-
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
269286
)
270287
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
271288

@@ -319,12 +336,15 @@
319336
),
320337
],
321338
)
322-
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
323-
rag_file_chunking_config=RagFileChunkingConfig(
324-
chunk_size=TEST_CHUNK_SIZE,
325-
chunk_overlap=TEST_CHUNK_OVERLAP,
339+
TEST_RAG_FILE_PARSING_CONFIG = RagFileParsingConfig(
340+
advanced_parser=RagFileParsingConfig.AdvancedParser(
341+
use_advanced_pdf_parsing=False
326342
)
327343
)
344+
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
345+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
346+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
347+
)
328348
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
329349
GapicSlackSource.SlackChannels(
330350
channels=[
@@ -375,10 +395,8 @@
375395
],
376396
)
377397
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
378-
rag_file_chunking_config=RagFileChunkingConfig(
379-
chunk_size=TEST_CHUNK_SIZE,
380-
chunk_overlap=TEST_CHUNK_OVERLAP,
381-
)
398+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
399+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
382400
)
383401
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
384402
GapicJiraSource.JiraQueries(
@@ -410,10 +428,8 @@
410428
],
411429
)
412430
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig(
413-
rag_file_chunking_config=RagFileChunkingConfig(
414-
chunk_size=TEST_CHUNK_SIZE,
415-
chunk_overlap=TEST_CHUNK_OVERLAP,
416-
),
431+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
432+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
417433
share_point_sources=GapicSharePointSources(
418434
share_point_sources=[
419435
GapicSharePointSources.SharePointSource(
@@ -488,10 +504,7 @@
488504
)
489505

490506
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig(
491-
rag_file_chunking_config=RagFileChunkingConfig(
492-
chunk_size=TEST_CHUNK_SIZE,
493-
chunk_overlap=TEST_CHUNK_OVERLAP,
494-
),
507+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
495508
share_point_sources=GapicSharePointSources(
496509
share_point_sources=[
497510
GapicSharePointSources.SharePointSource(
@@ -541,3 +554,14 @@
541554
filter=Filter(vector_distance_threshold=0.5),
542555
hybrid_search=HybridSearch(alpha=0.5),
543556
)
557+
TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE = RagRetrievalConfig(
558+
top_k=2,
559+
filter=Filter(vector_distance_threshold=0.5),
560+
ranking=Ranking(rank_service=RankService(model_name="test-model-name")),
561+
)
562+
TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER = RagRetrievalConfig(
563+
top_k=2,
564+
filter=Filter(vector_distance_threshold=0.5),
565+
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
566+
)
567+

tests/unit/vertex_rag/test_rag_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def test_update_corpus_failure(self):
503503
@pytest.mark.skip(reason="Need to fix the test later for v1.")
504504
@pytest.mark.usefixtures("rag_data_client_mock")
505505
def test_get_corpus_success(self):
506-
rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME)
506+
rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME)
507507
rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS)
508508

509509
@pytest.mark.skip(reason="Need to fix the test later for v1.")

tests/unit/vertex_rag/test_rag_data_preview.py

+43-24
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
prepare_import_files_request,
2222
set_embedding_model_config,
2323
)
24+
from vertexai.rag.utils.resources import (
25+
ChunkingConfig,
26+
TransformationConfig,
27+
)
2428
from google.cloud.aiplatform_v1beta1 import (
2529
VertexRagDataServiceAsyncClient,
2630
VertexRagDataServiceClient,
@@ -276,6 +280,18 @@ def list_rag_files_pager_mock():
276280
yield list_rag_files_pager_mock
277281

278282

283+
def create_transformation_config(
284+
chunk_size: int = test_rag_constants_preview.TEST_CHUNK_SIZE,
285+
chunk_overlap: int = test_rag_constants_preview.TEST_CHUNK_OVERLAP,
286+
):
287+
return TransformationConfig(
288+
chunking_config=ChunkingConfig(
289+
chunk_size=chunk_size,
290+
chunk_overlap=chunk_overlap,
291+
),
292+
)
293+
294+
279295
def rag_corpus_eq(returned_corpus, expected_corpus):
280296
assert returned_corpus.name == expected_corpus.name
281297
assert returned_corpus.display_name == expected_corpus.display_name
@@ -309,6 +325,10 @@ def import_files_request_eq(returned_request, expected_request):
309325
returned_request.import_rag_files_config.rag_file_parsing_config
310326
== expected_request.import_rag_files_config.rag_file_parsing_config
311327
)
328+
assert (
329+
returned_request.import_rag_files_config.rag_file_transformation_config
330+
== expected_request.import_rag_files_config.rag_file_transformation_config
331+
)
312332

313333

314334
@pytest.mark.usefixtures("google_auth_mock")
@@ -654,6 +674,17 @@ def test_delete_file_failure(self):
654674
e.match("Failed in RagFile deletion due to")
655675

656676
def test_prepare_import_files_request_list_gcs_uris(self):
677+
paths = [test_rag_constants_preview.TEST_GCS_PATH]
678+
request = prepare_import_files_request(
679+
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
680+
paths=paths,
681+
transformation_config=create_transformation_config(),
682+
)
683+
import_files_request_eq(
684+
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_GCS
685+
)
686+
687+
def test_prepare_import_files_request_list_gcs_uris_no_transformation_config(self):
657688
paths = [test_rag_constants_preview.TEST_GCS_PATH]
658689
request = prepare_import_files_request(
659690
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
@@ -676,8 +707,7 @@ def test_prepare_import_files_request_drive_folders(self, path):
676707
request = prepare_import_files_request(
677708
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
678709
paths=[path],
679-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
680-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
710+
transformation_config=create_transformation_config(),
681711
)
682712
import_files_request_eq(
683713
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER
@@ -694,8 +724,7 @@ def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path)
694724
request = prepare_import_files_request(
695725
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
696726
paths=[path],
697-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
698-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
727+
transformation_config=create_transformation_config(),
699728
use_advanced_pdf_parsing=True,
700729
)
701730
import_files_request_eq(
@@ -707,8 +736,7 @@ def test_prepare_import_files_request_drive_files(self):
707736
request = prepare_import_files_request(
708737
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
709738
paths=paths,
710-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
711-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
739+
transformation_config=create_transformation_config(),
712740
max_embedding_requests_per_min=800,
713741
)
714742
import_files_request_eq(
@@ -721,8 +749,7 @@ def test_prepare_import_files_request_invalid_drive_path(self):
721749
prepare_import_files_request(
722750
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
723751
paths=paths,
724-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
725-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
752+
transformation_config=create_transformation_config(),
726753
)
727754
e.match("is not a valid Google Drive url")
728755

@@ -732,17 +759,15 @@ def test_prepare_import_files_request_invalid_path(self):
732759
prepare_import_files_request(
733760
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
734761
paths=paths,
735-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
736-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
762+
transformation_config=create_transformation_config(),
737763
)
738764
e.match("path must be a Google Cloud Storage uri or a Google Drive url")
739765

740766
def test_prepare_import_files_request_slack_source(self):
741767
request = prepare_import_files_request(
742768
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
743769
source=test_rag_constants_preview.TEST_SLACK_SOURCE,
744-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
745-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
770+
transformation_config=create_transformation_config(),
746771
)
747772
import_files_request_eq(
748773
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SLACK_SOURCE
@@ -752,8 +777,7 @@ def test_prepare_import_files_request_jira_source(self):
752777
request = prepare_import_files_request(
753778
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
754779
source=test_rag_constants_preview.TEST_JIRA_SOURCE,
755-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
756-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
780+
transformation_config=create_transformation_config(),
757781
)
758782
import_files_request_eq(
759783
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_JIRA_SOURCE
@@ -763,8 +787,7 @@ def test_prepare_import_files_request_sharepoint_source(self):
763787
request = prepare_import_files_request(
764788
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
765789
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE,
766-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
767-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
790+
transformation_config=create_transformation_config(),
768791
)
769792
import_files_request_eq(
770793
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE
@@ -775,8 +798,7 @@ def test_prepare_import_files_request_sharepoint_source_2_drives(self):
775798
prepare_import_files_request(
776799
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
777800
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_DRIVES,
778-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
779-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
801+
transformation_config=create_transformation_config(),
780802
)
781803
e.match("drive_name and drive_id cannot both be set.")
782804

@@ -785,8 +807,7 @@ def test_prepare_import_files_request_sharepoint_source_2_folders(self):
785807
prepare_import_files_request(
786808
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
787809
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_FOLDERS,
788-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
789-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
810+
transformation_config=create_transformation_config(),
790811
)
791812
e.match("sharepoint_folder_path and sharepoint_folder_id cannot both be set.")
792813

@@ -795,17 +816,15 @@ def test_prepare_import_files_request_sharepoint_source_no_drives(self):
795816
prepare_import_files_request(
796817
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
797818
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_DRIVES,
798-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
799-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
819+
transformation_config=create_transformation_config(),
800820
)
801821
e.match("Either drive_name and drive_id must be set.")
802822

803823
def test_prepare_import_files_request_sharepoint_source_no_folders(self):
804824
request = prepare_import_files_request(
805825
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
806826
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_FOLDERS,
807-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
808-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
827+
transformation_config=create_transformation_config(),
809828
)
810829
import_files_request_eq(
811830
request,

tests/unit/vertex_rag/test_rag_retrieval_preview.py

+22
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,28 @@ def test_retrieval_query_rag_corpora_config_success(self):
130130
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
131131
)
132132

133+
@pytest.mark.usefixtures("retrieve_contexts_mock")
134+
def test_retrieval_query_rag_corpora_config_rank_service_success(self):
135+
response = rag.retrieval_query(
136+
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
137+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
138+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE,
139+
)
140+
retrieve_contexts_eq(
141+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
142+
)
143+
144+
@pytest.mark.usefixtures("retrieve_contexts_mock")
145+
def test_retrieval_query_rag_corpora_config_llm_ranker_success(self):
146+
response = rag.retrieval_query(
147+
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
148+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
149+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER,
150+
)
151+
retrieve_contexts_eq(
152+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
153+
)
154+
133155
@pytest.mark.usefixtures("rag_client_mock_exception")
134156
def test_retrieval_query_failure(self):
135157
with pytest.raises(RuntimeError) as e:

0 commit comments

Comments
 (0)