Skip to content

Commit 7c10077

Browse files
Chat with message history (draft)
1 parent bc6dd9c commit 7c10077

File tree

10 files changed

+182
-24
lines changed

10 files changed

+182
-24
lines changed

.pre-commit-config.yaml

-14
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,3 @@ repos:
1818
language: system
1919
types: [ python ]
2020
stages: [ commit, push ]
21-
- id: mypy
22-
name: Mypy Type Check
23-
entry: mypy .
24-
language: system
25-
types: [ python ]
26-
stages: [ commit, push ]
27-
pass_filenames: false
28-
args: [
29-
--strict,
30-
--ignore-missing-imports,
31-
--allow-untyped-calls,
32-
--allow-subclassing-any,
33-
--exclude='./docs/'
34-
]

src/neo4j_graphrag/generation/graphrag.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
def search(
8484
self,
8585
query_text: str = "",
86+
chat_history: Optional[list[str]] = None,
8687
examples: str = "",
8788
retriever_config: Optional[dict[str, Any]] = None,
8889
return_context: bool | None = None,
@@ -100,6 +101,7 @@ def search(
100101
101102
Args:
102103
query_text (str): The user question
104+
chat_history: (Optional[list]): A list of previous messages in the conversation
103105
examples (str): Examples added to the LLM prompt.
104106
retriever_config (Optional[dict]): Parameters passed to the retriever
105107
search method; e.g.: top_k
@@ -134,7 +136,10 @@ def search(
134136
)
135137
logger.debug(f"RAG: retriever_result={retriever_result}")
136138
logger.debug(f"RAG: prompt={prompt}")
137-
answer = self.llm.invoke(prompt)
139+
if chat_history is not None:
140+
answer = self.llm.chat(prompt, chat_history)
141+
else:
142+
answer = self.llm.invoke(prompt)
138143
result: dict[str, Any] = {"answer": answer.content}
139144
if return_context:
140145
result["retriever_result"] = retriever_result

src/neo4j_graphrag/llm/anthropic_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
110110
return LLMResponse(content=response.content)
111111
except self.anthropic.APIError as e:
112112
raise LLMGenerationError(e)
113+
114+
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
115+
pass

src/neo4j_graphrag/llm/base.py

+17
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def __init__(
3333
self,
3434
model_name: str,
3535
model_params: Optional[dict[str, Any]] = None,
36+
system_instruction: Optional[str] = None,
3637
**kwargs: Any,
3738
):
3839
self.model_name = model_name
3940
self.model_params = model_params or {}
41+
self.system_instruction = system_instruction
4042

4143
@abstractmethod
4244
def invoke(self, input: str) -> LLMResponse:
@@ -52,6 +54,21 @@ def invoke(self, input: str) -> LLMResponse:
5254
LLMGenerationError: If anything goes wrong.
5355
"""
5456

57+
@abstractmethod
58+
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
59+
"""Sends a text input and a converstion history to the LLM and retrieves a response.
60+
61+
Args:
62+
input (str): Text sent to the LLM
63+
chat_history (list[str]]): A list of previous messages in the conversation
64+
65+
Returns:
66+
LLMResponse: The response from the LLM.
67+
68+
Raises:
69+
LLMGenerationError: If anything goes wrong.
70+
"""
71+
5572
@abstractmethod
5673
async def ainvoke(self, input: str) -> LLMResponse:
5774
"""Asynchronously sends a text input to the LLM and retrieves a response.

src/neo4j_graphrag/llm/cohere_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
102102
return LLMResponse(
103103
content=res.text,
104104
)
105+
106+
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
107+
pass

src/neo4j_graphrag/llm/mistralai_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
118118
return LLMResponse(content=content)
119119
except SDKError as e:
120120
raise LLMGenerationError(e)
121+
122+
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
123+
pass

src/neo4j_graphrag/llm/openai_llm.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
self,
3737
model_name: str,
3838
model_params: Optional[dict[str, Any]] = None,
39+
system_instruction: Optional[str] = None,
3940
):
4041
"""
4142
Base class for OpenAI LLM.
@@ -54,7 +55,7 @@ def __init__(
5455
"Please install it with `pip install openai`."
5556
)
5657
self.openai = openai
57-
super().__init__(model_name, model_params)
58+
super().__init__(model_name, model_params, system_instruction)
5859

5960
def get_messages(
6061
self,
@@ -64,6 +65,32 @@ def get_messages(
6465
{"role": "system", "content": input},
6566
]
6667

68+
def get_conversation_history(
69+
self,
70+
input: str,
71+
chat_history: list[str],
72+
) -> Iterable[ChatCompletionMessageParam]:
73+
messages = [{"role": "system", "content": self.system_instruction}]
74+
for i, message in enumerate(chat_history):
75+
if i % 2 == 0:
76+
messages.append({"role": "user", "content": message})
77+
else:
78+
messages.append({"role": "assistant", "content": message})
79+
messages.append({"role": "user", "content": input})
80+
return messages
81+
82+
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
83+
try:
84+
response = self.client.chat.completions.create(
85+
messages=self.get_conversation_history(input, chat_history),
86+
model=self.model_name,
87+
**self.model_params,
88+
)
89+
content = response.choices[0].message.content or ""
90+
return LLMResponse(content=content)
91+
except self.openai.OpenAIError as e:
92+
raise LLMGenerationError(e)
93+
6794
def invoke(self, input: str) -> LLMResponse:
6895
"""Sends a text input to the OpenAI chat completion model
6996
and returns the response's content.
@@ -118,6 +145,7 @@ def __init__(
118145
self,
119146
model_name: str,
120147
model_params: Optional[dict[str, Any]] = None,
148+
system_instruction: Optional[str] = None,
121149
**kwargs: Any,
122150
):
123151
"""OpenAI LLM
@@ -129,7 +157,7 @@ def __init__(
129157
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
130158
kwargs: All other parameters will be passed to the openai.OpenAI init.
131159
"""
132-
super().__init__(model_name, model_params)
160+
super().__init__(model_name, model_params, system_instruction)
133161
self.client = self.openai.OpenAI(**kwargs)
134162
self.async_client = self.openai.AsyncOpenAI(**kwargs)
135163

