Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update v1beta1 sdk for a few new protos #4719

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
EmbeddingModelConfig,
Filter,
HybridSearch,
LlmRanker,
Pinecone,
RagCorpus,
RagFile,
RagResource,
RagRetrievalConfig,
Ranking,
RankService,
SharePointSource,
SharePointSources,
SlackChannelsSource,
Expand All @@ -40,6 +43,7 @@
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
RagFileTransformationConfig,
RagFileParsingConfig,
ImportRagFilesConfig,
ImportRagFilesRequest,
Expand Down Expand Up @@ -213,10 +217,20 @@
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
TEST_CHUNK_SIZE = 512
TEST_CHUNK_OVERLAP = 100
TEST_RAG_FILE_TRANSFORMATION_CONFIG = RagFileTransformationConfig(
rag_file_chunking_config=RagFileChunkingConfig(
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
),
)
# GCS
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = False
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
Expand All @@ -229,24 +243,28 @@
TEST_DRIVE_FOLDER_2 = (
f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3"
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig(
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
False
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig(
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
True
)
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
Expand All @@ -261,11 +279,10 @@
TEST_DRIVE_FILE_ID = "456"
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
rag_file_parsing_config=RagFileParsingConfig(
advanced_parser=RagFileParsingConfig.AdvancedParser(use_advanced_pdf_parsing=False)
),
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800

Expand Down Expand Up @@ -319,12 +336,15 @@
),
],
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
TEST_RAG_FILE_PARSING_CONFIG = RagFileParsingConfig(
advanced_parser=RagFileParsingConfig.AdvancedParser(
use_advanced_pdf_parsing=False
)
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
GapicSlackSource.SlackChannels(
channels=[
Expand Down Expand Up @@ -375,10 +395,8 @@
],
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
GapicJiraSource.JiraQueries(
Expand Down Expand Up @@ -410,10 +428,8 @@
],
)
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
share_point_sources=GapicSharePointSources(
share_point_sources=[
GapicSharePointSources.SharePointSource(
Expand Down Expand Up @@ -488,10 +504,7 @@
)

TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
share_point_sources=GapicSharePointSources(
share_point_sources=[
GapicSharePointSources.SharePointSource(
Expand Down Expand Up @@ -541,3 +554,14 @@
filter=Filter(vector_distance_threshold=0.5),
hybrid_search=HybridSearch(alpha=0.5),
)
TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(rank_service=RankService(model_name="test-model-name")),
)
TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
)

2 changes: 1 addition & 1 deletion tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_update_corpus_failure(self):
@pytest.mark.skip(reason="Need to fix the test later for v1.")
@pytest.mark.usefixtures("rag_data_client_mock")
def test_get_corpus_success(self):
rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME)
rag_corpus = rag.get_corpus(test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME)
rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS)

@pytest.mark.skip(reason="Need to fix the test later for v1.")
Expand Down
67 changes: 43 additions & 24 deletions tests/unit/vertex_rag/test_rag_data_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
prepare_import_files_request,
set_embedding_model_config,
)
from vertexai.rag.utils.resources import (
ChunkingConfig,
TransformationConfig,
)
from google.cloud.aiplatform_v1beta1 import (
VertexRagDataServiceAsyncClient,
VertexRagDataServiceClient,
Expand Down Expand Up @@ -276,6 +280,18 @@ def list_rag_files_pager_mock():
yield list_rag_files_pager_mock


def create_transformation_config(
chunk_size: int = test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap: int = test_rag_constants_preview.TEST_CHUNK_OVERLAP,
):
return TransformationConfig(
chunking_config=ChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
),
)


def rag_corpus_eq(returned_corpus, expected_corpus):
assert returned_corpus.name == expected_corpus.name
assert returned_corpus.display_name == expected_corpus.display_name
Expand Down Expand Up @@ -309,6 +325,10 @@ def import_files_request_eq(returned_request, expected_request):
returned_request.import_rag_files_config.rag_file_parsing_config
== expected_request.import_rag_files_config.rag_file_parsing_config
)
assert (
returned_request.import_rag_files_config.rag_file_transformation_config
== expected_request.import_rag_files_config.rag_file_transformation_config
)


@pytest.mark.usefixtures("google_auth_mock")
Expand Down Expand Up @@ -654,6 +674,17 @@ def test_delete_file_failure(self):
e.match("Failed in RagFile deletion due to")

def test_prepare_import_files_request_list_gcs_uris(self):
paths = [test_rag_constants_preview.TEST_GCS_PATH]
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=paths,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_GCS
)

def test_prepare_import_files_request_list_gcs_uris_no_transformation_config(self):
paths = [test_rag_constants_preview.TEST_GCS_PATH]
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
Expand All @@ -676,8 +707,7 @@ def test_prepare_import_files_request_drive_folders(self, path):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=[path],
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER
Expand All @@ -694,8 +724,7 @@ def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path)
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=[path],
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
use_advanced_pdf_parsing=True,
)
import_files_request_eq(
Expand All @@ -707,8 +736,7 @@ def test_prepare_import_files_request_drive_files(self):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=paths,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
max_embedding_requests_per_min=800,
)
import_files_request_eq(
Expand All @@ -721,8 +749,7 @@ def test_prepare_import_files_request_invalid_drive_path(self):
prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=paths,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
e.match("is not a valid Google Drive url")

Expand All @@ -732,17 +759,15 @@ def test_prepare_import_files_request_invalid_path(self):
prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=paths,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
e.match("path must be a Google Cloud Storage uri or a Google Drive url")

def test_prepare_import_files_request_slack_source(self):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SLACK_SOURCE,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SLACK_SOURCE
Expand All @@ -752,8 +777,7 @@ def test_prepare_import_files_request_jira_source(self):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_JIRA_SOURCE,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_JIRA_SOURCE
Expand All @@ -763,8 +787,7 @@ def test_prepare_import_files_request_sharepoint_source(self):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE
Expand All @@ -775,8 +798,7 @@ def test_prepare_import_files_request_sharepoint_source_2_drives(self):
prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_DRIVES,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
e.match("drive_name and drive_id cannot both be set.")

Expand All @@ -785,8 +807,7 @@ def test_prepare_import_files_request_sharepoint_source_2_folders(self):
prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_FOLDERS,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
e.match("sharepoint_folder_path and sharepoint_folder_id cannot both be set.")

Expand All @@ -795,17 +816,15 @@ def test_prepare_import_files_request_sharepoint_source_no_drives(self):
prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_DRIVES,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
e.match("Either drive_name and drive_id must be set.")

def test_prepare_import_files_request_sharepoint_source_no_folders(self):
request = prepare_import_files_request(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_FOLDERS,
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
transformation_config=create_transformation_config(),
)
import_files_request_eq(
request,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/vertex_rag/test_rag_retrieval_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,28 @@ def test_retrieval_query_rag_corpora_config_success(self):
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_corpora_config_rank_service_success(self):
response = rag.retrieval_query(
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_corpora_config_llm_ranker_success(self):
response = rag.retrieval_query(
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("rag_client_mock_exception")
def test_retrieval_query_failure(self):
with pytest.raises(RuntimeError) as e:
Expand Down
Loading