Skip to content

Commit 2195411

Browse files
BenjaminKazemicopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support Assemble feature on the multimodal datasets.
PiperOrigin-RevId: 832506583
1 parent d1da180 commit 2195411

File tree

6 files changed

+694
-141
lines changed

6 files changed

+694
-141
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
20+
import pytest
21+
22+
METADATA_SCHEMA_URI = (
23+
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
24+
)
25+
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
26+
DATASET = "8810841321427173376"
27+
28+
29+
def test_assemble_dataset(client):
30+
operation = client.datasets._assemble_multimodal_dataset(
31+
name=DATASET,
32+
gemini_request_read_config={
33+
"template_config": {
34+
"field_mapping": {"question": "questionColumn"},
35+
},
36+
},
37+
)
38+
assert isinstance(operation, types.MultimodalDatasetOperation)
39+
40+
41+
def test_assemble_dataset_public(client):
42+
bigquery_destination = client.datasets.assemble(
43+
name=DATASET,
44+
template_config=types.GeminiTemplateConfig(
45+
gemini_example=types.GeminiExample(
46+
model="gemini-1.5-flash",
47+
contents=[
48+
{
49+
"role": "user",
50+
"parts": [{"text": "What is the capital of {name}?"}],
51+
}
52+
],
53+
),
54+
),
55+
)
56+
assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}")
57+
58+
59+
pytestmark = pytest_helper.setup(
60+
file=__file__,
61+
globals_for_file=globals(),
62+
)
63+
64+
pytest_plugins = ("pytest_asyncio",)
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_assemble_dataset_async(client):
69+
operation = await client.aio.datasets._assemble_multimodal_dataset(
70+
name=DATASET,
71+
gemini_request_read_config={
72+
"template_config": {
73+
"field_mapping": {"question": "questionColumn"},
74+
},
75+
},
76+
)
77+
assert isinstance(operation, types.MultimodalDatasetOperation)
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_assemble_dataset_public_async(client):
82+
bigquery_destination = await client.aio.datasets.assemble(
83+
name=DATASET,
84+
template_config=types.GeminiTemplateConfig(
85+
gemini_example=types.GeminiExample(
86+
model="gemini-1.5-flash",
87+
contents=[
88+
{
89+
"role": "user",
90+
"parts": [{"text": "What is the capital of {name}?"}],
91+
}
92+
],
93+
),
94+
),
95+
)
96+
assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}")

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def test_create_dataset_from_bigquery(client):
5454
)
5555
assert isinstance(dataset, types.MultimodalDataset)
5656
assert dataset.display_name == "test-from-bigquery"
57+
assert dataset.metadata.input_config.bigquery_source.uri == (
58+
f"bq://{BIGQUERY_TABLE_NAME}"
59+
)
5760

5861

5962
def test_create_dataset_from_bigquery_without_bq_prefix(client):
@@ -70,6 +73,9 @@ def test_create_dataset_from_bigquery_without_bq_prefix(client):
7073
)
7174
assert isinstance(dataset, types.MultimodalDataset)
7275
assert dataset.display_name == "test-from-bigquery"
76+
assert dataset.metadata.input_config.bigquery_source.uri == (
77+
f"bq://{BIGQUERY_TABLE_NAME}"
78+
)
7379

7480

7581
pytestmark = pytest_helper.setup(
@@ -111,6 +117,9 @@ async def test_create_dataset_from_bigquery_async(client):
111117
)
112118
assert isinstance(dataset, types.MultimodalDataset)
113119
assert dataset.display_name == "test-from-bigquery"
120+
assert dataset.metadata.input_config.bigquery_source.uri == (
121+
f"bq://{BIGQUERY_TABLE_NAME}"
122+
)
114123

115124

116125
@pytest.mark.asyncio
@@ -129,6 +138,9 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client):
129138
)
130139
assert isinstance(dataset, types.MultimodalDataset)
131140
assert dataset.display_name == "test-from-bigquery"
141+
assert dataset.metadata.input_config.bigquery_source.uri == (
142+
f"bq://{BIGQUERY_TABLE_NAME}"
143+
)
132144

133145

134146
@pytest.mark.asyncio
@@ -146,3 +158,6 @@ async def test_create_dataset_from_bigquery_async_without_bq_prefix(client):
146158
)
147159
assert isinstance(dataset, types.MultimodalDataset)
148160
assert dataset.display_name == "test-from-bigquery"
161+
assert dataset.metadata.input_config.bigquery_source.uri == (
162+
f"bq://{BIGQUERY_TABLE_NAME}"
163+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,23 @@
1414
#
1515
"""Utility functions for multimodal dataset."""
1616

17+
from typing import Any, TypeVar, Type
18+
from vertexai._genai.types import common
19+
from pydantic import BaseModel
1720

1821
METADATA_SCHEMA_URI = (
1922
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
2023
)
24+
25+
T = TypeVar("T", bound=BaseModel)
26+
27+
28+
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
29+
"""Creates a model from a response."""
30+
model_field_names = model_type.model_fields.keys()
31+
filtered_response = {}
32+
for key, value in response.items():
33+
snake_key = common.camel_to_snake(key)
34+
if snake_key in model_field_names:
35+
filtered_response[snake_key] = value
36+
return model_type(**filtered_response)

0 commit comments

Comments
 (0)