Skip to content

Commit 0709372

Browse files
committed
use abc for provider
protocol is awesome for typing, but pycharm tooling doesn't fully understand it && behavior difference between 3.9 and 3.11 makes it cumbersome
1 parent c2d5b9b commit 0709372

File tree

8 files changed

+15
-16
lines changed

8 files changed

+15
-16
lines changed

src/shelloracle/providers/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import abc
34
from abc import abstractmethod
45
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
56

@@ -22,7 +23,7 @@ class ProviderError(Exception):
2223
"""LLM providers raise this error to gracefully indicate something has gone wrong."""
2324

2425

25-
class Provider(Protocol):
26+
class Provider(abc.ABC):
2627
"""
2728
LLM Provider Protocol
2829
@@ -38,6 +39,7 @@ def __init__(self, config: Configuration) -> None:
3839
:param config: the configuration object
3940
:return: none
4041
"""
42+
self.config = config
4143

4244
@abstractmethod
4345
def generate(self, prompt: str) -> AsyncIterator[str]:

src/shelloracle/providers/deepseek.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class Deepseek(Provider):
1818
api_key = Setting(default="")
1919
model = Setting(default="deepseek-chat")
2020

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
21+
def __init__(self, *args, **kwargs) -> None:
22+
super().__init__(*args, **kwargs)
2323
if not self.api_key:
2424
msg = "No API key provided"
2525
raise ProviderError(msg)

src/shelloracle/providers/google.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class Google(Provider):
1818
api_key = Setting(default="")
1919
model = Setting(default="gemini-2.0-flash") # Assuming a default model name
2020

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
21+
def __init__(self, *args, **kwargs) -> None:
22+
super().__init__(*args, **kwargs)
2323
if not self.api_key:
2424
msg = "No API key provided"
2525
raise ProviderError(msg)

src/shelloracle/providers/localai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class LocalAI(Provider):
2323
def endpoint(self) -> str:
2424
return f"http://{self.host}:{self.port}"
2525

26-
def __init__(self, config: Configuration) -> None:
27-
self.config = config
26+
def __init__(self, *args, **kwargs) -> None:
27+
super().__init__(*args, **kwargs)
2828
# Use a placeholder API key so the client will work
2929
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
3030

src/shelloracle/providers/ollama.py

-3
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ class Ollama(Provider):
6060
port = Setting(default=11434)
6161
model = Setting(default="dolphin-mistral")
6262

63-
def __init__(self, config: Configuration) -> None:
64-
self.config = config
65-
6663
@property
6764
def endpoint(self) -> str:
6865
# computed property because python descriptors need to be bound to an instance before access

src/shelloracle/providers/openai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class OpenAI(Provider):
1818
api_key = Setting(default="")
1919
model = Setting(default="gpt-3.5-turbo")
2020

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
21+
def __init__(self, *args, **kwargs) -> None:
22+
super().__init__(*args, **kwargs)
2323
if not self.api_key:
2424
msg = "No API key provided"
2525
raise ProviderError(msg)

src/shelloracle/providers/openai_compat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class OpenAICompat(Provider):
1919
api_key = Setting(default="")
2020
model = Setting(default="")
2121

22-
def __init__(self, config: Configuration) -> None:
23-
self.config = config
22+
def __init__(self, *args, **kwargs) -> None:
23+
super().__init__(*args, **kwargs)
2424
if not self.api_key:
2525
msg = "No API key provided. Use a dummy placeholder if no key is required"
2626
raise ProviderError(msg)

src/shelloracle/providers/xai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class XAI(Provider):
1818
api_key = Setting(default="")
1919
model = Setting(default="grok-beta")
2020

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
21+
def __init__(self, *args, **kwargs) -> None:
22+
super().__init__(*args, **kwargs)
2323
if not self.api_key:
2424
msg = "No API key provided"
2525
raise ProviderError(msg)

0 commit comments

Comments
 (0)