Skip to content

Commit c42e732

Browse files
Revert BaseMessage class type
* bring back list[dicy[str,str]] type declaration for the `message_history` parameter
1 parent a749a9e commit c42e732

File tree

9 files changed

+59
-58
lines changed

9 files changed

+59
-58
lines changed

src/neo4j_graphrag/generation/graphrag.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
)
3232
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
3333
from neo4j_graphrag.llm import LLMInterface
34-
from neo4j_graphrag.llm.types import BaseMessage
3534
from neo4j_graphrag.retrievers.base import Retriever
3635
from neo4j_graphrag.types import RetrieverResult
3736

@@ -88,7 +87,7 @@ def __init__(
8887
def search(
8988
self,
9089
query_text: str = "",
91-
message_history: Optional[list[BaseMessage]] = None,
90+
message_history: Optional[list[dict[str, str]]] = None,
9291
examples: str = "",
9392
retriever_config: Optional[dict[str, Any]] = None,
9493
return_context: bool | None = None,
@@ -148,7 +147,7 @@ def search(
148147
return RagResultModel(**result)
149148

150149
def build_query(
151-
self, query_text: str, message_history: Optional[list[BaseMessage]] = None
150+
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
152151
) -> str:
153152
if message_history:
154153
summarization_prompt = ChatSummaryTemplate().format(

src/neo4j_graphrag/generation/prompts.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import warnings
1818
from typing import Any, Optional
1919

20-
from neo4j_graphrag.llm.types import BaseMessage
2120
from neo4j_graphrag.exceptions import (
2221
PromptMissingInputError,
2322
PromptMissingPlaceholderError,
@@ -208,9 +207,10 @@ class ChatSummaryTemplate(PromptTemplate):
208207
EXPECTED_INPUTS = ["message_history"]
209208
SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"
210209

211-
def format(self, message_history: list[BaseMessage]) -> str:
210+
def format(self, message_history: list[dict[str, str]]) -> str:
212211
message_list = [
213-
f"{message.role}: {message.content}" for message in message_history
212+
": ".join([f"{value}" for _, value in message.items()])
213+
for message in message_history
214214
]
215215
history = "\n".join(message_list)
216216
return super().format(message_history=history)

src/neo4j_graphrag/llm/anthropic_llm.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, Iterable, Optional, TYPE_CHECKING
16+
from typing import Any, Iterable, Optional, TYPE_CHECKING, cast
1717

1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage
22+
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage
2323

2424
if TYPE_CHECKING:
2525
from anthropic.types.message_param import MessageParam
@@ -71,22 +71,22 @@ def __init__(
7171
self.async_client = anthropic.AsyncAnthropic(**kwargs)
7272

7373
def get_messages(
74-
self, input: str, message_history: Optional[list[BaseMessage]] = None
74+
self, input: str, message_history: Optional[list[dict[str, str]]] = None
7575
) -> Iterable[MessageParam]:
7676
messages = []
7777
if message_history:
7878
try:
79-
MessageList(messages=message_history)
79+
MessageList(messages=message_history) # type: ignore
8080
except ValidationError as e:
8181
raise LLMGenerationError(e.errors()) from e
8282
messages.extend(message_history)
8383
messages.append(UserMessage(content=input).model_dump())
84-
return messages
84+
return cast(Iterable[MessageParam], messages)
8585

8686
def invoke(
8787
self,
8888
input: str,
89-
message_history: Optional[list[BaseMessage]] = None,
89+
message_history: Optional[list[dict[str, str]]] = None,
9090
system_instruction: Optional[str] = None,
9191
) -> LLMResponse:
9292
"""Sends text to the LLM and returns a response.
@@ -108,18 +108,18 @@ def invoke(
108108
)
109109
response = self.client.messages.create(
110110
model=self.model_name,
111-
system=system_message,
111+
system=system_message, # type: ignore
112112
messages=messages,
113113
**self.model_params,
114114
)
115-
return LLMResponse(content=response.content)
115+
return LLMResponse(content=response.content) # type: ignore
116116
except self.anthropic.APIError as e:
117117
raise LLMGenerationError(e)
118118

119119
async def ainvoke(
120120
self,
121121
input: str,
122-
message_history: Optional[list[BaseMessage]] = None,
122+
message_history: Optional[list[dict[str, str]]] = None,
123123
system_instruction: Optional[str] = None,
124124
) -> LLMResponse:
125125
"""Asynchronously sends text to the LLM and returns a response.
@@ -141,10 +141,10 @@ async def ainvoke(
141141
)
142142
response = await self.async_client.messages.create(
143143
model=self.model_name,
144-
system=system_message,
144+
system=system_message, # type: ignore
145145
messages=messages,
146146
**self.model_params,
147147
)
148-
return LLMResponse(content=response.content)
148+
return LLMResponse(content=response.content) # type: ignore
149149
except self.anthropic.APIError as e:
150150
raise LLMGenerationError(e)

src/neo4j_graphrag/llm/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from typing import Any, Optional
1919

20-
from .types import LLMResponse, BaseMessage
20+
from .types import LLMResponse
2121

2222

2323
class LLMInterface(ABC):
@@ -45,7 +45,7 @@ def __init__(
4545
def invoke(
4646
self,
4747
input: str,
48-
message_history: Optional[list[BaseMessage]] = None,
48+
message_history: Optional[list[dict[str, str]]] = None,
4949
system_instruction: Optional[str] = None,
5050
) -> LLMResponse:
5151
"""Sends a text input to the LLM and retrieves a response.
@@ -66,7 +66,7 @@ def invoke(
6666
async def ainvoke(
6767
self,
6868
input: str,
69-
message_history: Optional[list[BaseMessage]] = None,
69+
message_history: Optional[list[dict[str, str]]] = None,
7070
system_instruction: Optional[str] = None,
7171
) -> LLMResponse:
7272
"""Asynchronously sends a text input to the LLM and retrieves a response.

src/neo4j_graphrag/llm/cohere_llm.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Any, Optional
17+
from typing import TYPE_CHECKING, Any, Optional, cast
1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
@@ -24,7 +24,6 @@
2424
MessageList,
2525
SystemMessage,
2626
UserMessage,
27-
BaseMessage,
2827
)
2928

3029
if TYPE_CHECKING:
@@ -77,7 +76,7 @@ def __init__(
7776
def get_messages(
7877
self,
7978
input: str,
80-
message_history: Optional[list[BaseMessage]] = None,
79+
message_history: Optional[list[dict[str, str]]] = None,
8180
system_instruction: Optional[str] = None,
8281
) -> ChatMessages:
8382
messages = []
@@ -90,17 +89,17 @@ def get_messages(
9089
messages.append(SystemMessage(content=system_message).model_dump())
9190
if message_history:
9291
try:
93-
MessageList(messages=message_history)
92+
MessageList(messages=message_history) # type: ignore
9493
except ValidationError as e:
9594
raise LLMGenerationError(e.errors()) from e
9695
messages.extend(message_history)
9796
messages.append(UserMessage(content=input).model_dump())
98-
return messages
97+
return cast(ChatMessages, messages)
9998

10099
def invoke(
101100
self,
102101
input: str,
103-
message_history: Optional[list[BaseMessage]] = None,
102+
message_history: Optional[list[dict[str, str]]] = None,
104103
system_instruction: Optional[str] = None,
105104
) -> LLMResponse:
106105
"""Sends text to the LLM and returns a response.
@@ -128,7 +127,7 @@ def invoke(
128127
async def ainvoke(
129128
self,
130129
input: str,
131-
message_history: Optional[list[BaseMessage]] = None,
130+
message_history: Optional[list[dict[str, str]]] = None,
132131
system_instruction: Optional[str] = None,
133132
) -> LLMResponse:
134133
"""Asynchronously sends text to the LLM and returns a response.

src/neo4j_graphrag/llm/mistralai_llm.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
MessageList,
2626
SystemMessage,
2727
UserMessage,
28-
BaseMessage,
2928
)
3029

3130
try:
@@ -68,7 +67,7 @@ def __init__(
6867
def get_messages(
6968
self,
7069
input: str,
71-
message_history: Optional[list[BaseMessage]] = None,
70+
message_history: Optional[list[dict[str, str]]] = None,
7271
system_instruction: Optional[str] = None,
7372
) -> list[Messages]:
7473
messages = []
@@ -81,7 +80,7 @@ def get_messages(
8180
messages.append(SystemMessage(content=system_message).model_dump())
8281
if message_history:
8382
try:
84-
MessageList(messages=message_history)
83+
MessageList(messages=message_history) # type: ignore
8584
except ValidationError as e:
8685
raise LLMGenerationError(e.errors()) from e
8786
messages.extend(message_history)
@@ -91,7 +90,7 @@ def get_messages(
9190
def invoke(
9291
self,
9392
input: str,
94-
message_history: Optional[list[BaseMessage]] = None,
93+
message_history: Optional[list[dict[str, str]]] = None,
9594
system_instruction: Optional[str] = None,
9695
) -> LLMResponse:
9796
"""Sends a text input to the Mistral chat completion model
@@ -127,7 +126,7 @@ def invoke(
127126
async def ainvoke(
128127
self,
129128
input: str,
130-
message_history: Optional[list[BaseMessage]] = None,
129+
message_history: Optional[list[dict[str, str]]] = None,
131130
system_instruction: Optional[str] = None,
132131
) -> LLMResponse:
133132
"""Asynchronously sends a text input to the MistralAI chat

src/neo4j_graphrag/llm/ollama_llm.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Optional, Sequence, TYPE_CHECKING
15+
from typing import Any, Optional, Sequence, TYPE_CHECKING, cast
1616

1717
from pydantic import ValidationError
1818

1919
from neo4j_graphrag.exceptions import LLMGenerationError
2020

2121
from .base import LLMInterface
22-
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
22+
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
2323

2424
if TYPE_CHECKING:
2525
from ollama import Message
@@ -52,7 +52,7 @@ def __init__(
5252
def get_messages(
5353
self,
5454
input: str,
55-
message_history: Optional[list[BaseMessage]] = None,
55+
message_history: Optional[list[dict[str, str]]] = None,
5656
system_instruction: Optional[str] = None,
5757
) -> Sequence[Message]:
5858
messages = []
@@ -65,17 +65,17 @@ def get_messages(
6565
messages.append(SystemMessage(content=system_message).model_dump())
6666
if message_history:
6767
try:
68-
MessageList(messages=message_history)
68+
MessageList(messages=message_history) # type: ignore
6969
except ValidationError as e:
7070
raise LLMGenerationError(e.errors()) from e
7171
messages.extend(message_history)
7272
messages.append(UserMessage(content=input).model_dump())
73-
return messages
73+
return cast(Sequence[Message], messages)
7474

7575
def invoke(
7676
self,
7777
input: str,
78-
message_history: Optional[list[BaseMessage]] = None,
78+
message_history: Optional[list[dict[str, str]]] = None,
7979
system_instruction: Optional[str] = None,
8080
) -> LLMResponse:
8181
"""Sends text to the LLM and returns a response.
@@ -102,7 +102,7 @@ def invoke(
102102
async def ainvoke(
103103
self,
104104
input: str,
105-
message_history: Optional[list[BaseMessage]] = None,
105+
message_history: Optional[list[dict[str, str]]] = None,
106106
system_instruction: Optional[str] = None,
107107
) -> LLMResponse:
108108
"""Asynchronously sends a text input to the OpenAI chat

src/neo4j_graphrag/llm/openai_llm.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
from typing import TYPE_CHECKING, Any, Iterable, Optional
18+
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
1919

2020
from pydantic import ValidationError
2121

2222
from ..exceptions import LLMGenerationError
2323
from .base import LLMInterface
24-
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
24+
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
2525

2626
if TYPE_CHECKING:
2727
import openai
@@ -63,7 +63,7 @@ def __init__(
6363
def get_messages(
6464
self,
6565
input: str,
66-
message_history: Optional[list[BaseMessage]] = None,
66+
message_history: Optional[list[dict[str, str]]] = None,
6767
system_instruction: Optional[str] = None,
6868
) -> Iterable[ChatCompletionMessageParam]:
6969
messages = []
@@ -76,17 +76,17 @@ def get_messages(
7676
messages.append(SystemMessage(content=system_message).model_dump())
7777
if message_history:
7878
try:
79-
MessageList(messages=message_history)
79+
MessageList(messages=message_history) # type: ignore
8080
except ValidationError as e:
8181
raise LLMGenerationError(e.errors()) from e
8282
messages.extend(message_history)
8383
messages.append(UserMessage(content=input).model_dump())
84-
return messages
84+
return cast(Iterable[ChatCompletionMessageParam], messages)
8585

8686
def invoke(
8787
self,
8888
input: str,
89-
message_history: Optional[list[BaseMessage]] = None,
89+
message_history: Optional[list[dict[str, str]]] = None,
9090
system_instruction: Optional[str] = None,
9191
) -> LLMResponse:
9292
"""Sends a text input to the OpenAI chat completion model
@@ -117,7 +117,7 @@ def invoke(
117117
async def ainvoke(
118118
self,
119119
input: str,
120-
message_history: Optional[list[BaseMessage]] = None,
120+
message_history: Optional[list[dict[str, str]]] = None,
121121
system_instruction: Optional[str] = None,
122122
) -> LLMResponse:
123123
"""Asynchronously sends a text input to the OpenAI chat

0 commit comments

Comments
 (0)