Skip to content

Commit 344600f

Browse files
yeesiancopybara-github
authored andcommitted
feat: Use TypeAliasType to define aliases for union types in generative models
This is based on the original PR in #4701, just wrapping the typealiases in a try-catch block. PiperOrigin-RevId: 704506046
1 parent 5a4e9c0 commit 344600f

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123

124124
genai_requires = (
125125
"pydantic < 3",
126+
"typing_extensions",
126127
"docstring_parser < 1",
127128
)
128129

@@ -143,7 +144,8 @@
143144
"google-cloud-trace < 2",
144145
"opentelemetry-sdk < 2",
145146
"opentelemetry-exporter-gcp-trace < 2",
146-
"pydantic >= 2.6.3, < 2.10",
147+
"pydantic >= 2.6.3, < 3",
148+
"typing_extensions",
147149
]
148150

149151
evaluation_extra_require = [

testing/constraints-langchain.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
langchain
22
langchain-core
3-
langchain-google-vertexai
4-
pydantic<2.10
3+
langchain-google-vertexai

vertexai/generative_models/_generative_models.py

+45
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,51 @@
117117
],
118118
]
119119

120+
try:
121+
# For Pydantic to resolve the forward references inside these aliases.
122+
from typing_extensions import TypeAliasType
123+
124+
PartsType = TypeAliasType(
125+
"PartsType",
126+
Union[
127+
str,
128+
"Image",
129+
"Part",
130+
List[Union[str, "Image", "Part"]],
131+
],
132+
)
133+
ContentsType = TypeAliasType(
134+
"ContentsType",
135+
Union[
136+
List["Content"],
137+
List[ContentDict],
138+
str,
139+
"Image",
140+
"Part",
141+
List[Union[str, "Image", "Part"]],
142+
],
143+
)
144+
GenerationConfigType = TypeAliasType(
145+
"GenerationConfigType",
146+
Union[
147+
"GenerationConfig",
148+
GenerationConfigDict,
149+
],
150+
)
151+
SafetySettingsType = TypeAliasType(
152+
"SafetySettingsType",
153+
Union[
154+
List["SafetySetting"],
155+
Dict[
156+
gapic_content_types.HarmCategory,
157+
gapic_content_types.SafetySetting.HarmBlockThreshold,
158+
],
159+
],
160+
)
161+
except ImportError:
162+
# Use existing definitions if typing_extensions is not available.
163+
pass
164+
120165

121166
def _reconcile_model_name(model_name: str, project: str, location: str) -> str:
122167
"""Returns a model name that's one of the following:

0 commit comments

Comments
 (0)