@@ -139,6 +167,7 @@ def __init__(
139167
self,
140168
model_name: str,
141169
model_params: Optional[dict[str, Any]] = None,
170+
system_instruction: Optional[str] = None,
142171
**kwargs: Any,
143172
):
144173
"""Azure OpenAI LLM. Use this class when using an OpenAI model
@@ -149,6 +178,6 @@ def __init__(
149178
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
150179
kwargs: All other parameters will be passed to the openai.OpenAI init.
151180
"""
152-
super().__init__(model_name, model_params)
181+
super().__init__(model_name, model_params, system_instruction)
153182
self.client = self.openai.AzureOpenAI(**kwargs)
154183
self.async_client = self.openai.AsyncAzureOpenAI(**kwargs)

src/neo4j_graphrag/llm/vertexai_llm.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from neo4j_graphrag.llm.types import LLMResponse
2121

2222
try:
23-
from vertexai.generative_models import GenerativeModel, ResponseValidationError
23+
from vertexai.generative_models import (
24+
GenerativeModel,
25+
ResponseValidationError,
26+
Part,
27+
Content,
28+
)
2429
except ImportError:
2530
GenerativeModel = None
2631
ResponseValidationError = None
@@ -55,15 +60,18 @@ def __init__(
5560
self,
5661
model_name: str = "gemini-1.5-flash-001",
5762
model_params: Optional[dict[str, Any]] = None,
63+
system_instruction: Optional[str] = None,
5864
**kwargs: Any,
5965
):
6066
if GenerativeModel is None or ResponseValidationError is None:
6167
raise ImportError(
6268
"Could not import Vertex AI Python client. "
6369
"Please install it with `pip install google-cloud-aiplatform`."
6470
)
65-
super().__init__(model_name, model_params)
66-
self.model = GenerativeModel(model_name=model_name, **kwargs)
71+
super().__init__(model_name, model_params, system_instruction)
72+
self.model = GenerativeModel(
73+
model_name=model_name, system_instruction=[system_instruction], **kwargs
74+
)
6775

6876
def invoke(self, input: str) -> LLMResponse:
6977
"""Sends text to the LLM and returns a response.
@@ -80,7 +88,25 @@ def invoke(self, input: str) -> LLMResponse:
8088
except ResponseValidationError as e:
8189
raise LLMGenerationError(e)
8290

83-
async def ainvoke(self, input: str) -> LLMResponse:
91+
def chat(self, input: str, chat_history: list[str] = []) -> LLMResponse:
92+
"""Sends text to the LLM and returns a response.
93+
94+
Args:
95+
input (str): The text to send to the LLM.
96+
97+
Returns:
98+
LLMResponse: The response from the LLM.
99+
"""
100+
try:
101+
messages = self.get_conversation_history(input, chat_history)
102+
response = self.model.generate_content(messages, **self.model_params)
103+
return LLMResponse(content=response.text)
104+
except ResponseValidationError as e:
105+
raise LLMGenerationError(e)
106+
107+
async def ainvoke(
108+
self, input: str, chat_history: Optional[list[str]] = []
109+
) -> LLMResponse:
84110
"""Asynchronously sends text to the LLM and returns a response.
85111
86112
Args:
@@ -96,3 +122,17 @@ async def ainvoke(self, input: str) -> LLMResponse:
96122
return LLMResponse(content=response.text)
97123
except ResponseValidationError as e:
98124
raise LLMGenerationError(e)
125+
126+
def get_conversation_history(
127+
self,
128+
input: str,
129+
chat_history: list[str],
130+
) -> list[Content]:
131+
messages = []
132+
for i, message in enumerate(chat_history):
133+
if i % 2 == 0:
134+
messages.append(Content(role="user", parts=[Part.from_text(message)]))
135+
else:
136+
messages.append(Content(role="model", parts=[Part.from_text(message)]))
137+
messages.append(Content(role="user", parts=[Part.from_text(input)]))
138+
return messages

