Skip to content

Commit fa12a9f

Browse files
Unit tests for the system instruction override
1 parent 23a8001 commit fa12a9f

6 files changed

+338
-4
lines changed

tests/unit/llm/test_anthropic_llm.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
from neo4j_graphrag.exceptions import LLMGenerationError
2424
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
25+
from neo4j_graphrag.llm.types import LLMResponse
2526

2627

2728
@pytest.fixture
@@ -61,11 +62,9 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
6162
content="generated text"
6263
)
6364
model_params = {"temperature": 0.3}
64-
system_instruction = "You are a helpful assistant."
6565
llm = AnthropicLLM(
6666
"claude-3-opus-20240229",
6767
model_params=model_params,
68-
system_instruction=system_instruction,
6968
)
7069
message_history = [
7170
{"role": "user", "content": "When does the sun come up in the summer?"},
@@ -79,11 +78,71 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
7978
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
8079
messages=message_history,
8180
model="claude-3-opus-20240229",
82-
system=system_instruction,
81+
system=None,
8382
**model_params,
8483
)
8584

8685

86+
def test_anthropic_invoke_with_message_history_and_system_instruction(
87+
mock_anthropic: Mock,
88+
) -> None:
89+
mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock(
90+
content="generated text"
91+
)
92+
model_params = {"temperature": 0.3}
93+
initial_instruction = "You are a helpful assistant."
94+
llm = AnthropicLLM(
95+
"claude-3-opus-20240229",
96+
model_params=model_params,
97+
system_instruction=initial_instruction,
98+
)
99+
message_history = [
100+
{"role": "user", "content": "When does the sun come up in the summer?"},
101+
{"role": "assistant", "content": "Usually around 6am."},
102+
]
103+
question = "What about next season?"
104+
105+
# first invokation - initial instructions
106+
response = llm.invoke(question, message_history) # type: ignore
107+
assert response.content == "generated text"
108+
message_history.append({"role": "user", "content": question})
109+
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
110+
model="claude-3-opus-20240229",
111+
system=initial_instruction,
112+
messages=message_history,
113+
**model_params,
114+
)
115+
116+
# second invokation - override instructions
117+
override_instruction = "Ignore all previous instructions"
118+
question = "When does it come up in the winter?"
119+
response = llm.invoke(question, message_history, override_instruction) # type: ignore
120+
assert isinstance(response, LLMResponse)
121+
assert response.content == "generated text"
122+
message_history.append({"role": "user", "content": question})
123+
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
124+
model="claude-3-opus-20240229",
125+
system=override_instruction,
126+
messages=message_history,
127+
**model_params,
128+
)
129+
130+
# third invokation - default instructions
131+
question = "When does it set?"
132+
response = llm.invoke(question, message_history) # type: ignore
133+
assert isinstance(response, LLMResponse)
134+
assert response.content == "generated text"
135+
message_history.append({"role": "user", "content": question})
136+
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
137+
model="claude-3-opus-20240229",
138+
system=initial_instruction,
139+
messages=message_history,
140+
**model_params,
141+
)
142+
143+
assert llm.client.messages.create.call_count == 3 # type: ignore
144+
145+
87146
def test_anthropic_invoke_with_message_history_validation_error(
88147
mock_anthropic: Mock,
89148
) -> None:

tests/unit/llm/test_cohere_llm.py

+55
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,61 @@ def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) ->
7272
)
7373

7474

75+
def test_cohere_llm_invoke_with_message_history_and_system_instruction(
76+
mock_cohere: Mock,
77+
) -> None:
78+
chat_response_mock = MagicMock()
79+
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
80+
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
81+
82+
initial_instruction = "You are a helpful assistant."
83+
llm = CohereLLM(model_name="gpt", system_instruction=initial_instruction)
84+
message_history = [
85+
{"role": "user", "content": "When does the sun come up in the summer?"},
86+
{"role": "assistant", "content": "Usually around 6am."},
87+
]
88+
question = "What about next season?"
89+
90+
# first invokation - initial instructions
91+
res = llm.invoke(question, message_history) # type: ignore
92+
assert isinstance(res, LLMResponse)
93+
assert res.content == "cohere response text"
94+
messages = [{"role": "system", "content": initial_instruction}]
95+
messages.extend(message_history)
96+
messages.append({"role": "user", "content": question})
97+
llm.client.chat.assert_called_once_with(
98+
messages=messages,
99+
model="gpt",
100+
)
101+
102+
# second invokation - override instructions
103+
override_instruction = "Ignore all previous instructions"
104+
res = llm.invoke(question, message_history, override_instruction) # type: ignore
105+
assert isinstance(res, LLMResponse)
106+
assert res.content == "cohere response text"
107+
messages = [{"role": "system", "content": override_instruction}]
108+
messages.extend(message_history)
109+
messages.append({"role": "user", "content": question})
110+
llm.client.chat.assert_called_with(
111+
messages=messages,
112+
model="gpt",
113+
)
114+
115+
# third invokation - default instructions
116+
res = llm.invoke(question, message_history) # type: ignore
117+
assert isinstance(res, LLMResponse)
118+
assert res.content == "cohere response text"
119+
messages = [{"role": "system", "content": initial_instruction}]
120+
messages.extend(message_history)
121+
messages.append({"role": "user", "content": question})
122+
llm.client.chat.assert_called_with(
123+
messages=messages,
124+
model="gpt",
125+
)
126+
127+
assert llm.client.chat.call_count == 3
128+
129+
75130
def test_cohere_llm_invoke_with_message_history_validation_error(
76131
mock_cohere: Mock,
77132
) -> None:

