Skip to content

Commit 64b837d

Browse files
committed
Fix mypy and tests
1 parent 3b070bf commit 64b837d

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/neo4j_graphrag/llm/cohere_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def invoke(
124124
except self.cohere_api_error as e:
125125
raise LLMGenerationError(e)
126126
return LLMResponse(
127-
content=res.message.content[0].text,
127+
content=res.message.content[0].text if res.message.content else "",
128128
)
129129

130130
async def ainvoke(
@@ -148,12 +148,12 @@ async def ainvoke(
148148
if isinstance(message_history, MessageHistory):
149149
message_history = message_history.messages
150150
messages = self.get_messages(input, message_history, system_instruction)
151-
res = self.async_client.chat(
151+
res = await self.async_client.chat(
152152
messages=messages,
153153
model=self.model_name,
154154
)
155155
except self.cohere_api_error as e:
156156
raise LLMGenerationError(e)
157157
return LLMResponse(
158-
content=res.message.content[0].text,
158+
content=res.message.content[0].text if res.message.content else "",
159159
)

tests/unit/llm/test_cohere_llm.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None:
5050
def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None:
5151
chat_response_mock = MagicMock()
5252
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
53-
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
53+
mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat
54+
mock_cohere_client_chat.return_value = chat_response_mock
5455

5556
system_instruction = "You are a helpful assistant."
5657
llm = CohereLLM(model_name="something")
@@ -66,7 +67,7 @@ def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) ->
6667
messages = [{"role": "system", "content": system_instruction}]
6768
messages.extend(message_history)
6869
messages.append({"role": "user", "content": question})
69-
llm.client.chat.assert_called_once_with(
70+
mock_cohere_client_chat.assert_called_once_with(
7071
messages=messages,
7172
model="something",
7273
)
@@ -77,7 +78,8 @@ def test_cohere_llm_invoke_with_message_history_and_system_instruction(
7778
) -> None:
7879
chat_response_mock = MagicMock()
7980
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
80-
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
81+
mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat
82+
mock_cohere_client_chat.return_value = chat_response_mock
8183

8284
system_instruction = "You are a helpful assistant."
8385
llm = CohereLLM(model_name="gpt")
@@ -93,11 +95,10 @@ def test_cohere_llm_invoke_with_message_history_and_system_instruction(
9395
messages = [{"role": "system", "content": system_instruction}]
9496
messages.extend(message_history)
9597
messages.append({"role": "user", "content": question})
96-
llm.client.chat.assert_called_once_with(
98+
mock_cohere_client_chat.assert_called_once_with(
9799
messages=messages,
98100
model="gpt",
99101
)
100-
assert llm.client.chat.call_count == 1
101102

102103

103104
def test_cohere_llm_invoke_with_message_history_validation_error(
@@ -122,9 +123,12 @@ def test_cohere_llm_invoke_with_message_history_validation_error(
122123

123124
@pytest.mark.asyncio
124125
async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None:
125-
chat_response_mock = AsyncMock()
126-
chat_response_mock.message.content = [AsyncMock(text="cohere response text")]
127-
mock_cohere.AsyncClientV2.return_value.chat.return_value = chat_response_mock
126+
chat_response_mock = MagicMock(
127+
message=MagicMock(content=[MagicMock(text="cohere response text")])
128+
)
129+
mock_cohere.AsyncClientV2.return_value.chat = AsyncMock(
130+
return_value=chat_response_mock
131+
)
128132

129133
llm = CohereLLM(model_name="something")
130134
res = await llm.ainvoke("my text")

0 commit comments

Comments
 (0)