Skip to content

Commit 4f1b0f6

Browse files
Use BaseMessage class type
* for the type declaration of the `message_history` parameter
1 parent 37225fd commit 4f1b0f6

File tree

8 files changed

+33
-30
lines changed

8 files changed

+33
-30
lines changed

src/neo4j_graphrag/generation/graphrag.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
3333
from neo4j_graphrag.llm import LLMInterface
34+
from neo4j_graphrag.llm.types import BaseMessage
3435
from neo4j_graphrag.retrievers.base import Retriever
3536
from neo4j_graphrag.types import RetrieverResult
3637

@@ -87,7 +88,7 @@ def __init__(
8788
def search(
8889
self,
8990
query_text: str = "",
90-
message_history: Optional[list[dict[str, str]]] = None,
91+
message_history: Optional[list[BaseMessage]] = None,
9192
examples: str = "",
9293
retriever_config: Optional[dict[str, Any]] = None,
9394
return_context: bool | None = None,
@@ -155,7 +156,7 @@ def build_query(
155156
)
156157
summary = self.llm.invoke(
157158
input=summarization_prompt,
158-
system_instruction=summarization_prompt.SYSTEM_MESSAGE,
159+
system_instruction=ChatSummaryTemplate().SYSTEM_MESSAGE,
159160
).content
160161
return ConversationTemplate().format(
161162
summary=summary, current_query=query_text

src/neo4j_graphrag/llm/anthropic_llm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage
22+
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage
2323

2424
if TYPE_CHECKING:
2525
from anthropic.types.message_param import MessageParam
@@ -71,7 +71,7 @@ def __init__(
7171
self.async_client = anthropic.AsyncAnthropic(**kwargs)
7272

7373
def get_messages(
74-
self, input: str, message_history: Optional[list[Any]] = None
74+
self, input: str, message_history: Optional[list[BaseMessage]] = None
7575
) -> Iterable[MessageParam]:
7676
messages = []
7777
if message_history:
@@ -84,7 +84,7 @@ def get_messages(
8484
return messages
8585

8686
def invoke(
87-
self, input: str, message_history: Optional[list[Any]] = None
87+
self, input: str, message_history: Optional[list[BaseMessage]] = None
8888
) -> LLMResponse:
8989
"""Sends text to the LLM and returns a response.
9090
@@ -108,7 +108,7 @@ def invoke(
108108
raise LLMGenerationError(e)
109109

110110
async def ainvoke(
111-
self, input: str, message_history: Optional[list[Any]] = None
111+
self, input: str, message_history: Optional[list[BaseMessage]] = None
112112
) -> LLMResponse:
113113
"""Asynchronously sends text to the LLM and returns a response.
114114

src/neo4j_graphrag/llm/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from typing import Any, Optional
1919

20-
from .types import LLMResponse
20+
from .types import LLMResponse, BaseMessage
2121

2222

2323
class LLMInterface(ABC):
@@ -45,7 +45,7 @@ def __init__(
4545
def invoke(
4646
self,
4747
input: str,
48-
message_history: Optional[list[dict[str, str]]] = None,
48+
message_history: Optional[list[BaseMessage]] = None,
4949
system_instruction: Optional[str] = None,
5050
) -> LLMResponse:
5151
"""Sends a text input to the LLM and retrieves a response.

src/neo4j_graphrag/llm/cohere_llm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MessageList,
2525
SystemMessage,
2626
UserMessage,
27+
BaseMessage,
2728
)
2829

2930
if TYPE_CHECKING:
@@ -74,7 +75,7 @@ def __init__(
7475
self.async_client = cohere.AsyncClientV2(**kwargs)
7576

7677
def get_messages(
77-
self, input: str, message_history: Optional[list[Any]] = None
78+
self, input: str, message_history: Optional[list[BaseMessage]] = None
7879
) -> ChatMessages:
7980
messages = []
8081
if self.system_instruction:
@@ -89,7 +90,7 @@ def get_messages(
8990
return messages
9091

9192
def invoke(
92-
self, input: str, message_history: Optional[list[Any]] = None
93+
self, input: str, message_history: Optional[list[BaseMessage]] = None
9394
) -> LLMResponse:
9495
"""Sends text to the LLM and returns a response.
9596
@@ -113,7 +114,7 @@ def invoke(
113114
)
114115

115116
async def ainvoke(
116-
self, input: str, message_history: Optional[list[Any]] = None
117+
self, input: str, message_history: Optional[list[BaseMessage]] = None
117118
) -> LLMResponse:
118119
"""Asynchronously sends text to the LLM and returns a response.
119120

src/neo4j_graphrag/llm/mistralai_llm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MessageList,
2626
SystemMessage,
2727
UserMessage,
28+
BaseMessage,
2829
)
2930

3031
try:
@@ -65,7 +66,7 @@ def __init__(
6566
self.client = Mistral(api_key=api_key, **kwargs)
6667

6768
def get_messages(
68-
self, input: str, message_history: Optional[list[Any]] = None
69+
self, input: str, message_history: Optional[list[BaseMessage]] = None
6970
) -> list[Messages]:
7071
messages = []
7172
if self.system_instruction:
@@ -80,7 +81,7 @@ def get_messages(
8081
return messages
8182

8283
def invoke(
83-
self, input: str, message_history: Optional[list[Any]] = None
84+
self, input: str, message_history: Optional[list[BaseMessage]] = None
8485
) -> LLMResponse:
8586
"""Sends a text input to the Mistral chat completion model
8687
and returns the response's content.
@@ -112,7 +113,7 @@ def invoke(
112113
raise LLMGenerationError(e)
113114

114115
async def ainvoke(
115-
self, input: str, message_history: Optional[list[Any]] = None
116+
self, input: str, message_history: Optional[list[BaseMessage]] = None
116117
) -> LLMResponse:
117118
"""Asynchronously sends a text input to the MistralAI chat
118119
completion model and returns the response's content.

src/neo4j_graphrag/llm/ollama_llm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from neo4j_graphrag.exceptions import LLMGenerationError
2020

2121
from .base import LLMInterface
22-
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
22+
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
2323

2424
if TYPE_CHECKING:
2525
from ollama import Message
@@ -50,7 +50,7 @@ def __init__(
5050
)
5151

5252
def get_messages(
53-
self, input: str, message_history: Optional[list[Any]] = None
53+
self, input: str, message_history: Optional[list[BaseMessage]] = None
5454
) -> Sequence[Message]:
5555
messages = []
5656
if self.system_instruction:
@@ -65,7 +65,7 @@ def get_messages(
6565
return messages
6666

6767
def invoke(
68-
self, input: str, message_history: Optional[list[Any]] = None
68+
self, input: str, message_history: Optional[list[BaseMessage]] = None
6969
) -> LLMResponse:
7070
try:
7171
response = self.client.chat(
@@ -79,7 +79,7 @@ def invoke(
7979
raise LLMGenerationError(e)
8080

8181
async def ainvoke(
82-
self, input: str, message_history: Optional[list[Any]] = None
82+
self, input: str, message_history: Optional[list[BaseMessage]] = None
8383
) -> LLMResponse:
8484
try:
8585
response = await self.async_client.chat(

src/neo4j_graphrag/llm/openai_llm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..exceptions import LLMGenerationError
2323
from .base import LLMInterface
24-
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
24+
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
2525

2626
if TYPE_CHECKING:
2727
import openai
@@ -63,7 +63,7 @@ def __init__(
6363
def get_messages(
6464
self,
6565
input: str,
66-
message_history: Optional[list[Any]] = None,
66+
message_history: Optional[list[BaseMessage]] = None,
6767
system_instruction: Optional[str] = None,
6868
) -> Iterable[ChatCompletionMessageParam]:
6969
messages = []
@@ -86,7 +86,7 @@ def get_messages(
8686
def invoke(
8787
self,
8888
input: str,
89-
message_history: Optional[list[Any]] = None,
89+
message_history: Optional[list[BaseMessage]] = None,
9090
system_instruction: Optional[str] = None,
9191
) -> LLMResponse:
9292
"""Sends a text input to the OpenAI chat completion model
@@ -115,7 +115,7 @@ def invoke(
115115
raise LLMGenerationError(e)
116116

117117
async def ainvoke(
118-
self, input: str, message_history: Optional[list[Any]] = None
118+
self, input: str, message_history: Optional[list[BaseMessage]] = None
119119
) -> LLMResponse:
120120
"""Asynchronously sends a text input to the OpenAI chat
121121
completion model and returns the response's content.

src/neo4j_graphrag/llm/vertexai_llm.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import LLMResponse, MessageList
22+
from neo4j_graphrag.llm.types import LLMResponse, MessageList, BaseMessage
2323

2424
try:
2525
from vertexai.generative_models import (
@@ -77,7 +77,7 @@ def __init__(
7777
self.model_params = kwargs
7878

7979
def get_messages(
80-
self, input: str, message_history: Optional[list[Any]] = None
80+
self, input: str, message_history: Optional[list[BaseMessage]] = None
8181
) -> list[Content]:
8282
messages = []
8383
if message_history:
@@ -87,16 +87,16 @@ def get_messages(
8787
raise LLMGenerationError(e.errors()) from e
8888

8989
for message in message_history:
90-
if message.get("role") == "user":
90+
if message.role == "user":
9191
messages.append(
9292
Content(
93-
role="user", parts=[Part.from_text(message.get("content"))]
93+
role="user", parts=[Part.from_text(message.content)]
9494
)
9595
)
96-
elif message.get("role") == "assistant":
96+
elif message.role == "assistant":
9797
messages.append(
9898
Content(
99-
role="model", parts=[Part.from_text(message.get("content"))]
99+
role="model", parts=[Part.from_text(message.content)]
100100
)
101101
)
102102

@@ -106,7 +106,7 @@ def get_messages(
106106
def invoke(
107107
self,
108108
input: str,
109-
message_history: Optional[list[Any]] = None,
109+
message_history: Optional[list[BaseMessage]] = None,
110110
system_instruction: Optional[str] = None,
111111
) -> LLMResponse:
112112
"""Sends text to the LLM and returns a response.
@@ -137,7 +137,7 @@ def invoke(
137137
raise LLMGenerationError(e)
138138

139139
async def ainvoke(
140-
self, input: str, message_history: Optional[list[Any]] = None
140+
self, input: str, message_history: Optional[list[BaseMessage]] = None
141141
) -> LLMResponse:
142142
"""Asynchronously sends text to the LLM and returns a response.
143143

0 commit comments

Comments
 (0)