Skip to content

Commit f98f285

Browse files
author
修杰
committed
Merge branch 'feat/add_bot_sdk' into 'master'
feat(ark): add bot chat sdk See merge request iaasng/volcengine-python-sdk!328
2 parents 0c2cac2 + 7f42795 commit f98f285

16 files changed

+452
-14
lines changed

meta.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"lasted": "1.0.88",
2+
"lasted": "1.0.89",
33
"meta_commit": "daee4996c2a9d1fb908eeaabcace8611ebae30f3"
44
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import setup, find_packages # noqa: H301
44

55
NAME = "volcengine-python-sdk"
6-
VERSION = "1.0.88"
6+
VERSION = "1.0.89"
77
# To install the library, run the following
88
#
99
# python setup.py install

volcenginesdkarkruntime/_client.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_DEFAULT_ADVISORY_REFRESH_TIMEOUT,
2323
_DEFAULT_MANDATORY_REFRESH_TIMEOUT,
2424
_DEFAULT_STS_TIMEOUT,
25+
_DEFAULT_RESOURCE_TYPE,
2526
DEFAULT_TIMEOUT
2627
)
2728
from ._streaming import Stream
@@ -31,6 +32,8 @@
3132

3233
class Ark(SyncAPIClient):
3334
chat: resources.Chat
35+
bot_chat: resources.BotChat
36+
embeddings: resources.Embeddings
3437

3538
def __init__(
3639
self,
@@ -83,6 +86,7 @@ def __init__(
8386
self._sts_token_manager: StsTokenManager | None = None
8487

8588
self.chat = resources.Chat(self)
89+
self.bot_chat = resources.BotChat(self)
8690
self.embeddings = resources.Embeddings(self)
8791
# self.tokenization = resources.Tokenization(self)
8892
# self.classification = resources.Classification(self)
@@ -94,6 +98,13 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
9498
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
9599
return self._sts_token_manager.get(endpoint_id)
96100

101+
def _get_bot_sts_token(self, bot_id: str):
102+
if self._sts_token_manager is None:
103+
if self.ak is None or self.sk is None:
104+
raise ArkAPIError("must set ak and sk before get endpoint token.")
105+
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
106+
return self._sts_token_manager.get(bot_id, resource_type="bot")
107+
97108
@property
98109
def auth_headers(self) -> dict[str, str]:
99110
api_key = self.api_key
@@ -102,6 +113,8 @@ def auth_headers(self) -> dict[str, str]:
102113

103114
class AsyncArk(AsyncAPIClient):
104115
chat: resources.AsyncChat
116+
bot_chat: resources.AsyncBotChat
117+
embeddings: resources.AsyncEmbeddings
105118

106119
def __init__(
107120
self,
@@ -153,6 +166,7 @@ def __init__(
153166
self._sts_token_manager: StsTokenManager | None = None
154167

155168
self.chat = resources.AsyncChat(self)
169+
self.bot_chat = resources.AsyncBotChat(self)
156170
self.embeddings = resources.AsyncEmbeddings(self)
157171
# self.tokenization = resources.AsyncTokenization(self)
158172
# self.classification = resources.AsyncClassification(self)
@@ -171,7 +185,6 @@ def auth_headers(self) -> dict[str, str]:
171185

172186

173187
class StsTokenManager(object):
174-
175188
# The time at which we'll attempt to refresh, but not
176189
# block if someone else is refreshing.
177190
_advisory_refresh_timeout: int = _DEFAULT_ADVISORY_REFRESH_TIMEOUT
@@ -200,13 +213,14 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
200213

201214
return self._endpoint_sts_tokens[ep][1] - time.time() < refresh_in
202215

203-
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False):
216+
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False,
217+
resource_type: str = _DEFAULT_RESOURCE_TYPE):
204218
if ttl < self._advisory_refresh_timeout * 2:
205219
raise ArkAPIError("ttl should not be under {} seconds.".format(self._advisory_refresh_timeout * 2))
206220

207221
try:
208222
api_key, expired_time = self._load_api_key(
209-
ep, ttl
223+
ep, ttl, resource_type=resource_type
210224
)
211225
self._endpoint_sts_tokens[ep] = (api_key, expired_time)
212226
except ApiException as e:
@@ -215,7 +229,7 @@ def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandat
215229
else:
216230
logging.error("load api key cause error: e={}".format(e))
217231

218-
def _refresh(self, ep: str):
232+
def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
219233
if not self._need_refresh(ep, self._advisory_refresh_timeout):
220234
return
221235

@@ -228,7 +242,7 @@ def _refresh(self, ep: str):
228242
ep, self._mandatory_refresh_timeout
229243
)
230244

231-
self._protected_refresh(ep, is_mandatory=is_mandatory_refresh)
245+
self._protected_refresh(ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type)
232246
return
233247
finally:
234248
self._refresh_lock.release()
@@ -237,16 +251,17 @@ def _refresh(self, ep: str):
237251
if not self._need_refresh(ep, self._mandatory_refresh_timeout):
238252
return
239253

240-
self._protected_refresh(ep, is_mandatory=True)
254+
self._protected_refresh(ep, is_mandatory=True, resource_type=resource_type)
241255

242-
def get(self, ep: str) -> str:
243-
self._refresh(ep)
256+
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE) -> str:
257+
self._refresh(ep, resource_type=resource_type)
244258
return self._endpoint_sts_tokens[ep][0]
245259

