13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
- from typing import Any , Iterable , Optional , TYPE_CHECKING
16
+ from typing import Any , Iterable , Optional , TYPE_CHECKING , cast
17
17
18
18
from pydantic import ValidationError
19
19
20
20
from neo4j_graphrag .exceptions import LLMGenerationError
21
21
from neo4j_graphrag .llm .base import LLMInterface
22
- from neo4j_graphrag .llm .types import LLMResponse , MessageList , UserMessage , BaseMessage
22
+ from neo4j_graphrag .llm .types import LLMResponse , MessageList , UserMessage
23
23
24
24
if TYPE_CHECKING :
25
25
from anthropic .types .message_param import MessageParam
@@ -71,22 +71,22 @@ def __init__(
71
71
self .async_client = anthropic .AsyncAnthropic (** kwargs )
72
72
73
73
def get_messages (
74
- self , input : str , message_history : Optional [list [BaseMessage ]] = None
74
+ self , input : str , message_history : Optional [list [dict [ str , str ] ]] = None
75
75
) -> Iterable [MessageParam ]:
76
76
messages = []
77
77
if message_history :
78
78
try :
79
- MessageList (messages = message_history )
79
+ MessageList (messages = message_history ) # type: ignore
80
80
except ValidationError as e :
81
81
raise LLMGenerationError (e .errors ()) from e
82
82
messages .extend (message_history )
83
83
messages .append (UserMessage (content = input ).model_dump ())
84
- return messages
84
+ return cast ( Iterable [ MessageParam ], messages )
85
85
86
86
def invoke (
87
87
self ,
88
88
input : str ,
89
- message_history : Optional [list [BaseMessage ]] = None ,
89
+ message_history : Optional [list [dict [ str , str ] ]] = None ,
90
90
system_instruction : Optional [str ] = None ,
91
91
) -> LLMResponse :
92
92
"""Sends text to the LLM and returns a response.
@@ -108,18 +108,18 @@ def invoke(
108
108
)
109
109
response = self .client .messages .create (
110
110
model = self .model_name ,
111
- system = system_message ,
111
+ system = system_message , # type: ignore
112
112
messages = messages ,
113
113
** self .model_params ,
114
114
)
115
- return LLMResponse (content = response .content )
115
+ return LLMResponse (content = response .content ) # type: ignore
116
116
except self .anthropic .APIError as e :
117
117
raise LLMGenerationError (e )
118
118
119
119
async def ainvoke (
120
120
self ,
121
121
input : str ,
122
- message_history : Optional [list [BaseMessage ]] = None ,
122
+ message_history : Optional [list [dict [ str , str ] ]] = None ,
123
123
system_instruction : Optional [str ] = None ,
124
124
) -> LLMResponse :
125
125
"""Asynchronously sends text to the LLM and returns a response.
@@ -141,10 +141,10 @@ async def ainvoke(
141
141
)
142
142
response = await self .async_client .messages .create (
143
143
model = self .model_name ,
144
- system = system_message ,
144
+ system = system_message , # type: ignore
145
145
messages = messages ,
146
146
** self .model_params ,
147
147
)
148
- return LLMResponse (content = response .content )
148
+ return LLMResponse (content = response .content ) # type: ignore
149
149
except self .anthropic .APIError as e :
150
150
raise LLMGenerationError (e )
0 commit comments