Skip to content

Commit 7a90ab4

Browse files
yeesiancopybara-github
authored andcommitted
fix: Handle the streaming of JSON delimited by newlines
FUTURE_COPYBARA_INTEGRATE_REVIEW=#4861 from googleapis:release-please--branches--main 039f2cb PiperOrigin-RevId: 719423860
1 parent 713ffac commit 7a90ab4

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

tests/unit/vertex_langchain/test_reasoning_engines.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -2234,60 +2234,58 @@ class ToParsedJsonTest(parameterized.TestCase):
22342234
obj=httpbody_pb2.HttpBody(
22352235
content_type="application/json", data=b'{"a": 1, "b": "hello"}'
22362236
),
2237-
expected={"a": 1, "b": "hello"},
2237+
expected=[{"a": 1, "b": "hello"}],
22382238
),
22392239
dict(
22402240
testcase_name="invalid_json",
22412241
obj=httpbody_pb2.HttpBody(
22422242
content_type="application/json", data=b'{"a": 1, "b": "hello"'
22432243
),
2244-
expected=httpbody_pb2.HttpBody(
2245-
content_type="application/json", data=b'{"a": 1, "b": "hello"'
2246-
),
2244+
expected=['{"a": 1, "b": "hello"'], # returns the unparsed string
22472245
),
22482246
dict(
22492247
testcase_name="missing_content_type",
22502248
obj=httpbody_pb2.HttpBody(data=b'{"a": 1}'),
2251-
expected=httpbody_pb2.HttpBody(data=b'{"a": 1}'),
2249+
expected=[httpbody_pb2.HttpBody(data=b'{"a": 1}')],
22522250
),
22532251
dict(
22542252
testcase_name="missing_data",
22552253
obj=httpbody_pb2.HttpBody(content_type="application/json"),
2256-
expected=None,
2254+
expected=[None],
22572255
),
22582256
dict(
22592257
testcase_name="wrong_content_type",
22602258
obj=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"),
2261-
expected=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"),
2259+
expected=[httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")],
22622260
),
22632261
dict(
22642262
testcase_name="empty_data",
22652263
obj=httpbody_pb2.HttpBody(content_type="application/json", data=b""),
2266-
expected=None,
2264+
expected=[None],
22672265
),
22682266
dict(
22692267
testcase_name="unicode_data",
22702268
obj=httpbody_pb2.HttpBody(
22712269
content_type="application/json", data='{"a": "你好"}'.encode("utf-8")
22722270
),
2273-
expected={"a": "你好"},
2271+
expected=[{"a": "你好"}],
22742272
),
22752273
dict(
22762274
testcase_name="nested_json",
22772275
obj=httpbody_pb2.HttpBody(
22782276
content_type="application/json", data=b'{"a": {"b": 1}}'
22792277
),
2280-
expected={"a": {"b": 1}},
2278+
expected=[{"a": {"b": 1}}],
22812279
),
22822280
dict(
2283-
testcase_name="error_handling",
2281+
testcase_name="multiline_json",
22842282
obj=httpbody_pb2.HttpBody(
2285-
content_type="application/json", data=b'{"a": 1, "b": "hello"'
2286-
),
2287-
expected=httpbody_pb2.HttpBody(
2288-
content_type="application/json", data=b'{"a": 1, "b": "hello"'
2283+
content_type="application/json",
2284+
data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}'
22892285
),
2286+
expected=[{"a": {"b": 1}}, {"a": {"b": 2}}],
22902287
),
22912288
)
22922289
def test_to_parsed_json(self, obj, expected):
2293-
self.assertEqual(_utils.to_parsed_json(obj), expected)
2290+
for got, want in zip(_utils.yield_parsed_json(obj), expected):
2291+
self.assertEqual(got, want)

vertexai/reasoning_engines/_reasoning_engines.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,9 @@ def _method(self, **kwargs) -> Iterable[Any]:
840840
),
841841
)
842842
for chunk in response:
843-
parsed_json = _utils.to_parsed_json(chunk)
844-
if parsed_json is not None:
845-
yield parsed_json
843+
for parsed_json in _utils.yield_parsed_json(chunk):
844+
if parsed_json is not None:
845+
yield parsed_json
846846

847847
_method.__name__ = method_name
848848
_method.__doc__ = doc

vertexai/reasoning_engines/_utils.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import types
1919
import typing
20-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
20+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
2121

2222
import proto
2323

@@ -90,36 +90,44 @@ def to_dict(message: proto.Message) -> JsonDict:
9090
return result
9191

9292

93-
def to_parsed_json(body: httpbody_pb2.HttpBody) -> Any:
93+
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
9494
"""Converts the contents of the httpbody message to JSON format.
9595
9696
Args:
9797
body (httpbody_pb2.HttpBody):
9898
Required. The httpbody body to be converted to a JSON.
9999
100-
Returns:
100+
Yields:
101101
Any: A JSON object or the original body if it is not JSON or None.
102102
"""
103103
content_type = getattr(body, "content_type", None)
104104
data = getattr(body, "data", None)
105105

106106
if content_type is None or data is None or "application/json" not in content_type:
107-
return body
107+
yield body
108+
return
108109

109110
try:
110111
utf8_data = data.decode("utf-8")
111112
except Exception as e:
112113
_LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
113-
return body
114+
yield body
115+
return
114116

115117
if not utf8_data:
116-
return None
117-
118-
try:
119-
return json.loads(utf8_data)
120-
except Exception as e:
121-
_LOGGER.warning(f"Failed to parse JSON: {utf8_data}. Exception: {e}")
122-
return body # Return the raw body on error
118+
yield None
119+
return
120+
121+
# Handle the case of multiple dictionaries delimited by newlines.
122+
for line in utf8_data.split("\n"):
123+
if line:
124+
try:
125+
line = json.loads(line)
126+
except Exception as e:
127+
_LOGGER.warning(
128+
f"failed to parse json: {line}. Exception: {e}"
129+
)
130+
yield line
123131

124132

125133
def generate_schema(
@@ -195,7 +203,7 @@ def generate_schema(
195203
# * https://github.com/pydantic/pydantic/issues/1270
196204
# * https://stackoverflow.com/a/58841311
197205
# * https://github.com/pydantic/pydantic/discussions/4872
198-
if typing.get_origin(annotation) is typing.Union and type(
206+
if typing.get_origin(annotation) is Union and type(
199207
None
200208
) in typing.get_args(annotation):
201209
# for "typing.Optional" arguments, function_arg might be a

0 commit comments

Comments
 (0)