1
1
import os
2
2
import sys
3
+ import threading
3
4
from contextlib import contextmanager
4
- from contextvars import ContextVar
5
5
from dataclasses import dataclass
6
6
from typing import Any , Callable , Dict , Literal , Optional , Tuple
7
7
13
13
from .run import _has_output_iterator_array_type
14
14
from .version import Version
15
15
16
- __all__ = ["include" ]
16
+ __all__ = ["get_run_state" , "get_run_token" , " include" , "run_state" , "run_token " ]
17
17
18
18
19
- _RUN_STATE : ContextVar [Literal ["load" , "setup" , "run" ] | None ] = ContextVar (
20
- "run_state" ,
21
- default = None ,
22
- )
23
- _RUN_TOKEN : ContextVar [str | None ] = ContextVar ("run_token" , default = None )
19
+ _run_state : Optional [Literal ["load" , "setup" , "run" ]] = None
20
+ _run_token : Optional [str ] = None
21
+
22
+ _state_stack = []
23
+ _token_stack = []
24
+
25
+ _state_lock = threading .RLock ()
26
+ _token_lock = threading .RLock ()
27
+
28
+
29
+ def get_run_state () -> Optional [Literal ["load" , "setup" , "run" ]]:
30
+ """
31
+ Get the current run state.
32
+ """
33
+ return _run_state
34
+
35
+
36
+ def get_run_token () -> Optional [str ]:
37
+ """
38
+ Get the current API token.
39
+ """
40
+ return _run_token
24
41
25
42
26
43
@contextmanager
27
44
def run_state (state : Literal ["load" , "setup" , "run" ]) -> Any :
28
45
"""
29
- Internal context manager for execution state.
46
+ Context manager for setting the current run state.
30
47
"""
31
- s = _RUN_STATE .set (state )
48
+ global _run_state
49
+
50
+ if threading .current_thread () is not threading .main_thread ():
51
+ raise RuntimeError ("Only the main thread can modify run state" )
52
+
53
+ with _state_lock :
54
+ _state_stack .append (_run_state )
55
+
56
+ _run_state = state
57
+
32
58
try :
33
59
yield
34
60
finally :
35
- _RUN_STATE .reset (s )
61
+ with _state_lock :
62
+ _run_state = _state_stack .pop ()
36
63
37
64
38
65
@contextmanager
39
66
def run_token (token : str ) -> Any :
40
67
"""
41
- Sets the API token for the current context .
68
+ Context manager for setting the current API token .
42
69
"""
43
- t = _RUN_TOKEN .set (token )
70
+ global _run_token
71
+
72
+ if threading .current_thread () is not threading .main_thread ():
73
+ raise RuntimeError ("Only the main thread can modify API token" )
74
+
75
+ with _token_lock :
76
+ _token_stack .append (_run_token )
77
+
78
+ _run_token = token
79
+
44
80
try :
45
81
yield
46
82
finally :
47
- _RUN_TOKEN .reset (t )
83
+ with _token_lock :
84
+ _run_token = _token_stack .pop ()
48
85
49
86
50
87
def _find_api_token () -> str :
@@ -53,12 +90,11 @@ def _find_api_token() -> str:
53
90
print ("Using Replicate API token from environment" , file = sys .stderr )
54
91
return token
55
92
56
- token = _RUN_TOKEN .get ()
57
-
58
- if not token :
93
+ current_token = get_run_token ()
94
+ if current_token is None :
59
95
raise ValueError ("No run token found" )
60
96
61
- return token
97
+ return current_token
62
98
63
99
64
100
@dataclass
@@ -158,7 +194,7 @@ def include(function_ref: str) -> Callable[..., Any]:
158
194
159
195
This function can only be called at the top level.
160
196
"""
161
- if _RUN_STATE . get () != "load" :
197
+ if get_run_state () != "load" :
162
198
raise RuntimeError (
163
199
"You may only call replicate.include at the top level."
164
200
)
0 commit comments