Skip to content

Commit a24b10d

Browse files
authored
add xai grok llm provider (#75)
1 parent cf620d5 commit a24b10d

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

src/shelloracle/providers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def _providers() -> dict[str, type[Provider]]:
7777
from shelloracle.providers.localai import LocalAI
7878
from shelloracle.providers.ollama import Ollama
7979
from shelloracle.providers.openai import OpenAI
80+
from shelloracle.providers.xai import XAI
8081

81-
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI}
82+
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI}
8283

8384

8485
def get_provider(name: str) -> type[Provider]:

src/shelloracle/providers/xai.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from collections.abc import AsyncIterator
2+
3+
from openai import APIError, AsyncOpenAI
4+
5+
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
6+
7+
8+
class XAI(Provider):
9+
name = "XAI"
10+
11+
api_key = Setting(default="")
12+
model = Setting(default="grok-beta")
13+
14+
def __init__(self):
15+
if not self.api_key:
16+
msg = "No API key provided"
17+
raise ProviderError(msg)
18+
self.client = AsyncOpenAI(
19+
api_key=self.api_key,
20+
base_url="https://api.x.ai/v1",
21+
)
22+
23+
async def generate(self, prompt: str) -> AsyncIterator[str]:
24+
try:
25+
stream = await self.client.chat.completions.create(
26+
model=self.model,
27+
messages=[
28+
{"role": "system", "content": system_prompt},
29+
{"role": "user", "content": prompt},
30+
],
31+
stream=True,
32+
)
33+
async for chunk in stream:
34+
if chunk.choices[0].delta.content is not None:
35+
yield chunk.choices[0].delta.content
36+
except APIError as e:
37+
msg = f"Something went wrong while querying XAI: {e}"
38+
raise ProviderError(msg) from e

tests/providers/test_xai.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
from shelloracle.providers.xai import XAI
4+
5+
6+
class TestOpenAI:
7+
@pytest.fixture
8+
def xai_config(self, set_config):
9+
config = {
10+
"shelloracle": {"provider": "XAI"},
11+
"provider": {
12+
"XAI": {
13+
"api_key": "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
14+
"model": "grok-beta",
15+
}
16+
},
17+
}
18+
set_config(config)
19+
20+
@pytest.fixture
21+
def xai_instance(self, xai_config):
22+
return XAI()
23+
24+
def test_name(self):
25+
assert XAI.name == "XAI"
26+
27+
def test_api_key(self, xai_instance):
28+
assert (
29+
xai_instance.api_key
30+
== "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
31+
)
32+
33+
def test_model(self, xai_instance):
34+
assert xai_instance.model == "grok-beta"
35+
36+
@pytest.mark.asyncio
37+
async def test_generate(self, mock_asyncopenai, xai_instance):
38+
result = ""
39+
async for response in xai_instance.generate(""):
40+
result += response
41+
assert result == "head -c 100 /dev/urandom | hexdump -C"

0 commit comments

Comments
 (0)