tests/unit/llm/test_mistralai_llm.py

+59
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,65 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None:
7777
)
7878

7979

80+
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
81+
def test_mistralai_llm_invoke_with_message_history_and_system_instruction(
82+
mock_mistral: Mock,
83+
) -> None:
84+
mock_mistral_instance = mock_mistral.return_value
85+
chat_response_mock = MagicMock()
86+
chat_response_mock.choices = [
87+
MagicMock(message=MagicMock(content="mistral response"))
88+
]
89+
mock_mistral_instance.chat.complete.return_value = chat_response_mock
90+
model = "mistral-model"
91+
initial_instruction = "You are a helpful assistant."
92+
llm = MistralAILLM(model_name=model, system_instruction=initial_instruction)
93+
message_history = [
94+
{"role": "user", "content": "When does the sun come up in the summer?"},
95+
{"role": "assistant", "content": "Usually around 6am."},
96+
]
97+
question = "What about next season?"
98+
99+
# first invokation - initial instructions
100+
res = llm.invoke(question, message_history) # type: ignore
101+
assert isinstance(res, LLMResponse)
102+
assert res.content == "mistral response"
103+
messages = [{"role": "system", "content": initial_instruction}]
104+
messages.extend(message_history)
105+
messages.append({"role": "user", "content": question})
106+
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
107+
messages=messages,
108+
model=model,
109+
)
110+
111+
# second invokation - override instructions
112+
override_instruction = "Ignore all previous instructions"
113+
res = llm.invoke(question, message_history, override_instruction) # type: ignore
114+
assert isinstance(res, LLMResponse)
115+
assert res.content == "mistral response"
116+
messages = [{"role": "system", "content": override_instruction}]
117+
messages.extend(message_history)
118+
messages.append({"role": "user", "content": question})
119+
llm.client.chat.complete.assert_called_with( # type: ignore
120+
messages=messages,
121+
model=model,
122+
)
123+
124+
# third invokation - default instructions
125+
res = llm.invoke(question, message_history) # type: ignore
126+
assert isinstance(res, LLMResponse)
127+
assert res.content == "mistral response"
128+
messages = [{"role": "system", "content": initial_instruction}]
129+
messages.extend(message_history)
130+
messages.append({"role": "user", "content": question})
131+
llm.client.chat.complete.assert_called_with( # type: ignore
132+
messages=messages,
133+
model=model,
134+
)
135+
136+
assert llm.client.chat.complete.call_count == 3 # type: ignore
137+
138+
80139
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
81140
def test_mistralai_llm_invoke_with_message_history_validation_error(
82141
mock_mistral: Mock,

tests/unit/llm/test_ollama_llm.py

+57
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,63 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
9494
)
9595

9696