tests/unit/llm/test_openai_llm.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None:
3333

3434

3535
@patch("builtins.__import__")
36-
def test_openai_llm_happy_path(mock_import: Mock) -> None:
36+
def test_openai_llm_invoke_happy_path(mock_import: Mock) -> None:
3737
mock_openai = get_mock_openai()
3838
mock_import.return_value = mock_openai
3939
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
@@ -46,6 +46,20 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None:
4646
assert res.content == "openai chat response"
4747

4848

49+
@patch("builtins.__import__")
50+
def test_openai_llm_chat_happy_path(mock_import: Mock) -> None:
51+
mock_openai = get_mock_openai()
52+
mock_import.return_value = mock_openai
53+
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
54+
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
55+
)
56+
llm = OpenAILLM(api_key="my key", model_name="gpt")
57+
58+
res = llm.chat("my question", ["user message", "assistant message"])
59+
assert isinstance(res, LLMResponse)
60+
assert res.content == "openai chat response"
61+
62+
4963
@patch("builtins.__import__", side_effect=ImportError)
5064
def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None:
5165
with pytest.raises(ImportError):
@@ -71,3 +85,28 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None:
7185
res = llm.invoke("my text")
7286
assert isinstance(res, LLMResponse)
7387
assert res.content == "openai chat response"
88+
89+
90+
def test_openai_llm_get_conversation_history() -> None:
91+
system_instruction = "You are a helpful assistant."
92+
question = "When does it set?"
93+
chat_history = [
94+
"When does the sun come up in the summer?",
95+
"Usually around 6am.",
96+
"What about next season?",
97+
"Around 8am.",
98+
]
99+
expected_response = [
100+
{"role": "system", "content": "You are a helpful assistant."},
101+
{"role": "user", "content": "When does the sun come up in the summer?"},
102+
{"role": "assistant", "content": "Usually around 6am."},
103+
{"role": "user", "content": "What about next season?"},
104+
{"role": "assistant", "content": "Around 8am."},
105+
{"role": "user", "content": "When does it set?"},
106+
]
107+
108+
llm = OpenAILLM(
109+
api_key="my key", model_name="gpt", system_instruction=system_instruction
110+
)
111+
response = llm.get_conversation_history(question, chat_history)
112+
assert response == expected_response

tests/unit/llm/test_vertexai_llm.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1717

1818
import pytest
19-
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
19+
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM, Part, Content
2020

2121

2222
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
@@ -52,3 +52,36 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
5252
response = await llm.ainvoke(input_text)
5353
assert response.content == "Return text"
5454
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)
55+
56+
57+
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
58+
def test_vertexai_get_conversation_history(GenerativeModelMock: MagicMock) -> None:
59+
system_instruction = "You are a helpful assistant."
60+
question = "When does it set?"
61+
chat_history = [
62+
"When does the sun come up in the summer?",
63+
"Usually around 6am.",
64+
"What about next season?",
65+
"Around 8am.",
66+
]
67+
expected_response = [
68+
Content(
69+
role="user",
70+
parts=[Part.from_text("When does the sun come up in the summer?")],
71+
),
72+
Content(role="model", parts=[Part.from_text("Usually around 6am.")]),
73+
Content(role="user", parts=[Part.from_text("What about next season?")]),
74+
Content(role="model", parts=[Part.from_text("Around 8am.")]),
75+
Content(role="user", parts=[Part.from_text("When does it set?")]),
76+
]
77+
78+
llm = VertexAILLM(
79+
model_name="gemini-1.5-flash-001", system_instruction=system_instruction
80+
)
81+
response = llm.get_conversation_history(question, chat_history)
82+
83+
assert llm.system_instruction == system_instruction
84+
assert len(response) == len(expected_response)
85+
for actual, expected in zip(response, expected_response):
86+
assert actual.role == expected.role
87+
assert actual.parts[0].text == expected.parts[0].text

0 commit comments

Comments
 (0)