Skip to content

Commit 950ba0f

Browse files
nurikknurikk-samattt
authored
Add async_wait method to Prediction class (#225)
Use non blocking wait function --------- Signed-off-by: John Doe <[email protected]> Co-authored-by: Ainur Timerbaev <[email protected]> Co-authored-by: Mattt Zmuda <[email protected]>
1 parent cd09db0 commit 950ba0f

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

Diff for: replicate/prediction.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import re
23
import time
34
from dataclasses import dataclass
@@ -114,6 +115,7 @@ def progress(self) -> Optional[Progress]:
114115
"""
115116
The progress of the prediction, if available.
116117
"""
118+
117119
if self.logs is None or self.logs == "":
118120
return None
119121

@@ -123,10 +125,20 @@ def wait(self) -> None:
123125
"""
124126
Wait for prediction to finish.
125127
"""
128+
126129
while self.status not in ["succeeded", "failed", "canceled"]:
127130
time.sleep(self._client.poll_interval)
128131
self.reload()
129132

133+
async def async_wait(self) -> None:
134+
"""
135+
Wait for prediction to finish asynchronously.
136+
"""
137+
138+
while self.status not in ["succeeded", "failed", "canceled"]:
139+
await asyncio.sleep(self._client.poll_interval)
140+
await self.async_reload()
141+
130142
def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
131143
"""
132144
Stream the prediction output.
@@ -164,6 +176,15 @@ def reload(self) -> None:
164176
for name, value in updated.dict().items():
165177
setattr(self, name, value)
166178

179+
async def async_reload(self) -> None:
180+
"""
181+
Load this prediction from the server asynchronously.
182+
"""
183+
184+
updated = await self._client.predictions.async_get(self.id)
185+
for name, value in updated.dict().items():
186+
setattr(self, name, value)
187+
167188
def output_iterator(self) -> Iterator[Any]:
168189
"""
169190
Return an iterator of the prediction output.

Diff for: replicate/run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ async def async_run(
8080
)
8181

8282
if not version and (owner and name and version_id):
83-
version = Versions(client, model=(owner, name)).get(version_id)
83+
version = await Versions(client, model=(owner, name)).async_get(version_id)
8484

8585
if version and (iterator := _make_output_iterator(version, prediction)):
8686
return iterator
8787

88-
prediction.wait()
88+
await prediction.async_wait()
8989

9090
if prediction.status == "failed":
9191
raise ModelError(prediction.error)

0 commit comments

Comments
 (0)