Skip to content

Commit e9f4e12

Browse files
authored
Add foundational tests (#60)
1 parent 3f56e60 commit e9f4e12

File tree

11 files changed

+238
-10
lines changed

11 files changed

+238
-10
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ classifiers = [
3535
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
3636
]
3737

38+
[project.optional-dependencies]
39+
tests = [
40+
"tox",
41+
"pytest",
42+
"pytest-sugar",
43+
"pytest-asyncio",
44+
"pytest-httpx"
45+
]
46+
3847
[project.scripts]
3948
shor = "shelloracle.__main__:main"
4049

src/shelloracle/providers/localai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import AsyncIterator
22

33
from openai import APIError
4-
from openai import AsyncOpenAI as OpenAIClient
4+
from openai import AsyncOpenAI
55

66
from . import Provider, ProviderError, Setting, system_prompt
77

@@ -19,7 +19,7 @@ def endpoint(self) -> str:
1919

2020
def __init__(self):
2121
# Use a placeholder API key so the client will work
22-
self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint)
22+
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
2323

2424
async def generate(self, prompt: str) -> AsyncIterator[str]:
2525
try:

src/shelloracle/providers/openai.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections.abc import AsyncIterator
22

3-
from openai import APIError
4-
from openai import AsyncOpenAI as OpenAIClient
3+
from openai import AsyncOpenAI, APIError
54

65
from . import Provider, ProviderError, Setting, system_prompt
76

@@ -15,7 +14,7 @@ class OpenAI(Provider):
1514
def __init__(self):
1615
if not self.api_key:
1716
raise ProviderError("No API key provided")
18-
self.client = OpenAIClient(api_key=self.api_key)
17+
self.client = AsyncOpenAI(api_key=self.api_key)
1918

2019
async def generate(self, prompt: str) -> AsyncIterator[str]:
2120
try:

tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
import tomlkit
3+
4+
from shelloracle.config import Configuration
5+
6+
7+
@pytest.fixture(autouse=True)
8+
def tmp_shelloracle_home(monkeypatch, tmp_path):
9+
monkeypatch.setattr("shelloracle.settings.Settings.shelloracle_home", tmp_path)
10+
return tmp_path
11+
12+
13+
@pytest.fixture
14+
def set_config(monkeypatch, tmp_shelloracle_home):
15+
config_path = tmp_shelloracle_home / "config.toml"
16+
17+
def _set_config(config: dict) -> Configuration:
18+
with config_path.open("w") as f:
19+
tomlkit.dump(config, f)
20+
configuration = Configuration(config_path)
21+
monkeypatch.setattr("shelloracle.config._config", configuration)
22+
23+
yield _set_config
24+
25+
config_path.unlink()
26+

tests/providers/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
6+
def split_with_delimiter(string, delim):
7+
result = []
8+
last_split = 0
9+
for index, character in enumerate(string):
10+
if character == delim:
11+
result.append(string[last_split:index + 1])
12+
last_split = index + 1
13+
if last_split != len(string):
14+
result.append(string[last_split:])
15+
return result
16+
17+
18+
@pytest.fixture
19+
def mock_asyncopenai(monkeypatch):
20+
class AsyncChatCompletionIterator:
21+
def __init__(self, answer: str):
22+
self.answer_index = 0
23+
self.answer_deltas = split_with_delimiter(answer, " ")
24+
25+
def __aiter__(self):
26+
return self
27+
28+
async def __anext__(self):
29+
if self.answer_index < len(self.answer_deltas):
30+
answer_chunk = self.answer_deltas[self.answer_index]
31+
self.answer_index += 1
32+
choice = MagicMock()
33+
choice.delta.content = answer_chunk
34+
chunk = MagicMock()
35+
chunk.choices = [choice]
36+
return chunk
37+
else:
38+
raise StopAsyncIteration
39+
40+
async def mock_acreate(*args, **kwargs):
41+
return AsyncChatCompletionIterator("head -c 100 /dev/urandom | hexdump -C")
42+
43+
monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate)

