diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index ebac9f1edc..a42576f95f 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -22,11 +22,14 @@ EmbeddingModelConfig, Filter, HybridSearch, + LlmRanker, Pinecone, RagCorpus, RagFile, RagResource, RagRetrievalConfig, + Ranking, + RankService, SharePointSource, SharePointSources, SlackChannelsSource, @@ -40,6 +43,7 @@ from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, RagFileChunkingConfig, + RagFileTransformationConfig, RagFileParsingConfig, ImportRagFilesConfig, ImportRagFilesRequest, @@ -213,10 +217,20 @@ TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}} TEST_CHUNK_SIZE = 512 TEST_CHUNK_OVERLAP = 100 +TEST_RAG_FILE_TRANSFORMATION_CONFIG = RagFileTransformationConfig( + rag_file_chunking_config=RagFileChunkingConfig( + fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ), + ), +) # GCS -TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig( + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, +) TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH] -TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False +TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = False TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest( parent=TEST_RAG_CORPUS_RESOURCE_NAME, import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS, @@ -229,24 +243,28 @@ TEST_DRIVE_FOLDER_2 = ( f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3" ) -TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig( + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, +) TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [ GoogleDriveSource.ResourceId( resource_id=TEST_DRIVE_FOLDER_ID, resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER, ) ] -TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = ( +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = ( False ) -TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig( + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, +) TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [ GoogleDriveSource.ResourceId( resource_id=TEST_DRIVE_FOLDER_ID, resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER, ) ] -TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = ( +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = ( True ) TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest( @@ -261,11 +279,10 @@ TEST_DRIVE_FILE_ID = "456" TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}" TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig( - rag_file_chunking_config=RagFileChunkingConfig( - chunk_size=TEST_CHUNK_SIZE, - chunk_overlap=TEST_CHUNK_OVERLAP, + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, + rag_file_parsing_config=RagFileParsingConfig( + advanced_parser=RagFileParsingConfig.AdvancedParser(use_advanced_pdf_parsing=False) ), - rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False), ) TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800 @@ -319,12 +336,15 @@ ), ], ) -TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig( - rag_file_chunking_config=RagFileChunkingConfig( - chunk_size=TEST_CHUNK_SIZE, - chunk_overlap=TEST_CHUNK_OVERLAP, +TEST_RAG_FILE_PARSING_CONFIG = RagFileParsingConfig( + advanced_parser=RagFileParsingConfig.AdvancedParser( + use_advanced_pdf_parsing=False ) ) +TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig( + rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, +) TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [ GapicSlackSource.SlackChannels( channels=[ @@ -375,10 +395,8 @@ ], ) TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig( - rag_file_chunking_config=RagFileChunkingConfig( - chunk_size=TEST_CHUNK_SIZE, - chunk_overlap=TEST_CHUNK_OVERLAP, - ) + rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, ) TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [ GapicJiraSource.JiraQueries( @@ -410,10 +428,8 @@ ], ) TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig( - rag_file_chunking_config=RagFileChunkingConfig( - chunk_size=TEST_CHUNK_SIZE, - chunk_overlap=TEST_CHUNK_OVERLAP, - ), + rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, share_point_sources=GapicSharePointSources( share_point_sources=[ GapicSharePointSources.SharePointSource( @@ -488,10 +504,7 @@ ) TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig( - rag_file_chunking_config=RagFileChunkingConfig( - chunk_size=TEST_CHUNK_SIZE, - chunk_overlap=TEST_CHUNK_OVERLAP, - ), + rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, share_point_sources=GapicSharePointSources( share_point_sources=[ GapicSharePointSources.SharePointSource( @@ -541,3 +554,14 @@ filter=Filter(vector_distance_threshold=0.5), hybrid_search=HybridSearch(alpha=0.5), ) +TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE = RagRetrievalConfig( + top_k=2, + filter=Filter(vector_distance_threshold=0.5), + ranking=Ranking(rank_service=RankService(model_name="test-model-name")), +) +TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER = RagRetrievalConfig( + top_k=2, + filter=Filter(vector_distance_threshold=0.5), + ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")), +) + diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index c81d7a5c79..7751beca9b 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -503,7 +503,7 @@ def test_update_corpus_failure(self): @pytest.mark.skip(reason="Need to fix the test later for v1.") @pytest.mark.usefixtures("rag_data_client_mock") def test_get_corpus_success(self): - rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME) + rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME) rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS) @pytest.mark.skip(reason="Need to fix the test later for v1.") diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index 470281083b..658e95ae26 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -21,6 +21,10 @@ prepare_import_files_request, set_embedding_model_config, ) +from vertexai.rag.utils.resources import ( + ChunkingConfig, + TransformationConfig, +) from google.cloud.aiplatform_v1beta1 import ( VertexRagDataServiceAsyncClient, VertexRagDataServiceClient, @@ -276,6 +280,18 @@ def list_rag_files_pager_mock(): yield list_rag_files_pager_mock +def create_transformation_config( + chunk_size: int = test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap: int = test_rag_constants_preview.TEST_CHUNK_OVERLAP, +): + return TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ), + ) + + def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name @@ -309,6 +325,10 @@ def import_files_request_eq(returned_request, expected_request): returned_request.import_rag_files_config.rag_file_parsing_config == expected_request.import_rag_files_config.rag_file_parsing_config ) + assert ( + returned_request.import_rag_files_config.rag_file_transformation_config + == expected_request.import_rag_files_config.rag_file_transformation_config + ) @pytest.mark.usefixtures("google_auth_mock") @@ -654,6 +674,17 @@ def test_delete_file_failure(self): e.match("Failed in RagFile deletion due to") def test_prepare_import_files_request_list_gcs_uris(self): + paths = [test_rag_constants_preview.TEST_GCS_PATH] + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + transformation_config=create_transformation_config(), + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_GCS + ) + + def test_prepare_import_files_request_list_gcs_uris_no_transformation_config(self): paths = [test_rag_constants_preview.TEST_GCS_PATH] request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, @@ -676,8 +707,7 @@ def test_prepare_import_files_request_drive_folders(self, path): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, paths=[path], - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) import_files_request_eq( request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER @@ -694,8 +724,7 @@ def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path) request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, paths=[path], - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), use_advanced_pdf_parsing=True, ) import_files_request_eq( @@ -707,8 +736,7 @@ def test_prepare_import_files_request_drive_files(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, paths=paths, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), max_embedding_requests_per_min=800, ) import_files_request_eq( @@ -721,8 +749,7 @@ def test_prepare_import_files_request_invalid_drive_path(self): prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, paths=paths, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) e.match("is not a valid Google Drive url") @@ -732,8 +759,7 @@ def test_prepare_import_files_request_invalid_path(self): prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, paths=paths, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) e.match("path must be a Google Cloud Storage uri or a Google Drive url") @@ -741,8 +767,7 @@ def test_prepare_import_files_request_slack_source(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SLACK_SOURCE, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) import_files_request_eq( request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SLACK_SOURCE @@ -752,8 +777,7 @@ def test_prepare_import_files_request_jira_source(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_JIRA_SOURCE, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) import_files_request_eq( request, test_rag_constants_preview.TEST_IMPORT_REQUEST_JIRA_SOURCE @@ -763,8 +787,7 @@ def test_prepare_import_files_request_sharepoint_source(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) import_files_request_eq( request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE @@ -775,8 +798,7 @@ def test_prepare_import_files_request_sharepoint_source_2_drives(self): prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_DRIVES, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) e.match("drive_name and drive_id cannot both be set.") @@ -785,8 +807,7 @@ def test_prepare_import_files_request_sharepoint_source_2_folders(self): prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_FOLDERS, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) e.match("sharepoint_folder_path and sharepoint_folder_id cannot both be set.") @@ -795,8 +816,7 @@ def test_prepare_import_files_request_sharepoint_source_no_drives(self): prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_DRIVES, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) e.match("Either drive_name and drive_id must be set.") @@ -804,8 +824,7 @@ def test_prepare_import_files_request_sharepoint_source_no_folders(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_FOLDERS, - chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, - chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + transformation_config=create_transformation_config(), ) import_files_request_eq( request, diff --git a/tests/unit/vertex_rag/test_rag_retrieval_preview.py b/tests/unit/vertex_rag/test_rag_retrieval_preview.py index 9c79636a8c..d04acac1d9 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval_preview.py +++ b/tests/unit/vertex_rag/test_rag_retrieval_preview.py @@ -130,6 +130,28 @@ def test_retrieval_query_rag_corpora_config_success(self): response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE ) + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_corpora_config_rank_service_success(self): + response = rag.retrieval_query( + rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_corpora_config_llm_ranker_success(self): + response = rag.retrieval_query( + rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + @pytest.mark.usefixtures("rag_client_mock_exception") def test_retrieval_query_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 4065616746..bf74e2b1da 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from vertexai.preview.rag.rag_data import ( create_corpus, update_corpus, @@ -28,24 +27,26 @@ list_files, delete_file, ) - from vertexai.preview.rag.rag_retrieval import ( retrieval_query, ) - from vertexai.preview.rag.rag_store import ( Retrieval, VertexRagStore, ) from vertexai.preview.rag.utils.resources import ( + ChunkingConfig, EmbeddingModelConfig, Filter, HybridSearch, JiraQuery, JiraSource, + LlmRanker, Pinecone, RagCorpus, RagFile, + Ranking, + RankService, RagManagedDb, RagResource, RagRetrievalConfig, @@ -53,21 +54,24 @@ SharePointSources, SlackChannel, SlackChannelsSource, + TransformationConfig, VertexFeatureStore, VertexVectorSearch, Weaviate, ) - - __all__ = ( + "ChunkingConfig", "EmbeddingModelConfig", "Filter", "HybridSearch", "JiraQuery", "JiraSource", + "LlmRanker", "Pinecone", "RagCorpus", "RagFile", + "Ranking", + "RankService", "RagManagedDb", "RagResource", "RagRetrievalConfig", @@ -76,6 +80,7 @@ "SharePointSources", "SlackChannel", "SlackChannelsSource", + "TransformationConfig", "VertexFeatureStore", "VertexRagStore", "VertexVectorSearch", diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 2050c252ae..45766500e3 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -51,6 +51,7 @@ RagManagedDb, SharePointSources, SlackChannelsSource, + TransformationConfig, VertexFeatureStore, VertexVectorSearch, Weaviate, @@ -291,6 +292,7 @@ def upload_file( path: Union[str, Sequence[str]], display_name: Optional[str] = None, description: Optional[str] = None, + transformation_config: Optional[TransformationConfig] = None, ) -> RagFile: """ Synchronous file upload to an existing RagCorpus. @@ -303,10 +305,19 @@ def upload_file( vertexai.init(project="my-project") + // Optional. + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + rag_file = rag.upload_file( corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", display_name="my_file.txt", path="usr/home/my_file.txt", + transformation_config=transformation_config, ) ``` @@ -318,6 +329,7 @@ def upload_file( "usr/home/my_file.txt". display_name: The display name of the data file. description: The description of the RagFile. + transformation_config: The config for transforming the RagFile, liking chunking. Returns: RagFile. Raises: @@ -336,12 +348,24 @@ def upload_file( aiplatform.constants.base.API_BASE_PATH, corpus_name, ) + js_rag_file = {"rag_file": {"display_name": display_name}} + if description: - js_rag_file = { - "rag_file": {"display_name": display_name, "description": description} - } - else: - js_rag_file = {"rag_file": {"display_name": display_name}} + js_rag_file["rag_file"]["description"] = description + + if transformation_config and transformation_config.chunking_config: + chunk_size = transformation_config.chunking_config.chunk_size + chunk_overlap = transformation_config.chunking_config.chunk_overlap + js_rag_file["upload_rag_file_config"] = { + "rag_file_transformation_config": { + "rag_file_chunking_config": { + "fixed_length_chunking": { + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + } + } + } + } files = { "metadata": (None, str(js_rag_file)), "file": open(path, "rb"), @@ -372,6 +396,7 @@ def import_files( source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, chunk_size: int = 1024, chunk_overlap: int = 200, + transformation_config: Optional[TransformationConfig] = None, timeout: int = 600, max_embedding_requests_per_min: int = 1000, use_advanced_pdf_parsing: Optional[bool] = False, @@ -396,11 +421,17 @@ def import_files( # Google Cloud Storage example paths = ["gs://my_bucket/my_files_dir", ...] + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + response = rag.import_files( corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", paths=paths, - chunk_size=512, - chunk_overlap=100, + transformation_config=transformation_config, ) # Slack example @@ -429,8 +460,7 @@ def import_files( response = rag.import_files( corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", source=source, - chunk_size=512, - chunk_overlap=100, + transformation_config=transformation_config, ) # SharePoint Example. @@ -460,8 +490,12 @@ def import_files( "https://drive.google.com/corp/drive/folders/..."). source: The source of the Slack or Jira import. Must be either a SlackChannelsSource or JiraSource. - chunk_size: The size of the chunks. - chunk_overlap: The overlap between chunks. + chunk_size: The size of the chunks. This field is deprecated. Please use + transformation_config instead. + chunk_overlap: The overlap between chunks. This field is deprecated. Please use + transformation_config instead. + transformation_config: The config for transforming the imported + RagFiles. max_embedding_requests_per_min: Optional. The max number of queries per minute that this job is allowed to make to the @@ -496,6 +530,7 @@ def import_files( source=source, chunk_size=chunk_size, chunk_overlap=chunk_overlap, + transformation_config=transformation_config, max_embedding_requests_per_min=max_embedding_requests_per_min, use_advanced_pdf_parsing=use_advanced_pdf_parsing, partial_failures_sink=partial_failures_sink, @@ -515,6 +550,7 @@ async def import_files_async( source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, chunk_size: int = 1024, chunk_overlap: int = 200, + transformation_config: Optional[TransformationConfig] = None, max_embedding_requests_per_min: int = 1000, use_advanced_pdf_parsing: Optional[bool] = False, partial_failures_sink: Optional[str] = None, @@ -539,11 +575,17 @@ async def import_files_async( # Google Cloud Storage example paths = ["gs://my_bucket/my_files_dir", ...] + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + response = await rag.import_files_async( corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", paths=paths, - chunk_size=512, - chunk_overlap=100, + transformation_config=transformation_config, ) # Slack example @@ -572,8 +614,7 @@ async def import_files_async( response = await rag.import_files_async( corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", source=source, - chunk_size=512, - chunk_overlap=100, + transformation_config=transformation_config, ) # SharePoint Example. @@ -603,8 +644,12 @@ async def import_files_async( "https://drive.google.com/corp/drive/folders/..."). source: The source of the Slack or Jira import. Must be either a SlackChannelsSource or JiraSource. - chunk_size: The size of the chunks. - chunk_overlap: The overlap between chunks. + chunk_size: The size of the chunks. This field is deprecated. Please use + transformation_config instead. + chunk_overlap: The overlap between chunks. This field is deprecated. Please use + transformation_config instead. + transformation_config: The config for transforming the imported + RagFiles. max_embedding_requests_per_min: Optional. The max number of queries per minute that this job is allowed to make to the @@ -638,6 +683,7 @@ async def import_files_async( source=source, chunk_size=chunk_size, chunk_overlap=chunk_overlap, + transformation_config=transformation_config, max_embedding_requests_per_min=max_embedding_requests_per_min, use_advanced_pdf_parsing=use_advanced_pdf_parsing, partial_failures_sink=partial_failures_sink, diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index 47fb57669f..d247bafc48 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -215,6 +215,31 @@ def retrieval_query( api_retrival_config.filter.vector_distance_threshold = ( vector_distance_threshold ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError( + "Only one of rank_service and llm_ranker can be set." + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrival_config.ranking = aiplatform_v1beta1.RagRetrievalConfig.Ranking( + rank_service=aiplatform_v1beta1.RagRetrievalConfig.Ranking.RankService( + model_name=rag_retrieval_config.ranking.rank_service.model_name + ) + ) + elif ( + rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker + ): + api_retrival_config.ranking = aiplatform_v1beta1.RagRetrievalConfig.Ranking( + llm_ranker=aiplatform_v1beta1.RagRetrievalConfig.Ranking.LlmRanker( + model_name=rag_retrieval_config.ranking.llm_ranker.model_name + ) + ) query = aiplatform_v1beta1.RagQuery( text=text, rag_retrieval_config=api_retrival_config, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index b231ae99cc..29c17da5aa 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -24,6 +24,7 @@ ImportRagFilesRequest, RagFileChunkingConfig, RagFileParsingConfig, + RagFileTransformationConfig, RagCorpus as GapicRagCorpus, RagFile as GapicRagFile, SharePointSources as GapicSharePointSources, @@ -45,6 +46,7 @@ RagManagedDb, SharePointSources, SlackChannelsSource, + TransformationConfig, JiraSource, VertexFeatureStore, VertexVectorSearch, @@ -347,6 +349,7 @@ def prepare_import_files_request( source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, chunk_size: int = 1024, chunk_overlap: int = 200, + transformation_config: Optional[TransformationConfig] = None, max_embedding_requests_per_min: int = 1000, use_advanced_pdf_parsing: bool = False, partial_failures_sink: Optional[str] = None, @@ -357,14 +360,26 @@ def prepare_import_files_request( ) rag_file_parsing_config = RagFileParsingConfig( - use_advanced_pdf_parsing=use_advanced_pdf_parsing, + advanced_parser=RagFileParsingConfig.AdvancedParser( + use_advanced_pdf_parsing=use_advanced_pdf_parsing, + ), ) - rag_file_chunking_config = RagFileChunkingConfig( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, + local_chunk_size = chunk_size + local_chunk_overlap = chunk_overlap + if transformation_config and transformation_config.chunking_config: + local_chunk_size = transformation_config.chunking_config.chunk_size + local_chunk_overlap = transformation_config.chunking_config.chunk_overlap + + rag_file_transformation_config = RagFileTransformationConfig( + rag_file_chunking_config=RagFileChunkingConfig( + fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking( + chunk_size=local_chunk_size, + chunk_overlap=local_chunk_overlap, + ), + ), ) import_rag_files_config = ImportRagFilesConfig( - rag_file_chunking_config=rag_file_chunking_config, + rag_file_transformation_config=rag_file_transformation_config, max_embedding_requests_per_min=max_embedding_requests_per_min, rag_file_parsing_config=rag_file_parsing_config, ) diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 1db8f0899f..5993537cfa 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -323,7 +323,8 @@ class LlmRanker: """LlmRanker. Attributes: - model_name: The model name used for ranking. + model_name: The model name used for ranking. Only Gemini models are + supported for now. """ model_name: Optional[str] = None @@ -373,3 +374,28 @@ class RagRetrievalConfig: filter: Optional[Filter] = None hybrid_search: Optional[HybridSearch] = None ranking: Optional[Ranking] = None + + +@dataclasses.dataclass +class ChunkingConfig: + """ChunkingConfig. + + Attributes: + chunk_size: The size of each chunk. + chunk_overlap: The size of the overlap between chunks. + """ + + chunk_size: int + chunk_overlap: int + + +@dataclasses.dataclass +class TransformationConfig: + """TransformationConfig. + + Attributes: + chunking_config: The chunking config. + """ + + chunking_config: Optional[ChunkingConfig] = None +