Skip to content

Commit 96d287d

Browse files
committed
Fix output Iterator[Path] case
1 parent e5c25b9 commit 96d287d

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

internal/server/runner.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type PendingPrediction struct {
3333
response PredictionResponse
3434
lastUpdated time.Time
3535
inputPaths []string
36+
outputCache map[string]string
3637
mu sync.Mutex
3738
c chan PredictionResponse
3839
}
@@ -250,9 +251,10 @@ func (r *Runner) Predict(req PredictionRequest) (chan PredictionResponse, error)
250251
StartedAt: req.StartedAt,
251252
}
252253
pr := PendingPrediction{
253-
request: req,
254-
response: resp,
255-
inputPaths: inputPaths,
254+
request: req,
255+
response: resp,
256+
inputPaths: inputPaths,
257+
outputCache: make(map[string]string),
256258
}
257259
if req.Webhook == "" {
258260
pr.c = make(chan PredictionResponse, 1)
@@ -515,7 +517,23 @@ func (r *Runner) handleResponses() {
515517
} else if r.uploadUrl != "" {
516518
outputFn = outputToUpload(r.uploadUrl, pr.response.Id)
517519
}
518-
if output, err := handlePath(pr.response.Output, &paths, outputFn); err != nil {
520+
cachedOutputFn := func(s string, paths *[]string) (string, error) {
521+
// Cache already handled output files to avoid duplicates or deleted files in Iterator[Path]
522+
if cache, ok := pr.outputCache[s]; ok {
523+
return cache, nil
524+
}
525+
o, err := outputFn(s, paths)
526+
if err != nil {
527+
return "", err
528+
}
529+
if o != s {
530+
// Output path converted to base64 or upload URL, cache it
531+
pr.outputCache[s] = o
532+
}
533+
return o, nil
534+
}
535+
536+
if output, err := handlePath(pr.response.Output, &paths, cachedOutputFn); err != nil {
519537
log.Errorw("failed to handle output path", "id", pid, "error", err)
520538
pr.response.Status = PredictionFailed
521539
pr.response.Error = fmt.Sprintf("failed to handle output path: %s", err)

internal/tests/path_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,41 @@ func TestPredictionPathUploadUrlSucceeded(t *testing.T) {
144144
assert.NoError(t, ct.Cleanup())
145145
}
146146

147+
func TestPredictionPathUploadIterator(t *testing.T) {
148+
ct := NewCogTest(t, "path_out_iter")
149+
ct.StartWebhook()
150+
ct.AppendArgs(fmt.Sprintf("--upload-url=http://localhost:%d/upload/", ct.webhookPort))
151+
assert.NoError(t, ct.Start())
152+
153+
hc := ct.WaitForSetup()
154+
assert.Equal(t, server.StatusReady.String(), hc.Status)
155+
assert.Equal(t, server.SetupSucceeded, hc.Setup.Status)
156+
157+
ct.AsyncPrediction(map[string]any{"n": 3})
158+
wr := ct.WaitForWebhookCompletion()
159+
ul := ct.GetUploads()
160+
161+
assert.Len(t, wr, 5)
162+
assert.Equal(t, server.PredictionProcessing, wr[0].Status)
163+
assert.Nil(t, wr[0].Output)
164+
assert.Equal(t, server.PredictionProcessing, wr[1].Status)
165+
assert.Len(t, wr[1].Output.([]any), 1)
166+
assert.Equal(t, server.PredictionProcessing, wr[2].Status)
167+
assert.Len(t, wr[2].Output.([]any), 2)
168+
assert.Equal(t, server.PredictionProcessing, wr[3].Status)
169+
assert.Len(t, wr[3].Output.([]any), 3)
170+
assert.Equal(t, server.PredictionSucceeded, wr[4].Status)
171+
assert.Len(t, wr[4].Output.([]any), 3)
172+
173+
assert.Len(t, ul, 3)
174+
assert.Equal(t, "out0", string(ul[0].Body))
175+
assert.Equal(t, "out1", string(ul[1].Body))
176+
assert.Equal(t, "out2", string(ul[2].Body))
177+
178+
ct.Shutdown()
179+
assert.NoError(t, ct.Cleanup())
180+
}
181+
147182
const TestDataPrefix = "https://raw.githubusercontent.com/gabriel-vasile/mimetype/refs/heads/master/testdata/"
148183

149184
func TestPredictionPathMimeTypes(t *testing.T) {

python/tests/runners/path_out_iter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import tempfile
2+
import time
3+
from typing import Iterator
4+
5+
from cog import BasePredictor, Path
6+
7+
8+
class Predictor(BasePredictor):
9+
test_inputs = {'n': 2}
10+
11+
def predict(self, n: int) -> Iterator[Path]:
12+
for i in range(n):
13+
time.sleep(1)
14+
with tempfile.NamedTemporaryFile(
15+
mode='w', suffix='.txt', delete=False
16+
) as f:
17+
f.write(f'out{i}')
18+
yield Path(f.name)

0 commit comments

Comments
 (0)