tests/providers/test_localai.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
3+
from shelloracle.providers.localai import LocalAI
4+
5+
6+
class TestOpenAI:
7+
@pytest.fixture
8+
def localai_config(self, set_config):
9+
config = {'shelloracle': {'provider': 'LocalAI'}, 'provider': {
10+
'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'mistral-openorca'}}}
11+
set_config(config)
12+
13+
@pytest.fixture
14+
def localai_instance(self, localai_config):
15+
return LocalAI()
16+
17+
def test_name(self):
18+
assert LocalAI.name == "LocalAI"
19+
20+
def test_model(self, localai_instance):
21+
assert localai_instance.model == "mistral-openorca"
22+
23+
@pytest.mark.asyncio
24+
async def test_generate(self, mock_asyncopenai, localai_instance):
25+
result = ""
26+
async for response in localai_instance.generate(""):
27+
result += response
28+
assert result == "head -c 100 /dev/urandom | hexdump -C"

tests/providers/test_ollama.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,43 @@
1+
import pytest
2+
from pytest_httpx import IteratorStream
3+
14
from shelloracle.providers.ollama import Ollama
25

36

4-
def test_name():
5-
assert Ollama.name == "Ollama"
7+
class TestOllama:
8+
@pytest.fixture
9+
def ollama_config(self, set_config):
10+
config = {'shelloracle': {'provider': 'Ollama'},
11+
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
12+
set_config(config)
13+
14+
@pytest.fixture
15+
def ollama_instance(self, ollama_config):
16+
return Ollama()
17+
18+
def test_name(self):
19+
assert Ollama.name == "Ollama"
20+
21+
def test_host(self, ollama_instance):
22+
assert ollama_instance.host == "localhost"
23+
24+
def test_port(self, ollama_instance):
25+
assert ollama_instance.port == 11434
26+
27+
def test_model(self, ollama_instance):
28+
assert ollama_instance.model == "dolphin-mistral"
29+
30+
def test_endpoint(self, ollama_instance):
31+
assert ollama_instance.endpoint == "http://localhost:11434/api/generate"
32+
33+
@pytest.mark.asyncio
34+
async def test_generate(self, ollama_instance, httpx_mock):
35+
responses = [
36+
b'{"response": "cat"}\n', b'{"response": " test"}\n', b'{"response": "."}\n', b'{"response": "py"}\n',
37+
b'{"response": ""}\n'
38+
]
39+
httpx_mock.add_response(stream=IteratorStream(responses))
40+
result = ""
41+
async for response in ollama_instance.generate(""):
42+
result += response
43+
assert result == "cat test.py"

tests/providers/test_openai.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
1+
import pytest
2+
13
from shelloracle.providers.openai import OpenAI
24

35

4-
def test_name():
5-
assert OpenAI.name == "OpenAI"
6+
class TestOpenAI:
7+
@pytest.fixture
8+
def openai_config(self, set_config):
9+
config = {'shelloracle': {'provider': 'OpenAI'}, 'provider': {
10+
'OpenAI': {'api_key': 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'model': 'gpt-3.5-turbo'}}}
11+
set_config(config)
12+
13+
@pytest.fixture
14+
def openai_instance(self, openai_config):
15+
return OpenAI()
16+
17+
def test_name(self):
18+
assert OpenAI.name == "OpenAI"
19+
20+
def test_api_key(self, openai_instance):
21+
assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
22+
23+
def test_model(self, openai_instance):
24+
assert openai_instance.model == "gpt-3.5-turbo"
25+
26+
@pytest.mark.asyncio
27+
async def test_generate(self, mock_asyncopenai, openai_instance):
28+
result = ""
29+
async for response in openai_instance.generate(""):
30+
result += response
31+
assert result == "head -c 100 /dev/urandom | hexdump -C"

tests/test_config.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from shelloracle.config import get_config, initialize_config
6+
7+
8+
class TestConfiguration:
9+
@pytest.fixture
10+
def default_config(self, set_config):
11+
config = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'earth'},
12+
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
13+
set_config(config)
14+
return config
15+
16+
def test_initialize_config(self, default_config):
17+
with pytest.raises(RuntimeError):
18+
initialize_config()
19+
20+
def test_from_file(self, default_config):
21+
assert get_config() == default_config
22+
23+
def test_getitem(self, default_config):
24+
for key in default_config:
25+
assert default_config[key] == get_config()[key]
26+
27+
def test_len(self, default_config):
28+
assert len(default_config) == len(get_config())
29+
30+
def test_iter(self, default_config):
31+
assert list(iter(default_config)) == list(iter(get_config()))
32+
33+
def test_str(self, default_config):
34+
assert str(get_config()) == f"Configuration({default_config})"
35+
36+
def test_repr(self, default_config):
37+
assert repr(default_config) == str(default_config)
38+
39+
def test_provider(self, default_config):
40+
assert get_config().provider == "Ollama"
41+
42+
def test_spinner_style(self, default_config):
43+
assert get_config().spinner_style == "earth"
44+
45+
def test_no_spinner_style(self, caplog, set_config):
46+
config_dict = {'shelloracle': {'provider': 'Ollama'},
47+
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
48+
set_config(config_dict)
49+
assert get_config().spinner_style is None
50+
assert "invalid spinner style" not in caplog.text
51+
52+
def test_invalid_spinner_style(self, caplog, set_config):
53+
config_dict = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'invalid'},
54+
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
55+
set_config(config_dict)
56+
assert get_config().spinner_style is None
57+
assert "invalid spinner style" in caplog.text

tests/test_shelloracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import sys
5-
from unittest.mock import MagicMock, call
5+
from unittest.mock import call, MagicMock
66

77
import pytest
88
from yaspin.spinners import Spinners

0 commit comments

Comments
 (0)