Skip to content

Commit d58871f

Browse files
Fix mypy errors
1 parent f2792ff commit d58871f

File tree

7 files changed

+19
-19
lines changed

7 files changed

+19
-19
lines changed

src/neo4j_graphrag/generation/graphrag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def search(
146146
result["retriever_result"] = retriever_result
147147
return RagResultModel(**result)
148148

149-
def build_query(self, query_text: str, chat_history: list[dict[str, str]]) -> str:
149+
def build_query(self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None) -> str:
150150
if chat_history:
151151
summarization_prompt = ChatSummaryTemplate().format(
152152
chat_history=chat_history

src/neo4j_graphrag/llm/anthropic_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self.client = anthropic.Anthropic(**kwargs)
7272
self.async_client = anthropic.AsyncAnthropic(**kwargs)
7373

74-
def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]:
74+
def get_messages(self, input: str, chat_history: Optional[list[Any]] = None) -> Iterable[MessageParam]:
7575
messages = []
7676
if chat_history:
7777
try:
@@ -83,7 +83,7 @@ def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]
8383
return messages
8484

8585
def invoke(
86-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
86+
self, input: str, chat_history: Optional[list[Any]] = None
8787
) -> LLMResponse:
8888
"""Sends text to the LLM and returns a response.
8989
@@ -107,7 +107,7 @@ def invoke(
107107
raise LLMGenerationError(e)
108108

109109
async def ainvoke(
110-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
110+
self, input: str, chat_history: Optional[list[Any]] = None
111111
) -> LLMResponse:
112112
"""Asynchronously sends text to the LLM and returns a response.
113113

src/neo4j_graphrag/llm/cohere_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
self.client = cohere.ClientV2(**kwargs)
7575
self.async_client = cohere.AsyncClientV2(**kwargs)
7676

77-
def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: ignore
77+
def get_messages(self, input: str, chat_history: Optional[list[Any]] = None) -> ChatMessages:
7878
messages = []
7979
if self.system_instruction:
8080
messages.append(SystemMessage(content=self.system_instruction).model_dump())
@@ -88,7 +88,7 @@ def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type:
8888
return messages
8989

9090
def invoke(
91-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
91+
self, input: str, chat_history: Optional[list[Any]] = None
9292
) -> LLMResponse:
9393
"""Sends text to the LLM and returns a response.
9494
@@ -112,7 +112,7 @@ def invoke(
112112
)
113113

114114
async def ainvoke(
115-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
115+
self, input: str, chat_history: Optional[list[Any]] = None
116116
) -> LLMResponse:
117117
"""Asynchronously sends text to the LLM and returns a response.
118118

src/neo4j_graphrag/llm/mistralai_llm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from mistralai import Mistral, Messages
3232
from mistralai.models.sdkerror import SDKError
3333
except ImportError:
34-
Mistral = None # type: ignore
35-
SDKError = None # type: ignore
34+
Mistral = None
35+
SDKError = None
3636

3737

3838
class MistralAILLM(LLMInterface):
@@ -64,7 +64,7 @@ def __init__(
6464
api_key = os.getenv("MISTRAL_API_KEY", "")
6565
self.client = Mistral(api_key=api_key, **kwargs)
6666

67-
def get_messages(self, input: str, chat_history: list) -> list[Messages]:
67+
def get_messages(self, input: str, chat_history: Optional[list[Any]] = None) -> list[Messages]:
6868
messages = []
6969
if self.system_instruction:
7070
messages.append(SystemMessage(content=self.system_instruction).model_dump())
@@ -78,7 +78,7 @@ def get_messages(self, input: str, chat_history: list) -> list[Messages]:
7878
return messages
7979

8080
def invoke(
81-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
81+
self, input: str, chat_history: Optional[list[Any]] = None
8282
) -> LLMResponse:
8383
"""Sends a text input to the Mistral chat completion model
8484
and returns the response's content.
@@ -110,7 +110,7 @@ def invoke(
110110
raise LLMGenerationError(e)
111111

112112
async def ainvoke(
113-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
113+
self, input: str, chat_history: Optional[list[Any]] = None
114114
) -> LLMResponse:
115115
"""Asynchronously sends a text input to the MistralAI chat
116116
completion model and returns the response's content.

src/neo4j_graphrag/llm/openai_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
super().__init__(model_name, model_params, system_instruction)
6262

6363
def get_messages(
64-
self, input: str, chat_history: list
64+
self, input: str, chat_history: Optional[list[Any]] = None
6565
) -> Iterable[ChatCompletionMessageParam]:
6666
messages = []
6767
if self.system_instruction:
@@ -76,7 +76,7 @@ def get_messages(
7676
return messages
7777

7878
def invoke(
79-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
79+
self, input: str, chat_history: Optional[list[Any]] = None
8080
) -> LLMResponse:
8181
"""Sends a text input to the OpenAI chat completion model
8282
and returns the response's content.
@@ -103,7 +103,7 @@ def invoke(
103103
raise LLMGenerationError(e)
104104

105105
async def ainvoke(
106-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
106+
self, input: str, chat_history: Optional[list[Any]] = None
107107
) -> LLMResponse:
108108
"""Asynchronously sends a text input to the OpenAI chat
109109
completion model and returns the response's content.

src/neo4j_graphrag/llm/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class LLMResponse(BaseModel):
77

88

99
class BaseMessage(BaseModel):
10-
role: Literal["user", "assistant"]
10+
role: Literal["user", "assistant", "system"]
1111
content: str
1212

1313

src/neo4j_graphrag/llm/vertexai_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
model_name=model_name, system_instruction=[system_instruction], **kwargs
7777
)
7878

79-
def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
79+
def get_messages(self, input: str, chat_history: Optional[list[Any]] = None) -> list[Content]:
8080
messages = []
8181
if chat_history:
8282
try:
@@ -102,7 +102,7 @@ def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
102102
return messages
103103

104104
def invoke(
105-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
105+
self, input: str, chat_history: Optional[list[Any]] = None
106106
) -> LLMResponse:
107107
"""Sends text to the LLM and returns a response.
108108
@@ -121,7 +121,7 @@ def invoke(
121121
raise LLMGenerationError(e)
122122

123123
async def ainvoke(
124-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
124+
self, input: str, chat_history: Optional[list[Any]] = None
125125
) -> LLMResponse:
126126
"""Asynchronously sends text to the LLM and returns a response.
127127

0 commit comments

Comments
 (0)