Skip to content

Commit 01bbb14

Browse files
committed
Allow threaded code within include context managers
1 parent 39c1bc5 commit 01bbb14

File tree

2 files changed

+102
-18
lines changed

2 files changed

+102
-18
lines changed

Diff for: replicate/include.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
3+
import threading
34
from contextlib import contextmanager
4-
from contextvars import ContextVar
55
from dataclasses import dataclass
66
from typing import Any, Callable, Dict, Literal, Optional, Tuple
77

@@ -13,38 +13,75 @@
1313
from .run import _has_output_iterator_array_type
1414
from .version import Version
1515

16-
__all__ = ["include"]
16+
__all__ = ["get_run_state", "get_run_token", "include", "run_state", "run_token"]
1717

1818

19-
_RUN_STATE: ContextVar[Literal["load", "setup", "run"] | None] = ContextVar(
20-
"run_state",
21-
default=None,
22-
)
23-
_RUN_TOKEN: ContextVar[str | None] = ContextVar("run_token", default=None)
19+
_run_state: Optional[Literal["load", "setup", "run"]] = None
20+
_run_token: Optional[str] = None
21+
22+
_state_stack = []
23+
_token_stack = []
24+
25+
_state_lock = threading.RLock()
26+
_token_lock = threading.RLock()
27+
28+
29+
def get_run_state() -> Optional[Literal["load", "setup", "run"]]:
30+
"""
31+
Get the current run state.
32+
"""
33+
return _run_state
34+
35+
36+
def get_run_token() -> Optional[str]:
37+
"""
38+
Get the current API token.
39+
"""
40+
return _run_token
2441

2542

2643
@contextmanager
2744
def run_state(state: Literal["load", "setup", "run"]) -> Any:
2845
"""
29-
Internal context manager for execution state.
46+
Context manager for setting the current run state.
3047
"""
31-
s = _RUN_STATE.set(state)
48+
global _run_state
49+
50+
if threading.current_thread() is not threading.main_thread():
51+
raise RuntimeError("Only the main thread can modify run state")
52+
53+
with _state_lock:
54+
_state_stack.append(_run_state)
55+
56+
_run_state = state
57+
3258
try:
3359
yield
3460
finally:
35-
_RUN_STATE.reset(s)
61+
with _state_lock:
62+
_run_state = _state_stack.pop()
3663

3764

3865
@contextmanager
3966
def run_token(token: str) -> Any:
4067
"""
41-
Sets the API token for the current context.
68+
Context manager for setting the current API token.
4269
"""
43-
t = _RUN_TOKEN.set(token)
70+
global _run_token
71+
72+
if threading.current_thread() is not threading.main_thread():
73+
raise RuntimeError("Only the main thread can modify API token")
74+
75+
with _token_lock:
76+
_token_stack.append(_run_token)
77+
78+
_run_token = token
79+
4480
try:
4581
yield
4682
finally:
47-
_RUN_TOKEN.reset(t)
83+
with _token_lock:
84+
_run_token = _token_stack.pop()
4885

4986

5087
def _find_api_token() -> str:
@@ -53,12 +90,11 @@ def _find_api_token() -> str:
5390
print("Using Replicate API token from environment", file=sys.stderr)
5491
return token
5592

56-
token = _RUN_TOKEN.get()
57-
58-
if not token:
93+
current_token = get_run_token()
94+
if current_token is None:
5995
raise ValueError("No run token found")
6096

61-
return token
97+
return current_token
6298

6399

64100
@dataclass
@@ -158,7 +194,7 @@ def include(function_ref: str) -> Callable[..., Any]:
158194
159195
This function can only be called at the top level.
160196
"""
161-
if _RUN_STATE.get() != "load":
197+
if get_run_state() != "load":
162198
raise RuntimeError(
163199
"You may only call replicate.include at the top level."
164200
)

Diff for: tests/test_include.py

+48
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import threading
23
import unittest.mock as mock
34

45
import pytest
@@ -7,6 +8,8 @@
78
from replicate.include import (
89
Function,
910
Run,
11+
get_run_state,
12+
get_run_token,
1013
include,
1114
run_state,
1215
run_token,
@@ -291,3 +294,48 @@ def test_run_logs(prediction, version):
291294

292295
prediction.reload.assert_called_once()
293296
assert logs == "log content"
297+
298+
299+
def test_thread_safety_concepts():
300+
with run_state("load"), run_token("test-token"):
301+
assert get_run_state() == "load"
302+
assert get_run_token() == "test-token"
303+
304+
results = []
305+
306+
def worker_thread_fn():
307+
thread_sees_state = get_run_state() == "load"
308+
thread_sees_token = get_run_token() == "test-token"
309+
310+
can_modify = True
311+
try:
312+
with run_state("setup"):
313+
pass
314+
can_modify = True
315+
except RuntimeError:
316+
can_modify = False
317+
318+
results.append(
319+
{
320+
"reads_state": thread_sees_state,
321+
"reads_token": thread_sees_token,
322+
"can_modify": can_modify,
323+
}
324+
)
325+
326+
threads = []
327+
for _ in range(3):
328+
t = threading.Thread(target=worker_thread_fn)
329+
threads.append(t)
330+
t.start()
331+
332+
for t in threads:
333+
t.join()
334+
335+
for result in results:
336+
assert result["reads_state"]
337+
assert result["reads_token"]
338+
assert not result["can_modify"]
339+
340+
assert get_run_state() is None
341+
assert get_run_token() is None

0 commit comments

Comments
 (0)