-
Notifications
You must be signed in to change notification settings - Fork 360
/
Copy path_utils.py
357 lines (305 loc) · 12.3 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import json
import types
import typing
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
import proto
from google.cloud.aiplatform import base
from google.api import httpbody_pb2
from google.protobuf import struct_pb2
from google.protobuf import json_format
try:
# For LangChain templates, they might not import langchain_core and get
# PydanticUserError: `query` is not fully defined; you should define
# `RunnableConfig`, then call `query.model_rebuild()`.
import langchain_core.runnables.config
RunnableConfig = langchain_core.runnables.config.RunnableConfig
except ImportError:
RunnableConfig = Any
JsonDict = Dict[str, Any]
_LOGGER = base.Logger(__name__)
def to_proto(
obj: Union[JsonDict, proto.Message],
message: Optional[proto.Message] = None,
) -> proto.Message:
"""Parses a JSON-like object into a message.
If the object is already a message, this will return the object as-is. If
the object is a JSON Dict, this will parse and merge the object into the
message.
Args:
obj (Union[dict[str, Any], proto.Message]):
Required. The object to convert to a proto message.
message (proto.Message):
Optional. A protocol buffer message to merge the obj into. It
defaults to Struct() if unspecified.
Returns:
proto.Message: The same message passed as argument.
"""
if message is None:
message = struct_pb2.Struct()
if isinstance(obj, (proto.Message, struct_pb2.Struct)):
return obj
try:
json_format.ParseDict(obj, message._pb)
except AttributeError:
json_format.ParseDict(obj, message)
return message
def to_dict(message: proto.Message) -> JsonDict:
"""Converts the contents of the protobuf message to JSON format.
Args:
message (proto.Message):
Required. The proto message to be converted to a JSON dictionary.
Returns:
dict[str, Any]: A dictionary containing the contents of the proto.
"""
try:
# Best effort attempt to convert the message into a JSON dictionary.
result: JsonDict = json.loads(json_format.MessageToJson(message._pb))
except AttributeError:
result: JsonDict = json.loads(json_format.MessageToJson(message))
return result
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
"""Converts the contents of the httpbody message to JSON format.
Args:
body (httpbody_pb2.HttpBody):
Required. The httpbody body to be converted to a JSON.
Yields:
Any: A JSON object or the original body if it is not JSON or None.
"""
content_type = getattr(body, "content_type", None)
data = getattr(body, "data", None)
if content_type is None or data is None or "application/json" not in content_type:
yield body
return
try:
utf8_data = data.decode("utf-8")
except Exception as e:
_LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
yield body
return
if not utf8_data:
yield None
return
# Handle the case of multiple dictionaries delimited by newlines.
for line in utf8_data.split("\n"):
if line:
try:
line = json.loads(line)
except Exception as e:
_LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
yield line
def generate_schema(
f: Callable[..., Any],
*,
schema_name: Optional[str] = None,
descriptions: Mapping[str, str] = {},
required: Sequence[str] = [],
) -> JsonDict:
"""Generates the OpenAPI Schema for a callable object.
Only positional and keyword arguments of the function `f` will be supported
in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will
not be present in the OpenAPI schema returned from this function. For those
cases, you can either include it in the docstring for `f`, or modify the
OpenAPI schema returned from this function to include additional arguments.
Args:
f (Callable):
Required. The function to generate an OpenAPI Schema for.
schema_name (str):
Optional. The name for the OpenAPI schema. If unspecified, the name
of the Callable will be used.
descriptions (Mapping[str, str]):
Optional. A `{name: description}` mapping for annotating input
arguments of the function with user-provided descriptions. It
defaults to an empty dictionary (i.e. there will not be any
description for any of the inputs).
required (Sequence[str]):
Optional. For the user to specify the set of required arguments in
function calls to `f`. If specified, it will be automatically
inferred from `f`.
Returns:
dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
"""
pydantic = _import_pydantic_or_raise()
defaults = dict(inspect.signature(f).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
pydantic.Field(
# 2. We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
# 3. We support user-provided descriptions.
description=descriptions.get(name, None),
),
)
for name, param in defaults.items()
# We do not support *args or **kwargs
if param.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
parameters.pop("title", "")
for name, function_arg in parameters.get("properties", {}).items():
function_arg.pop("title", "")
annotation = defaults[name].annotation
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
annotation
):
# for "typing.Optional" arguments, function_arg might be a
# dictionary like
#
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
for schema in function_arg.pop("anyOf", []):
schema_type = schema.get("type")
if schema_type and schema_type != "null":
function_arg["type"] = schema_type
break
function_arg["nullable"] = True
# 6. Annotate required fields.
if required:
# We use the user-provided "required" fields if specified.
parameters["required"] = required
else:
# Otherwise we infer it from the function signature.
parameters["required"] = [
k
for k in defaults
if (
defaults[k].default == inspect.Parameter.empty
and defaults[k].kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
)
]
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
if schema_name:
schema["name"] = schema_name
return schema
def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool:
"""Returns True if the tracer_provider is Proxy or NoOp."""
opentelemetry = _import_opentelemetry_or_warn()
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
def _import_cloud_storage_or_raise() -> types.ModuleType:
"""Tries to import the Cloud Storage module."""
try:
from google.cloud import storage
except ImportError as e:
raise ImportError(
"Cloud Storage is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
) from e
return storage
def _import_cloudpickle_or_raise() -> types.ModuleType:
"""Tries to import the cloudpickle module."""
try:
import cloudpickle # noqa:F401
except ImportError as e:
raise ImportError(
"cloudpickle is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
) from e
return cloudpickle
def _import_pydantic_or_raise() -> types.ModuleType:
"""Tries to import the pydantic module."""
try:
import pydantic
_ = pydantic.Field
except AttributeError:
from pydantic import v1 as pydantic
except ImportError as e:
raise ImportError(
"pydantic is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
) from e
return pydantic
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry module."""
try:
import opentelemetry # noqa:F401
return opentelemetry
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.sdk.trace module."""
try:
import opentelemetry.sdk.trace # noqa:F401
return opentelemetry.sdk.trace
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None
def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the google.cloud.trace_v2 module."""
try:
import google.cloud.trace_v2
return google.cloud.trace_v2
except ImportError:
_LOGGER.warning(
"google-cloud-trace is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
try:
import opentelemetry.exporter.cloud_trace # noqa:F401
return opentelemetry.exporter.cloud_trace
except ImportError:
_LOGGER.warning(
"opentelemetry-exporter-gcp-trace is not installed. Please "
"call 'pip install google-cloud-aiplatform[langchain]'."
)
return None
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the openinference.instrumentation.langchain module."""
try:
import openinference.instrumentation.langchain # noqa:F401
return openinference.instrumentation.langchain
except ImportError:
_LOGGER.warning(
"openinference-instrumentation-langchain is not installed. Please "
"call 'pip install google-cloud-aiplatform[langchain]'."
)
return None