Skip to content

Commit a6bf4ae

Browse files
committed
Accept AsyncIterables being passed to Response
Fixes pallets/flask#5322
1 parent 2fc6d4f commit a6bf4ae

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

src/quart/typing.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AnyStr,
1010
AsyncContextManager,
1111
AsyncGenerator,
12+
AsyncIterator,
1213
Awaitable,
1314
Callable,
1415
Dict,

src/quart/wrappers/response.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,28 @@ async def __anext__(self) -> bytes:
102102

103103

104104
class IterableBody(ResponseBody):
105-
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
105+
def __init__(self, iterable: AsyncIterable[bytes] | Iterable) -> None:
106106
self.iter: AsyncGenerator[bytes, None]
107107
if isasyncgen(iterable):
108108
self.iter = iterable
109109
elif isgenerator(iterable):
110110
self.iter = run_sync_iterable(iterable)
111-
else:
111+
elif isinstance(iterable, AsyncIterable):
112112

113113
async def _aiter() -> AsyncGenerator[bytes, None]:
114-
for data in iterable: # type: ignore
114+
async for data in iterable:
115115
yield data
116116

117117
self.iter = _aiter()
118+
elif isinstance(iterable, Iterable):
119+
120+
async def _aiter() -> AsyncGenerator[bytes, None]:
121+
for data in iterable:
122+
yield data
123+
124+
self.iter = _aiter()
125+
else:
126+
raise ValueError("unreachable?")
118127

119128
async def __aenter__(self) -> IterableBody:
120129
return self
@@ -262,7 +271,7 @@ class Response(SansIOResponse):
262271

263272
def __init__(
264273
self,
265-
response: ResponseBody | AnyStr | Iterable | None = None,
274+
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
266275
status: int | None = None,
267276
headers: dict | Headers | None = None,
268277
mimetype: str | None = None,

tests/test_templating.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
g,
1010
Quart,
1111
render_template_string,
12+
Response,
1213
ResponseReturnValue,
1314
session,
1415
stream_template_string,
@@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
148149
test_client = app.test_client()
149150
response = await test_client.get("/")
150151
assert (await response.data) == b"42"
152+
153+
@app.get("/2")
154+
async def index2() -> ResponseReturnValue:
155+
return Response(await stream_template_string("{{ config }}", config=43))
156+
157+
test_client = app.test_client()
158+
response = await test_client.get("/2")
159+
assert (await response.data) == b"43"

0 commit comments

Comments
 (0)