Skip to content

Commit 92d8b2a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Enabling zero-shot prompt optimization for prompts from Android API by passing optimization_target=vertexai.types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO in the config
PiperOrigin-RevId: 825240223
1 parent 10ca56f commit 92d8b2a

File tree

5 files changed

+125
-21
lines changed

5 files changed

+125
-21
lines changed

tests/unit/vertexai/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def test_optimize_prompt(client):
2727
assert response.raw_text_response
2828

2929

30+
# def test_optimize_prompt_w_optimization_target(client):
31+
# """Tests the optimize request parameters method with optimization target."""
32+
# from google.genai import types as genai_types
33+
# test_prompt = "Generate system instructions for analyzing medical articles"
34+
# response = client.prompt_optimizer.optimize_prompt(
35+
# prompt=test_prompt,
36+
# config=types.OptimizeConfig(
37+
# optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
38+
# ),
39+
# )
40+
# assert isinstance(response, types.OptimizeResponse)
41+
# assert response.raw_text_response
42+
43+
3044
pytestmark = pytest_helper.setup(
3145
file=__file__,
3246
globals_for_file=globals(),

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,57 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
6969
def test_prompt_optimizer_optimize_prompt(
7070
self, mock_custom_optimize_prompt, mock_client
7171
):
72-
"""Test that prompt_optimizer.optimize method creates a custom job."""
72+
"""Test that prompt_optimizer.optimize_prompt method calls optimize_prompt API."""
7373
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
7474
test_client.prompt_optimizer.optimize_prompt(prompt="test_prompt")
7575
mock_client.assert_called_once()
7676
mock_custom_optimize_prompt.assert_called_once()
7777

78-
# TODO(b/415060797): add more tests for prompt_optimizer.optimize
78+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
79+
def test_prompt_optimizer_optimize_prompt_with_optimization_target(
80+
self, mock_custom_optimize_prompt
81+
):
82+
"""Test that prompt_optimizer.optimize_prompt method calls _custom_optimize_prompt with optimization_target."""
83+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
84+
config = types.OptimizeConfig(
85+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
86+
)
87+
test_client.prompt_optimizer.optimize_prompt(
88+
prompt="test_prompt",
89+
config=config,
90+
)
91+
mock_custom_optimize_prompt.assert_called_once_with(
92+
content=mock.ANY,
93+
config=config,
94+
)
95+
96+
@pytest.mark.asyncio
97+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
98+
async def test_async_prompt_optimizer_optimize_prompt(
99+
self, mock_custom_optimize_prompt
100+
):
101+
"""Test that async prompt_optimizer.optimize_prompt method calls optimize_prompt API."""
102+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
103+
await test_client.aio.prompt_optimizer.optimize_prompt(prompt="test_prompt")
104+
mock_custom_optimize_prompt.assert_called_once()
105+
106+
@pytest.mark.asyncio
107+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
108+
async def test_async_prompt_optimizer_optimize_prompt_with_optimization_target(
109+
self, mock_custom_optimize_prompt
110+
):
111+
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with optimization_target."""
112+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
113+
config = types.OptimizeConfig(
114+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
115+
)
116+
await test_client.aio.prompt_optimizer.optimize_prompt(
117+
prompt="test_prompt",
118+
config=config,
119+
)
120+
mock_custom_optimize_prompt.assert_called_once_with(
121+
content=mock.ANY,
122+
config=config,
123+
)
124+
125+
# # TODO(b/415060797): add more tests for prompt_optimizer.optimize

vertexai/_genai/prompt_optimizer.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,22 @@ def _GetCustomJobParameters_to_vertex(
167167
return to_object
168168

169169

170+
def _OptimizeConfig_to_vertex(
171+
from_object: Union[dict[str, Any], object],
172+
parent_object: Optional[dict[str, Any]] = None,
173+
) -> dict[str, Any]:
174+
to_object: dict[str, Any] = {}
175+
176+
if getv(from_object, ["optimization_target"]) is not None:
177+
setv(
178+
parent_object,
179+
["optimizationTarget"],
180+
getv(from_object, ["optimization_target"]),
181+
)
182+
183+
return to_object
184+
185+
170186
def _OptimizeRequestParameters_to_vertex(
171187
from_object: Union[dict[str, Any], object],
172188
parent_object: Optional[dict[str, Any]] = None,
@@ -176,7 +192,11 @@ def _OptimizeRequestParameters_to_vertex(
176192
setv(to_object, ["content"], getv(from_object, ["content"]))
177193

178194
if getv(from_object, ["config"]) is not None:
179-
setv(to_object, ["config"], getv(from_object, ["config"]))
195+
setv(
196+
to_object,
197+
["config"],
198+
_OptimizeConfig_to_vertex(getv(from_object, ["config"]), to_object),
199+
)
180200

181201
return to_object
182202

@@ -468,7 +488,10 @@ def optimize(
468488
return job
469489

470490
def optimize_prompt(
471-
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
491+
self,
492+
*,
493+
prompt: str,
494+
config: Optional[types.OptimizeConfig] = None,
472495
) -> types.OptimizeResponse:
473496
"""Makes an API request to _optimize_prompt and returns the parsed response.
474497
@@ -480,19 +503,21 @@ def optimize_prompt(
480503
481504
Args:
482505
prompt: The prompt to optimize.
483-
config: The configuration for prompt optimization. Currently, config is
484-
not supported for a single prompt optimization.
506+
config: Optional.The configuration for prompt optimization. To optimize
507+
prompts from Android API provide
508+
types.OptimizeConfig(
509+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO
510+
)
485511
Returns:
486512
The parsed response from the API request.
487513
"""
488-
if config is not None:
489-
raise ValueError(
490-
"Currently, config is not supported for a single prompt optimization."
491-
)
492514

493515
prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
494516
# TODO: b/435653980 - replace the custom method with a generated method.
495-
return self._custom_optimize_prompt(content=prompt)
517+
return self._custom_optimize_prompt(
518+
content=prompt,
519+
config=config,
520+
)
496521

497522
def _custom_optimize_prompt(
498523
self,
@@ -511,7 +536,6 @@ def _custom_optimize_prompt(
511536
content=content,
512537
config=config,
513538
)
514-
515539
request_url_dict: Optional[dict[str, str]]
516540
if not self._api_client.vertexai:
517541
raise ValueError("This method is only supported in the Vertex AI client.")
@@ -850,7 +874,6 @@ async def _custom_optimize_prompt(
850874
content=content,
851875
config=config,
852876
)
853-
854877
request_url_dict: Optional[dict[str, str]]
855878
if not self._api_client.vertexai:
856879
raise ValueError("This method is only supported in the Vertex AI client.")
@@ -909,7 +932,10 @@ async def _custom_optimize_prompt(
909932
return final_response
910933

911934
async def optimize_prompt(
912-
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
935+
self,
936+
*,
937+
prompt: str,
938+
config: Optional[types.OptimizeConfig] = None,
913939
) -> types.OptimizeResponse:
914940
"""Makes an async request to _optimize_prompt and returns an optimized prompt.
915941
@@ -920,16 +946,18 @@ async def optimize_prompt(
920946
921947
Args:
922948
prompt: The prompt to optimize.
923-
config: The configuration for prompt optimization. Currently, config is
924-
not supported for a single prompt optimization.
949+
config: Optional.The configuration for prompt optimization. To optimize
950+
prompts from Android API provide
951+
types.OptimizeConfig(
952+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO
953+
)
925954
Returns:
926955
The parsed response from the API request.
927956
"""
928-
if config is not None:
929-
raise ValueError(
930-
"Currently, config is not supported for a single prompt optimization."
931-
)
932957

933958
prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
934959
# TODO: b/435653980 - replace the custom method with a generated method.
935-
return await self._custom_optimize_prompt(content=prompt)
960+
return await self._custom_optimize_prompt(
961+
content=prompt,
962+
config=config,
963+
)

vertexai/_genai/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@
584584
from .common import OptimizeResponseEndpointDict
585585
from .common import OptimizeResponseEndpointOrDict
586586
from .common import OptimizeResponseOrDict
587+
from .common import OptimizeTarget
587588
from .common import PairwiseChoice
588589
from .common import PairwiseMetricInput
589590
from .common import PairwiseMetricInputDict
@@ -1828,6 +1829,7 @@
18281829
"RubricContentType",
18291830
"EvaluationRunState",
18301831
"Importance",
1832+
"OptimizeTarget",
18311833
"GenerateMemoriesResponseGeneratedMemoryAction",
18321834
"PromptData",
18331835
"PromptDataDict",

vertexai/_genai/types/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,13 @@ class Importance(_common.CaseInSensitiveEnum):
331331
"""Low importance."""
332332

333333

334+
class OptimizeTarget(_common.CaseInSensitiveEnum):
335+
"""None"""
336+
337+
OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO"
338+
"""The data driven prompt optimizer designer for prompts from Android core API."""
339+
340+
334341
class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum):
335342
"""The action to take."""
336343

@@ -3986,6 +3993,9 @@ class OptimizeConfig(_common.BaseModel):
39863993
http_options: Optional[genai_types.HttpOptions] = Field(
39873994
default=None, description="""Used to override HTTP request options."""
39883995
)
3996+
optimization_target: Optional[OptimizeTarget] = Field(
3997+
default=None, description=""""""
3998+
)
39893999

39904000

39914001
class OptimizeConfigDict(TypedDict, total=False):
@@ -3994,6 +4004,9 @@ class OptimizeConfigDict(TypedDict, total=False):
39944004
http_options: Optional[genai_types.HttpOptionsDict]
39954005
"""Used to override HTTP request options."""
39964006

4007+
optimization_target: Optional[OptimizeTarget]
4008+
""""""
4009+
39974010

39984011
OptimizeConfigOrDict = Union[OptimizeConfig, OptimizeConfigDict]
39994012

0 commit comments

Comments
 (0)