Skip to content

Commit 37225fd

Browse files
Rename chat_history to message_history
1 parent 07038dd commit 37225fd

16 files changed

+129
-127
lines changed

src/neo4j_graphrag/generation/graphrag.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
def search(
8888
self,
8989
query_text: str = "",
90-
chat_history: Optional[list[dict[str, str]]] = None,
90+
message_history: Optional[list[dict[str, str]]] = None,
9191
examples: str = "",
9292
retriever_config: Optional[dict[str, Any]] = None,
9393
return_context: bool | None = None,
@@ -105,7 +105,7 @@ def search(
105105
106106
Args:
107107
query_text (str): The user question.
108-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
108+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
109109
examples (str): Examples added to the LLM prompt.
110110
retriever_config (Optional[dict]): Parameters passed to the retriever.
111111
search method; e.g.: top_k
@@ -130,7 +130,7 @@ def search(
130130
)
131131
except ValidationError as e:
132132
raise SearchValidationError(e.errors())
133-
query = self.build_query(validated_data.query_text, chat_history)
133+
query = self.build_query(validated_data.query_text, message_history)
134134
retriever_result: RetrieverResult = self.retriever.search(
135135
query_text=query, **validated_data.retriever_config
136136
)
@@ -140,18 +140,18 @@ def search(
140140
)
141141
logger.debug(f"RAG: retriever_result={retriever_result}")
142142
logger.debug(f"RAG: prompt={prompt}")
143-
answer = self.llm.invoke(prompt, chat_history)
143+
answer = self.llm.invoke(prompt, message_history)
144144
result: dict[str, Any] = {"answer": answer.content}
145145
if return_context:
146146
result["retriever_result"] = retriever_result
147147
return RagResultModel(**result)
148148

149149
def build_query(
150-
self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None
150+
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
151151
) -> str:
152-
if chat_history:
152+
if message_history:
153153
summarization_prompt = ChatSummaryTemplate().format(
154-
chat_history=chat_history
154+
message_history=message_history
155155
)
156156
summary = self.llm.invoke(
157157
input=summarization_prompt,

src/neo4j_graphrag/generation/prompts.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -200,25 +200,25 @@ def format(
200200

201201
class ChatSummaryTemplate(PromptTemplate):
202202
DEFAULT_TEMPLATE = """
203-
Summarize the chat history:
203+
Summarize the message history:
204204
205-
{chat_history}
205+
{message_history}
206206
"""
207-
EXPECTED_INPUTS = ["chat_history"]
207+
EXPECTED_INPUTS = ["message_history"]
208208
SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"
209209

210-
def format(self, chat_history: list[dict[str, str]]) -> str:
210+
def format(self, message_history: list[dict[str, str]]) -> str:
211211
message_list = [
212212
": ".join([f"{value}" for _, value in message.items()])
213-
for message in chat_history
213+
for message in message_history
214214
]
215215
history = "\n".join(message_list)
216-
return super().format(chat_history=history)
216+
return super().format(message_history=history)
217217

218218

219219
class ConversationTemplate(PromptTemplate):
220220
DEFAULT_TEMPLATE = """
221-
Chat Summary:
221+
Message Summary:
222222
{summary}
223223
224224
Current Query:

src/neo4j_graphrag/llm/anthropic_llm.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -71,32 +71,32 @@ def __init__(
7171
self.async_client = anthropic.AsyncAnthropic(**kwargs)
7272

7373
def get_messages(
74-
self, input: str, chat_history: Optional[list[Any]] = None
74+
self, input: str, message_history: Optional[list[Any]] = None
7575
) -> Iterable[MessageParam]:
7676
messages = []
77-
if chat_history:
77+
if message_history:
7878
try:
79-
MessageList(messages=chat_history)
79+
MessageList(messages=message_history)
8080
except ValidationError as e:
8181
raise LLMGenerationError(e.errors()) from e
82-
messages.extend(chat_history)
82+
messages.extend(message_history)
8383
messages.append(UserMessage(content=input).model_dump())
8484
return messages
8585

8686
def invoke(
87-
self, input: str, chat_history: Optional[list[Any]] = None
87+
self, input: str, message_history: Optional[list[Any]] = None
8888
) -> LLMResponse:
8989
"""Sends text to the LLM and returns a response.
9090
9191
Args:
9292
input (str): The text to send to the LLM.
93-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
93+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
9494
9595
Returns:
9696
LLMResponse: The response from the LLM.
9797
"""
9898
try:
99-
messages = self.get_messages(input, chat_history)
99+
messages = self.get_messages(input, message_history)
100100
response = self.client.messages.create(
101101
model=self.model_name,
102102
system=self.system_instruction,
@@ -108,19 +108,19 @@ def invoke(
108108
raise LLMGenerationError(e)
109109

110110
async def ainvoke(
111-
self, input: str, chat_history: Optional[list[Any]] = None
111+
self, input: str, message_history: Optional[list[Any]] = None
112112
) -> LLMResponse:
113113
"""Asynchronously sends text to the LLM and returns a response.
114114
115115
Args:
116116
input (str): The text to send to the LLM.
117-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
117+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
118118
119119
Returns:
120120
LLMResponse: The response from the LLM.
121121
"""
122122
try:
123-
messages = self.get_messages(input, chat_history)
123+
messages = self.get_messages(input, message_history)
124124
response = await self.async_client.messages.create(
125125
model=self.model_name,
126126
system=self.system_instruction,

src/neo4j_graphrag/llm/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ def __init__(
4545
def invoke(
4646
self,
4747
input: str,
48-
chat_history: Optional[list[dict[str, str]]] = None,
48+
message_history: Optional[list[dict[str, str]]] = None,
4949
system_instruction: Optional[str] = None,
5050
) -> LLMResponse:
5151
"""Sends a text input to the LLM and retrieves a response.
5252
5353
Args:
5454
input (str): Text sent to the LLM.
55-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
55+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
5656
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
5757
5858
Returns:
@@ -64,13 +64,13 @@ def invoke(
6464

6565
@abstractmethod
6666
async def ainvoke(
67-
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
67+
self, input: str, message_history: Optional[list[dict[str, str]]] = None
6868
) -> LLMResponse:
6969
"""Asynchronously sends a text input to the LLM and retrieves a response.
7070
7171
Args:
7272
input (str): Text sent to the LLM.
73-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
73+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
7474
7575
7676
Returns:

src/neo4j_graphrag/llm/cohere_llm.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -74,34 +74,34 @@ def __init__(
7474
self.async_client = cohere.AsyncClientV2(**kwargs)
7575

7676
def get_messages(
77-
self, input: str, chat_history: Optional[list[Any]] = None
77+
self, input: str, message_history: Optional[list[Any]] = None
7878
) -> ChatMessages:
7979
messages = []
8080
if self.system_instruction:
8181
messages.append(SystemMessage(content=self.system_instruction).model_dump())
82-
if chat_history:
82+
if message_history:
8383
try:
84-
MessageList(messages=chat_history)
84+
MessageList(messages=message_history)
8585
except ValidationError as e:
8686
raise LLMGenerationError(e.errors()) from e
87-
messages.extend(chat_history)
87+
messages.extend(message_history)
8888
messages.append(UserMessage(content=input).model_dump())
8989
return messages
9090

9191
def invoke(
92-
self, input: str, chat_history: Optional[list[Any]] = None
92+
self, input: str, message_history: Optional[list[Any]] = None
9393
) -> LLMResponse:
9494
"""Sends text to the LLM and returns a response.
9595
9696
Args:
9797
input (str): The text to send to the LLM.
98-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
98+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
9999
100100
Returns:
101101
LLMResponse: The response from the LLM.
102102
"""
103103
try:
104-
messages = self.get_messages(input, chat_history)
104+
messages = self.get_messages(input, message_history)
105105
res = self.client.chat(
106106
messages=messages,
107107
model=self.model_name,
@@ -113,19 +113,19 @@ def invoke(
113113
)
114114

115115
async def ainvoke(
116-
self, input: str, chat_history: Optional[list[Any]] = None
116+
self, input: str, message_history: Optional[list[Any]] = None
117117
) -> LLMResponse:
118118
"""Asynchronously sends text to the LLM and returns a response.
119119
120120
Args:
121121
input (str): The text to send to the LLM.
122-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
122+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
123123
124124
Returns:
125125
LLMResponse: The response from the LLM.
126126
"""
127127
try:
128-
messages = self.get_messages(input, chat_history)
128+
messages = self.get_messages(input, message_history)
129129
res = self.async_client.chat(
130130
messages=messages,
131131
model=self.model_name,

src/neo4j_graphrag/llm/mistralai_llm.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,29 @@ def __init__(
6565
self.client = Mistral(api_key=api_key, **kwargs)
6666

6767
def get_messages(
68-
self, input: str, chat_history: Optional[list[Any]] = None
68+
self, input: str, message_history: Optional[list[Any]] = None
6969
) -> list[Messages]:
7070
messages = []
7171
if self.system_instruction:
7272
messages.append(SystemMessage(content=self.system_instruction).model_dump())
73-
if chat_history:
73+
if message_history:
7474
try:
75-
MessageList(messages=chat_history)
75+
MessageList(messages=message_history)
7676
except ValidationError as e:
7777
raise LLMGenerationError(e.errors()) from e
78-
messages.extend(chat_history)
78+
messages.extend(message_history)
7979
messages.append(UserMessage(content=input).model_dump())
8080
return messages
8181

8282
def invoke(
83-
self, input: str, chat_history: Optional[list[Any]] = None
83+
self, input: str, message_history: Optional[list[Any]] = None
8484
) -> LLMResponse:
8585
"""Sends a text input to the Mistral chat completion model
8686
and returns the response's content.
8787
8888
Args:
8989
input (str): Text sent to the LLM.
90-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
90+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
9191
9292
Returns:
9393
LLMResponse: The response from MistralAI.
@@ -96,7 +96,7 @@ def invoke(
9696
LLMGenerationError: If anything goes wrong.
9797
"""
9898
try:
99-
messages = self.get_messages(input, chat_history)
99+
messages = self.get_messages(input, message_history)
100100
response = self.client.chat.complete(
101101
model=self.model_name,
102102
messages=messages,
@@ -112,14 +112,14 @@ def invoke(
112112
raise LLMGenerationError(e)
113113

114114
async def ainvoke(
115-
self, input: str, chat_history: Optional[list[Any]] = None
115+
self, input: str, message_history: Optional[list[Any]] = None
116116
) -> LLMResponse:
117117
"""Asynchronously sends a text input to the MistralAI chat
118118
completion model and returns the response's content.
119119
120120
Args:
121121
input (str): Text sent to the LLM.
122-
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
122+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
123123
124124
Returns:
125125
LLMResponse: The response from MistralAI.
@@ -128,7 +128,7 @@ async def ainvoke(
128128
LLMGenerationError: If anything goes wrong.
129129
"""
130130
try:
131-
messages = self.get_messages(input, chat_history)
131+
messages = self.get_messages(input, message_history)
132132
response = await self.client.chat.complete_async(
133133
model=self.model_name,
134134
messages=messages,

src/neo4j_graphrag/llm/ollama_llm.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,27 @@ def __init__(
5050
)
5151

5252
def get_messages(
53-
self, input: str, chat_history: Optional[list[Any]] = None
53+
self, input: str, message_history: Optional[list[Any]] = None
5454
) -> Sequence[Message]:
5555
messages = []
5656
if self.system_instruction:
5757
messages.append(SystemMessage(content=self.system_instruction).model_dump())
58-
if chat_history:
58+
if message_history:
5959
try:
60-
MessageList(messages=chat_history)
60+
MessageList(messages=message_history)
6161
except ValidationError as e:
6262
raise LLMGenerationError(e.errors()) from e
63-
messages.extend(chat_history)
63+
messages.extend(message_history)
6464
messages.append(UserMessage(content=input).model_dump())
6565
return messages
6666

6767
def invoke(
68-
self, input: str, chat_history: Optional[list[Any]] = None
68+
self, input: str, message_history: Optional[list[Any]] = None
6969
) -> LLMResponse:
7070
try:
7171
response = self.client.chat(
7272
model=self.model_name,
73-
messages=self.get_messages(input, chat_history),
73+
messages=self.get_messages(input, message_history),
7474
options=self.model_params,
7575
)
7676
content = response.message.content or ""
@@ -79,12 +79,12 @@ def invoke(
7979
raise LLMGenerationError(e)
8080

8181
async def ainvoke(
82-
self, input: str, chat_history: Optional[list[Any]] = None
82+
self, input: str, message_history: Optional[list[Any]] = None
8383
) -> LLMResponse:
8484
try:
8585
response = await self.async_client.chat(
8686
model=self.model_name,
87-
messages=self.get_messages(input, chat_history),
87+
messages=self.get_messages(input, message_history),
8888
options=self.model_params,
8989
)
9090
content = response.message.content or ""

0 commit comments

Comments
 (0)