22
22
_DEFAULT_ADVISORY_REFRESH_TIMEOUT ,
23
23
_DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
24
24
_DEFAULT_STS_TIMEOUT ,
25
+ _DEFAULT_RESOURCE_TYPE ,
25
26
DEFAULT_TIMEOUT
26
27
)
27
28
from ._streaming import Stream
31
32
32
33
class Ark (SyncAPIClient ):
33
34
chat : resources .Chat
35
+ bot_chat : resources .BotChat
36
+ embeddings : resources .Embeddings
34
37
35
38
def __init__ (
36
39
self ,
@@ -83,6 +86,7 @@ def __init__(
83
86
self ._sts_token_manager : StsTokenManager | None = None
84
87
85
88
self .chat = resources .Chat (self )
89
+ self .bot_chat = resources .BotChat (self )
86
90
self .embeddings = resources .Embeddings (self )
87
91
# self.tokenization = resources.Tokenization(self)
88
92
# self.classification = resources.Classification(self)
@@ -94,6 +98,13 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
94
98
self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self .region )
95
99
return self ._sts_token_manager .get (endpoint_id )
96
100
101
+ def _get_bot_sts_token (self , bot_id : str ):
102
+ if self ._sts_token_manager is None :
103
+ if self .ak is None or self .sk is None :
104
+ raise ArkAPIError ("must set ak and sk before get endpoint token." )
105
+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self .region )
106
+ return self ._sts_token_manager .get (bot_id , resource_type = "bot" )
107
+
97
108
@property
98
109
def auth_headers (self ) -> dict [str , str ]:
99
110
api_key = self .api_key
@@ -102,6 +113,8 @@ def auth_headers(self) -> dict[str, str]:
102
113
103
114
class AsyncArk (AsyncAPIClient ):
104
115
chat : resources .AsyncChat
116
+ bot_chat : resources .AsyncBotChat
117
+ embeddings : resources .AsyncEmbeddings
105
118
106
119
def __init__ (
107
120
self ,
@@ -153,6 +166,7 @@ def __init__(
153
166
self ._sts_token_manager : StsTokenManager | None = None
154
167
155
168
self .chat = resources .AsyncChat (self )
169
+ self .bot_chat = resources .AsyncBotChat (self )
156
170
self .embeddings = resources .AsyncEmbeddings (self )
157
171
# self.tokenization = resources.AsyncTokenization(self)
158
172
# self.classification = resources.AsyncClassification(self)
@@ -171,7 +185,6 @@ def auth_headers(self) -> dict[str, str]:
171
185
172
186
173
187
class StsTokenManager (object ):
174
-
175
188
# The time at which we'll attempt to refresh, but not
176
189
# block if someone else is refreshing.
177
190
_advisory_refresh_timeout : int = _DEFAULT_ADVISORY_REFRESH_TIMEOUT
@@ -200,13 +213,14 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
200
213
201
214
return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
202
215
203
- def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ):
216
+ def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ,
217
+ resource_type : str = _DEFAULT_RESOURCE_TYPE ):
204
218
if ttl < self ._advisory_refresh_timeout * 2 :
205
219
raise ArkAPIError ("ttl should not be under {} seconds." .format (self ._advisory_refresh_timeout * 2 ))
206
220
207
221
try :
208
222
api_key , expired_time = self ._load_api_key (
209
- ep , ttl
223
+ ep , ttl , resource_type = resource_type
210
224
)
211
225
self ._endpoint_sts_tokens [ep ] = (api_key , expired_time )
212
226
except ApiException as e :
@@ -215,7 +229,7 @@ def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandat
215
229
else :
216
230
logging .error ("load api key cause error: e={}" .format (e ))
217
231
218
- def _refresh (self , ep : str ):
232
+ def _refresh (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ):
219
233
if not self ._need_refresh (ep , self ._advisory_refresh_timeout ):
220
234
return
221
235
@@ -228,7 +242,7 @@ def _refresh(self, ep: str):
228
242
ep , self ._mandatory_refresh_timeout
229
243
)
230
244
231
- self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh )
245
+ self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type )
232
246
return
233
247
finally :
234
248
self ._refresh_lock .release ()
@@ -237,16 +251,17 @@ def _refresh(self, ep: str):
237
251
if not self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
238
252
return
239
253
240
- self ._protected_refresh (ep , is_mandatory = True )
254
+ self ._protected_refresh (ep , is_mandatory = True , resource_type = resource_type )
241
255
242
- def get (self , ep : str ) -> str :
243
- self ._refresh (ep )
256
+ def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> str :
257
+ self ._refresh (ep , resource_type = resource_type )
244
258
return self ._endpoint_sts_tokens [ep ][0 ]
245
259
246
- def _load_api_key (self , ep : str , duration_seconds : int ) -> Tuple [str , int ]:
260
+ def _load_api_key (self , ep : str , duration_seconds : int ,
261
+ resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> Tuple [str , int ]:
247
262
get_api_key_request = volcenginesdkark .GetApiKeyRequest (
248
263
duration_seconds = duration_seconds ,
249
- resource_type = "endpoint" ,
264
+ resource_type = resource_type ,
250
265
resource_ids = [ep ],
251
266
)
252
267
resp : volcenginesdkark .GetApiKeyResponse = self .api_instance .get_api_key (
0 commit comments