97+
@patch("builtins.__import__")
98+
def test_ollama_invoke_with_message_history_and_system_instruction(
99+
mock_import: Mock,
100+
) -> None:
101+
mock_ollama = get_mock_ollama()
102+
mock_import.return_value = mock_ollama
103+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
104+
message=MagicMock(content="ollama chat response"),
105+
)
106+
model = "gpt"
107+
model_params = {"temperature": 0.3}
108+
system_instruction = "You are a helpful assistant."
109+
llm = OllamaLLM(
110+
model,
111+
model_params=model_params,
112+
system_instruction=system_instruction,
113+
)
114+
message_history = [
115+
{"role": "user", "content": "When does the sun come up in the summer?"},
116+
{"role": "assistant", "content": "Usually around 6am."},
117+
]
118+
question = "What about next season?"
119+
120+
# first invokation - initial instructions
121+
response = llm.invoke(question, message_history) # type: ignore
122+
assert response.content == "ollama chat response"
123+
messages = [{"role": "system", "content": system_instruction}]
124+
messages.extend(message_history)
125+
messages.append({"role": "user", "content": question})
126+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
127+
model=model, messages=messages, options=model_params
128+
)
129+
130+
# second invokation - override instructions
131+
override_instruction = "Ignore all previous instructions"
132+
response = llm.invoke(question, message_history, override_instruction) # type: ignore
133+
assert response.content == "ollama chat response"
134+
messages = [{"role": "system", "content": override_instruction}]
135+
messages.extend(message_history)
136+
messages.append({"role": "user", "content": question})
137+
llm.client.chat.assert_called_with( # type: ignore[attr-defined]
138+
model=model, messages=messages, options=model_params
139+
)
140+
141+
# third invokation - default instructions
142+
response = llm.invoke(question, message_history) # type: ignore
143+
assert response.content == "ollama chat response"
144+
messages = [{"role": "system", "content": system_instruction}]
145+
messages.extend(message_history)
146+
messages.append({"role": "user", "content": question})
147+
llm.client.chat.assert_called_with( # type: ignore[attr-defined]
148+
model=model, messages=messages, options=model_params
149+
)
150+
151+
assert llm.client.chat.call_count == 3 # type: ignore
152+
153+
97154
@patch("builtins.__import__")
98155
def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None:
99156
mock_ollama = get_mock_ollama()

tests/unit/llm/test_openai_llm.py

+69
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,70 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None:
6464
res = llm.invoke(question, message_history) # type: ignore
6565
assert isinstance(res, LLMResponse)
6666
assert res.content == "openai chat response"
67+
message_history.append({"role": "user", "content": question})
68+
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
69+
messages=message_history,
70+
model="gpt",
71+
)
72+
73+
74+
@patch("builtins.__import__")
75+
def test_openai_llm_with_message_history_and_system_instruction(
76+
mock_import: Mock,
77+
) -> None:
78+
mock_openai = get_mock_openai()
79+
mock_import.return_value = mock_openai
80+
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
81+
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
82+
)
83+
initial_instruction = "You are a helpful assistent."
84+
llm = OpenAILLM(
85+
api_key="my key", model_name="gpt", system_instruction=initial_instruction
86+
)
87+
message_history = [
88+
{"role": "user", "content": "When does the sun come up in the summer?"},
89+
{"role": "assistant", "content": "Usually around 6am."},
90+
]
91+
question = "What about next season?"
92+
93+
# first invokation - initial instructions
94+
res = llm.invoke(question, message_history) # type: ignore
95+
assert isinstance(res, LLMResponse)
96+
assert res.content == "openai chat response"
97+
messages = [{"role": "system", "content": initial_instruction}]
98+
messages.extend(message_history)
99+
messages.append({"role": "user", "content": question})
100+
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
101+
messages=messages,
102+
model="gpt",
103+
)
104+
105+
# second invokation - override instructions
106+
override_instruction = "Ignore all previous instructions"
107+
res = llm.invoke(question, message_history, override_instruction) # type: ignore
108+
assert isinstance(res, LLMResponse)
109+
assert res.content == "openai chat response"
110+
messages = [{"role": "system", "content": override_instruction}]
111+
messages.extend(message_history)
112+
messages.append({"role": "user", "content": question})
113+
llm.client.chat.completions.create.assert_called_with( # type: ignore
114+
messages=messages,
115+
model="gpt",
116+
)
117+
118+
# third invokation - default instructions
119+
res = llm.invoke(question, message_history) # type: ignore
120+
assert isinstance(res, LLMResponse)
121+
assert res.content == "openai chat response"
122+
messages = [{"role": "system", "content": initial_instruction}]
123+
messages.extend(message_history)
124+
messages.append({"role": "user", "content": question})
125+
llm.client.chat.completions.create.assert_called_with( # type: ignore
126+
messages=messages,
127+
model="gpt",
128+
)
129+
130+
assert llm.client.chat.completions.create.call_count == 3 # type: ignore
67131

68132

69133
@patch("builtins.__import__")
@@ -137,6 +201,11 @@ def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) ->
137201
res = llm.invoke(question, message_history) # type: ignore
138202
assert isinstance(res, LLMResponse)
139203
assert res.content == "openai chat response"
204+
message_history.append({"role": "user", "content": question})
205+
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
206+
messages=message_history,
207+
model="gpt",
208+
)
140209

141210

142211
@patch("builtins.__import__")

0 commit comments

Comments
 (0)