|
9 | 9 |
|
10 | 10 | import httpx |
11 | 11 |
|
12 | | -from ._utils import extract_type_var_from_base |
| 12 | +from ._utils import is_mapping, extract_type_var_from_base |
| 13 | +from ._exceptions import APIError |
13 | 14 |
|
14 | 15 | if TYPE_CHECKING: |
15 | 16 | from ._client import GradientAI, AsyncGradientAI |
@@ -55,7 +56,25 @@ def __stream__(self) -> Iterator[_T]: |
55 | 56 | iterator = self._iter_events() |
56 | 57 |
|
57 | 58 | for sse in iterator: |
58 | | - yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
| 59 | + if sse.data.startswith("[DONE]"): |
| 60 | + break |
| 61 | + |
| 62 | + data = sse.json() |
| 63 | + if is_mapping(data) and data.get("error"): |
| 64 | + message = None |
| 65 | + error = data.get("error") |
| 66 | + if is_mapping(error): |
| 67 | + message = error.get("message") |
| 68 | + if not message or not isinstance(message, str): |
| 69 | + message = "An error occurred during streaming" |
| 70 | + |
| 71 | + raise APIError( |
| 72 | + message=message, |
| 73 | + request=self.response.request, |
| 74 | + body=data["error"], |
| 75 | + ) |
| 76 | + |
| 77 | + yield process_data(data=data, cast_to=cast_to, response=response) |
59 | 78 |
|
60 | 79 | # Ensure the entire stream is consumed |
61 | 80 | for _sse in iterator: |
@@ -119,7 +138,25 @@ async def __stream__(self) -> AsyncIterator[_T]: |
119 | 138 | iterator = self._iter_events() |
120 | 139 |
|
121 | 140 | async for sse in iterator: |
122 | | - yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
| 141 | + if sse.data.startswith("[DONE]"): |
| 142 | + break |
| 143 | + |
| 144 | + data = sse.json() |
| 145 | + if is_mapping(data) and data.get("error"): |
| 146 | + message = None |
| 147 | + error = data.get("error") |
| 148 | + if is_mapping(error): |
| 149 | + message = error.get("message") |
| 150 | + if not message or not isinstance(message, str): |
| 151 | + message = "An error occurred during streaming" |
| 152 | + |
| 153 | + raise APIError( |
| 154 | + message=message, |
| 155 | + request=self.response.request, |
| 156 | + body=data["error"], |
| 157 | + ) |
| 158 | + |
| 159 | + yield process_data(data=data, cast_to=cast_to, response=response) |
123 | 160 |
|
124 | 161 | # Ensure the entire stream is consumed |
125 | 162 | async for _sse in iterator: |
|
0 commit comments