Skip to content

Commit 98167ff

Browse files
committed
Rename method and return directly list[LLMMessage]
1 parent e09f1a1 commit 98167ff

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

src/neo4j_graphrag/llm/base.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from neo4j_graphrag.tool import Tool
3131

3232
from .rate_limit import RateLimitHandler
33-
from .utils import legacy_inputs_to_message_history
33+
from .utils import legacy_inputs_to_messages
3434

3535

3636
class LLMInterface(ABC):
@@ -65,10 +65,8 @@ def invoke(
6565
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
6666
system_instruction: Optional[str] = None,
6767
) -> LLMResponse:
68-
message_history = legacy_inputs_to_message_history(
69-
input, message_history, system_instruction
70-
)
71-
return self._invoke(message_history.messages)
68+
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
69+
return self._invoke(messages)
7270

7371
@abstractmethod
7472
def _invoke(
@@ -94,10 +92,8 @@ async def ainvoke(
9492
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
9593
system_instruction: Optional[str] = None,
9694
) -> LLMResponse:
97-
message_history = legacy_inputs_to_message_history(
98-
input, message_history, system_instruction
99-
)
100-
return await self._ainvoke(message_history.messages)
95+
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
96+
return await self._ainvoke(messages)
10197

10298
@abstractmethod
10399
async def _ainvoke(
@@ -142,10 +138,8 @@ def invoke_with_tools(
142138
LLMGenerationError: If anything goes wrong.
143139
NotImplementedError: If the LLM provider does not support tool calling.
144140
"""
145-
history = legacy_inputs_to_message_history(
146-
input, message_history, system_instruction
147-
)
148-
return self._invoke_with_tools(history.messages, tools)
141+
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
142+
return self._invoke_with_tools(messages, tools)
149143

150144
def _invoke_with_tools(
151145
self, inputs: list[LLMMessage], tools: Sequence[Tool]
@@ -177,10 +171,8 @@ async def ainvoke_with_tools(
177171
LLMGenerationError: If anything goes wrong.
178172
NotImplementedError: If the LLM provider does not support tool calling.
179173
"""
180-
history = legacy_inputs_to_message_history(
181-
input, message_history, system_instruction
182-
)
183-
return await self._ainvoke_with_tools(history.messages, tools)
174+
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
175+
return await self._ainvoke_with_tools(messages, tools)
184176

185177
async def _ainvoke_with_tools(
186178
self, inputs: list[LLMMessage], tools: Sequence[Tool]

src/neo4j_graphrag/llm/utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,57 @@
11
import warnings
22
from typing import Union, Optional
33

4-
from neo4j_graphrag.message_history import MessageHistory, InMemoryMessageHistory
4+
from neo4j_graphrag.message_history import MessageHistory
55
from neo4j_graphrag.types import LLMMessage
66

77

8-
def legacy_inputs_to_message_history(
8+
def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None:
9+
for message in messages:
10+
if message["role"] == "system":
11+
return message["content"]
12+
return None
13+
14+
15+
def legacy_inputs_to_messages(
916
input: Union[str, list[LLMMessage], MessageHistory],
1017
message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None,
1118
system_instruction: Optional[str] = None,
12-
) -> MessageHistory:
19+
) -> list[LLMMessage]:
1320
if message_history:
1421
warnings.warn(
1522
"Using message_history parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.",
1623
DeprecationWarning,
1724
)
1825
if isinstance(message_history, MessageHistory):
19-
history = message_history
26+
messages = message_history.messages
2027
else: # list[LLMMessage]
21-
history = InMemoryMessageHistory(message_history)
28+
messages = []
2229
else:
23-
history = InMemoryMessageHistory()
30+
messages = []
2431
if system_instruction is not None:
2532
warnings.warn(
2633
"Using system_instruction parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.",
2734
DeprecationWarning,
2835
)
29-
if history.is_empty():
30-
history.add_message(
36+
if system_instruction_from_messages(messages) is not None:
37+
warnings.warn(
38+
"system_instruction provided but ignored as the message history already contains a system message",
39+
RuntimeWarning,
40+
)
41+
else:
42+
messages.append(
3143
LLMMessage(
3244
role="system",
3345
content=system_instruction,
3446
),
3547
)
36-
else:
37-
warnings.warn(
38-
"system_instruction provided but ignored as the message history is not empty",
39-
RuntimeWarning,
40-
)
48+
4149
if isinstance(input, str):
42-
history.add_message(LLMMessage(role="user", content=input))
43-
return history
50+
messages.append(LLMMessage(role="user", content=input))
51+
return messages
4452
if isinstance(input, list):
45-
history.add_messages(input)
46-
return history
53+
messages.extend(input)
54+
return messages
4755
# input is a MessageHistory instance
48-
history.add_messages(input.messages)
49-
return history
56+
messages.extend(input.messages)
57+
return messages

0 commit comments

Comments
 (0)