Skip to content

Commit c2d5b9b

Browse files
committed
fix python version incompatibilities
1 parent 336c660 commit c2d5b9b

17 files changed

+2160
-46
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ jobs:
3131
with:
3232
python-version: ${{ matrix.py }}
3333
- name: Run test suite
34-
run: hatch test
34+
run: hatch test -py=${{ matrix.py }}

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,5 @@ cython_debug/
162162
demo/
163163

164164
# Experimental scripts
165-
experiments/
165+
experiments/
166+
/.jj/

pyproject.toml

+20-8
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,26 @@ classifiers = [
2828
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
2929
]
3030
dependencies = [
31-
"httpx",
32-
"openai",
33-
"prompt-toolkit",
34-
"yaspin",
35-
"tomlkit",
36-
"tomli >= 1.1.0; python_version < '3.11'",
37-
"google-generativeai",
38-
"click>=8.1.8",
31+
"tomli ~= 2.1; python_version < '3.11'",
32+
"click~=8.1",
33+
"dspy~=2.6",
34+
"httpx~=0.28",
35+
"openai~=1.66",
36+
"prompt-toolkit~=3.0",
37+
"yaspin~=3.1",
38+
"tomlkit~=0.13",
39+
"google-generativeai~=0.8",
40+
]
41+
42+
[dependency-groups]
43+
test = [
44+
"pytest~=8.3",
45+
"pytest-cov~=5.0",
46+
"pytest-sugar~=1.0",
47+
"pytest-xdist~=3.6",
48+
"pytest-asyncio~=0.24",
49+
"pytest-httpx~=0.30",
50+
"tomli-w~=1.2.0",
3951
]
4052

4153
[project.scripts]

src/shelloracle/cli/__init__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import logging
24
import sys
3-
from pathlib import Path
5+
from typing import TYPE_CHECKING
46

57
import click
68

@@ -10,6 +12,9 @@
1012
from shelloracle.config import Configuration
1113
from shelloracle.tty_log_handler import TtyLogHandler
1214

15+
if TYPE_CHECKING:
16+
from pathlib import Path
17+
1318
logger = logging.getLogger(__name__)
1419

1520

@@ -34,7 +39,7 @@ def configure_logging(log_path: Path):
3439
@click.group(invoke_without_command=True)
3540
@click.version_option()
3641
@click.pass_context
37-
def cli(ctx):
42+
def cli(ctx: click.Context):
3843
"""ShellOracle command line interface."""
3944
app = Application()
4045
configure_logging(app.log_path)

src/shelloracle/cli/application.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from __future__ import annotations
2+
3+
import typing
14
from pathlib import Path
25

3-
from shelloracle.config import Configuration
6+
if typing.TYPE_CHECKING:
7+
from shelloracle.config import Configuration
48

59
shelloracle_home = Path.home() / ".shelloracle"
610
shelloracle_home.mkdir(exist_ok=True)

src/shelloracle/cli/config/edit.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
15
import click
26

3-
from shelloracle.cli.application import Application
7+
if TYPE_CHECKING:
8+
from shelloracle.cli.application import Application
49

510

611
@click.command()

src/shelloracle/cli/config/init.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
15
import click
26

3-
from shelloracle.cli import Application
7+
if TYPE_CHECKING:
8+
from shelloracle.cli import Application
49

510

611
@click.command()

src/shelloracle/providers/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ class Provider(Protocol):
3333
config: Configuration
3434

3535
def __init__(self, config: Configuration) -> None:
36-
self.config = config
36+
"""Initialize the provider with the given configuration.
37+
38+
:param config: the configuration object
39+
:return: none
40+
"""
3741

3842
@abstractmethod
3943
def generate(self, prompt: str) -> AsyncIterator[str]:

src/shelloracle/providers/deepseek.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
from openai import APIError, AsyncOpenAI
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class Deepseek(Provider):
916
name = "Deepseek"
1017

1118
api_key = Setting(default="")
1219
model = Setting(default="deepseek-chat")
1320