246-
def _load_api_key(self, ep: str, duration_seconds: int) -> Tuple[str, int]:
260+
def _load_api_key(self, ep: str, duration_seconds: int,
261+
resource_type: str = _DEFAULT_RESOURCE_TYPE) -> Tuple[str, int]:
247262
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
248263
duration_seconds=duration_seconds,
249-
resource_type="endpoint",
264+
resource_type=resource_type,
250265
resource_ids=[ep],
251266
)
252267
resp: volcenginesdkark.GetApiKeyResponse = self.api_instance.get_api_key(

volcenginesdkarkruntime/_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@
2222
_DEFAULT_MANDATORY_REFRESH_TIMEOUT = 10 * 60 # 10 min
2323
_DEFAULT_ADVISORY_REFRESH_TIMEOUT = 30 * 60 # 30 min
2424
_DEFAULT_STS_TIMEOUT = 7 * 24 * 60 * 60 # 7 days
25+
26+
_DEFAULT_RESOURCE_TYPE = "endpoint"

volcenginesdkarkruntime/_utils/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,7 @@ def _insert_sts_token(args, kwargs):
8484
default_auth_header = {"Authorization": "Bearer " + ark_client._get_endpoint_sts_token(model)}
8585
extra_headers = kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
8686
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}
87+
elif ark_client.api_key is None and model and model.startswith("bot-") and ark_client.ak and ark_client.sk:
88+
default_auth_header = {"Authorization": "Bearer " + ark_client._get_bot_sts_token(model)}
89+
extra_headers = kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
90+
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}

volcenginesdkarkruntime/resources/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from .embeddings import Embeddings, AsyncEmbeddings
33
from .tokenization import Tokenization, AsyncTokenization
44
from .classification import Classification, AsyncClassification
5+
from .bot import BotChat, AsyncBotChat
56

67
__all__ = [
78
"Chat",
9+
"BotChat",
810
"AsyncChat",
11+
"AsyncBotChat",
912
"Embeddings",
1013
"AsyncEmbeddings",
1114
"Tokenization",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .chat import BotChat, AsyncBotChat
2+
3+
__all__ = ["BotChat", "AsyncBotChat"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
from .completions import Completions, AsyncCompletions
6+
from ..._compat import cached_property
7+
from ..._resource import SyncAPIResource, AsyncAPIResource
8+
9+
__all__ = ["BotChat", "AsyncBotChat"]
10+
11+
12+
class BotChat(SyncAPIResource):
13+
@cached_property
14+
def completions(self) -> Completions:
15+
return Completions(self._client)
16+
17+
18+
class AsyncBotChat(AsyncAPIResource):
19+
@cached_property
20+
def completions(self) -> AsyncCompletions:
21+
return AsyncCompletions(self._client)

0 commit comments

Comments
 (0)