1
- from collections .abc import AsyncIterator
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
2
4
3
5
from openai import APIError , AsyncOpenAI
4
6
5
7
from shelloracle .providers import Provider , ProviderError , Setting , system_prompt
6
8
9
+ if TYPE_CHECKING :
10
+ from collections .abc import AsyncIterator
11
+
12
+ from shelloracle .config import Configuration
13
+
7
14
8
15
class Deepseek (Provider ):
9
16
name = "Deepseek"
10
17
11
18
api_key = Setting (default = "" )
12
19
model = Setting (default = "deepseek-chat" )
13
20
14
- def __init__ (self , * args , ** kwargs ) :
15
- super (). __init__ ( * args , ** kwargs )
21
+ def __init__ (self , config : Configuration ) -> None :
22
+ self . config = config
16
23
if not self .api_key :
17
24
msg = "No API key provided"
18
25
raise ProviderError (msg )
@@ -22,7 +29,10 @@ async def generate(self, prompt: str) -> AsyncIterator[str]:
22
29
try :
23
30
stream = await self .client .chat .completions .create (
24
31
model = self .model ,
25
- messages = [{"role" : "system" , "content" : system_prompt }, {"role" : "user" , "content" : prompt }],
32
+ messages = [
33
+ {"role" : "system" , "content" : system_prompt },
34
+ {"role" : "user" , "content" : prompt },
35
+ ],
26
36
stream = True ,
27
37
)
28
38
async for chunk in stream :
0 commit comments