14-
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
21+
def __init__(self, config: Configuration) -> None:
22+
self.config = config
1623
if not self.api_key:
1724
msg = "No API key provided"
1825
raise ProviderError(msg)
@@ -22,7 +29,10 @@ async def generate(self, prompt: str) -> AsyncIterator[str]:
2229
try:
2330
stream = await self.client.chat.completions.create(
2431
model=self.model,
25-
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
32+
messages=[
33+
{"role": "system", "content": system_prompt},
34+
{"role": "user", "content": prompt},
35+
],
2636
stream=True,
2737
)
2838
async for chunk in stream:

src/shelloracle/providers/google.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
import google.generativeai as genai
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class Google(Provider):
916
name = "Google"
1017

1118
api_key = Setting(default="")
1219
model = Setting(default="gemini-2.0-flash") # Assuming a default model name
1320

14-
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
21+
def __init__(self, config: Configuration) -> None:
22+
self.config = config
1623
if not self.api_key:
1724
msg = "No API key provided"
1825
raise ProviderError(msg)

src/shelloracle/providers/localai.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
from openai import APIError, AsyncOpenAI
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class LocalAI(Provider):
916
name = "LocalAI"
@@ -16,8 +23,8 @@ class LocalAI(Provider):
1623
def endpoint(self) -> str:
1724
return f"http://{self.host}:{self.port}"
1825

19-
def __init__(self, *args, **kwargs):
20-
super().__init__(*args, **kwargs)
26+
def __init__(self, config: Configuration) -> None:
27+
self.config = config
2128
# Use a placeholder API key so the client will work
2229
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
2330

src/shelloracle/providers/ollama.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if TYPE_CHECKING:
1212
from collections.abc import AsyncIterator
1313

14+
from shelloracle.config import Configuration
15+
1416

1517
def dataclass_to_json(obj: Any) -> dict[str, Any]:
1618
"""Convert dataclass to a json dict
@@ -58,6 +60,9 @@ class Ollama(Provider):
5860
port = Setting(default=11434)
5961
model = Setting(default="dolphin-mistral")
6062

63+
def __init__(self, config: Configuration) -> None:
64+
self.config = config
65+
6166
@property
6267
def endpoint(self) -> str:
6368
# computed property because python descriptors need to be bound to an instance before access

src/shelloracle/providers/openai.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
from openai import APIError, AsyncOpenAI
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class OpenAI(Provider):
916
name = "OpenAI"
1017

1118
api_key = Setting(default="")
1219
model = Setting(default="gpt-3.5-turbo")
1320

14-
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
21+
def __init__(self, config: Configuration) -> None:
22+
self.config = config
1623
if not self.api_key:
1724
msg = "No API key provided"
1825
raise ProviderError(msg)

src/shelloracle/providers/openai_compat.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
from openai import APIError, AsyncOpenAI
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class OpenAICompat(Provider):
916
name = "OpenAICompat"
@@ -12,8 +19,8 @@ class OpenAICompat(Provider):
1219
api_key = Setting(default="")
1320
model = Setting(default="")
1421

15-
def __init__(self, *args, **kwargs):
16-
super().__init__(*args, **kwargs)
22+
def __init__(self, config: Configuration) -> None:
23+
self.config = config
1724
if not self.api_key:
1825
msg = "No API key provided. Use a dummy placeholder if no key is required"
1926
raise ProviderError(msg)

src/shelloracle/providers/xai.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1-
from collections.abc import AsyncIterator
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
24

35
from openai import APIError, AsyncOpenAI
46

57
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
68

9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterator
11+
12+
from shelloracle.config import Configuration
13+
714

815
class XAI(Provider):
916
name = "XAI"
1017

1118
api_key = Setting(default="")
1219
model = Setting(default="grok-beta")
1320

14-
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
21+
def __init__(self, config: Configuration) -> None:
22+
self.config = config
1623
if not self.api_key:
1724
msg = "No API key provided"
1825
raise ProviderError(msg)

src/shelloracle/tty_log_handler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24

35
from prompt_toolkit import print_formatted_text
@@ -6,7 +8,7 @@
68

79

810
class TtyLogHandler(logging.Handler):
9-
def emit(self, record):
11+
def emit(self, record: logging.LogRecord):
1012
if record.levelno >= logging.ERROR:
1113
color = "ansired"
1214
elif record.levelno == logging.WARNING:

0 commit comments

Comments
 (0)