1
+ import asyncio
1
2
import re
2
3
import time
3
4
from dataclasses import dataclass
@@ -114,6 +115,7 @@ def progress(self) -> Optional[Progress]:
114
115
"""
115
116
The progress of the prediction, if available.
116
117
"""
118
+
117
119
if self .logs is None or self .logs == "" :
118
120
return None
119
121
@@ -123,10 +125,20 @@ def wait(self) -> None:
123
125
"""
124
126
Wait for prediction to finish.
125
127
"""
128
+
126
129
while self .status not in ["succeeded" , "failed" , "canceled" ]:
127
130
time .sleep (self ._client .poll_interval )
128
131
self .reload ()
129
132
133
+ async def async_wait (self ) -> None :
134
+ """
135
+ Wait for prediction to finish asynchronously.
136
+ """
137
+
138
+ while self .status not in ["succeeded" , "failed" , "canceled" ]:
139
+ await asyncio .sleep (self ._client .poll_interval )
140
+ await self .async_reload ()
141
+
130
142
def stream (self ) -> Optional [Iterator ["ServerSentEvent" ]]:
131
143
"""
132
144
Stream the prediction output.
@@ -164,6 +176,15 @@ def reload(self) -> None:
164
176
for name , value in updated .dict ().items ():
165
177
setattr (self , name , value )
166
178
179
+ async def async_reload (self ) -> None :
180
+ """
181
+ Load this prediction from the server asynchronously.
182
+ """
183
+
184
+ updated = await self ._client .predictions .async_get (self .id )
185
+ for name , value in updated .dict ().items ():
186
+ setattr (self , name , value )
187
+
167
188
def output_iterator (self ) -> Iterator [Any ]:
168
189
"""
169
190
Return an iterator of the prediction output.
0 commit comments