From 955692ace8aabbbb91fcc9bbde326e09b86f89b5 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 16 Oct 2025 19:24:27 -0400 Subject: [PATCH 1/6] tests(xai): rfc, expand reasoning tests, check for `model_provider` --- .../integration_tests/test_chat_models.py | 82 +++++++++++++++++++ .../test_chat_models_standard.py | 71 ++++++---------- .../test_chat_models_standard.ambr | 2 +- .../unit_tests/test_chat_models_standard.py | 6 +- libs/partners/xai/uv.lock | 9 +- 5 files changed, 117 insertions(+), 53 deletions(-) create mode 100644 libs/partners/xai/tests/integration_tests/test_chat_models.py diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py new file mode 100644 index 0000000000000..a328668470007 --- /dev/null +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -0,0 +1,82 @@ +"""Integration tests for ChatXAI specific features.""" + +from __future__ import annotations + +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk + +from langchain_xai import ChatXAI + +MODEL_NAME = "grok-4-fast-reasoning" + + +def test_reasoning() -> None: + """Test reasoning features.""" + # Test reasoning effort + chat_model = ChatXAI( + # grok-4 doesn't support reasoning_effort + model="grok-3-mini", + reasoning_effort="low", + ) + response = chat_model.invoke("What is 3^3?") + assert response.content + assert response.additional_kwargs["reasoning_content"] + + # Test default reasoning + chat_model = ChatXAI( + model=MODEL_NAME, + ) + response = chat_model.invoke("What is 3^3?") + assert response.content + assert response.additional_kwargs["reasoning_content"] + + # Test streaming + full: BaseMessageChunk | None = None + for chunk in chat_model.stream("What is 3^3?"): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["reasoning_content"] + + # Check that we can access reasoning content blocks + assert response.content_blocks + reasoning_content = ( + block for block in response.content_blocks if block["type"] == "reasoning" + ) + assert len(list(reasoning_content)) >= 1 + + # Test that passing message with reasoning back in works + followup = chat_model.invoke([response, "Based on your reasoning, what is 4^4?"]) + assert followup.content + assert followup.additional_kwargs["reasoning_content"] + followup_reasoning = ( + block for block in followup.content_blocks if block["type"] == "reasoning" + ) + assert len(list(followup_reasoning)) >= 1 + + # Test passing in a ReasoningContentBlock + msg_w_reasoning = AIMessage(content_blocks=response.content_blocks) + followup_2 = chat_model.invoke( + [msg_w_reasoning, "Based on your reasoning, what is 5^5?"] + ) + assert followup_2.content + assert followup_2.additional_kwargs["reasoning_content"] + + +def test_web_search() -> None: + llm = ChatXAI( + model=MODEL_NAME, + search_parameters={"mode": "on", "max_search_results": 3}, + ) + + # Test invoke + response = llm.invoke("Provide me a digest of world news in the last 24 hours.") + assert response.content + assert response.additional_kwargs["citations"] + assert len(response.additional_kwargs["citations"]) <= 3 + + # Test streaming + full = None + for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["citations"] + assert len(full.additional_kwargs["citations"]) <= 3 diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index dada7d5c4bcd5..095b4589cb68b 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, override -from langchain_core.messages import AIMessageChunk, BaseMessageChunk +import pytest +from langchain_core.messages import AIMessage from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] @@ -15,15 +16,12 @@ if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel -# Initialize the rate limiter in global scope, so it can be re-used -# across tests. +# Initialize the rate limiter in global scope, so it can be re-used across tests rate_limiter = InMemoryRateLimiter( requests_per_second=0.5, ) - -# Not using Grok 4 since it doesn't support reasoning params (effort) or returns -# reasoning content. +MODEL_NAME = "grok-4-fast-reasoning" class TestXAIStandard(ChatModelIntegrationTests): @@ -33,48 +31,29 @@ def chat_model_class(self) -> type[BaseChatModel]: @property def chat_model_params(self) -> dict: - # TODO: bump to test new Grok once they implement other features return { - "model": "grok-3", + "model": MODEL_NAME, "rate_limiter": rate_limiter, "stream_usage": True, } - -def test_reasoning_content() -> None: - """Test reasoning content.""" - chat_model = ChatXAI( - model="grok-3-mini", - reasoning_effort="low", - ) - response = chat_model.invoke("What is 3^3?") - assert response.content - assert response.additional_kwargs["reasoning_content"] - - # Test streaming - full: BaseMessageChunk | None = None - for chunk in chat_model.stream("What is 3^3?"): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["reasoning_content"] - - -def test_web_search() -> None: - llm = ChatXAI( - model="grok-3", - search_parameters={"mode": "auto", "max_search_results": 3}, + @pytest.mark.xfail( + reason="Default model does not support stop sequences, using grok-3 instead" ) - - # Test invoke - response = llm.invoke("Provide me a digest of world news in the last 24 hours.") - assert response.content - assert response.additional_kwargs["citations"] - assert len(response.additional_kwargs["citations"]) <= 3 - - # Test streaming - full = None - for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["citations"] - assert len(full.additional_kwargs["citations"]) <= 3 + @override + def test_stop_sequence(self, model: BaseChatModel) -> None: + """Override to use `grok-3` which supports stop sequences.""" + params = self.chat_model_params + params["model"] = "grok-3" + + grok3_model = ChatXAI(**params) + + result = grok3_model.invoke("hi", stop=["you"]) + assert isinstance(result, AIMessage) + + custom_model = ChatXAI( + **params, + stop_sequences=["you"], + ) + result = custom_model.invoke("hi") + assert isinstance(result, AIMessage) diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 4cd1261555c90..6f1cd18c2e2dc 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -9,7 +9,7 @@ 'kwargs': dict({ 'max_retries': 2, 'max_tokens': 100, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py index 7c9a947a913ed..7224655db79ed 100644 --- a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py @@ -7,6 +7,8 @@ from langchain_xai import ChatXAI +MODEL_NAME = "grok-4" + class TestXAIStandard(ChatModelUnitTests): @property @@ -15,7 +17,7 @@ def chat_model_class(self) -> type[BaseChatModel]: @property def chat_model_params(self) -> dict: - return {"model": "grok-beta"} + return {"model": MODEL_NAME} @property def init_from_env_params(self) -> tuple[dict, dict, dict]: @@ -24,7 +26,7 @@ def init_from_env_params(self) -> tuple[dict, dict, dict]: "XAI_API_KEY": "api_key", }, { - "model": "grok-beta", + "model": MODEL_NAME, }, { "xai_api_key": "api_key", diff --git a/libs/partners/xai/uv.lock b/libs/partners/xai/uv.lock index ece0030930e52..83d4569151fcb 100644 --- a/libs/partners/xai/uv.lock +++ b/libs/partners/xai/uv.lock @@ -357,7 +357,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -621,7 +621,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0a6" +version = "1.0.0rc2" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -679,7 +679,7 @@ typing = [ [[package]] name = "langchain-openai" -version = "1.0.0a3" +version = "1.0.0a4" source = { editable = "../openai" } dependencies = [ { name = "langchain-core" }, @@ -699,6 +699,7 @@ dev = [{ name = "langchain-core", editable = "../../core" }] lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] test = [ { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, + { name = "langchain", editable = "../../langchain_v1" }, { name = "langchain-core", editable = "../../core" }, { name = "langchain-tests", editable = "../../standard-tests" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" }, @@ -728,7 +729,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "1.0.0a2" +version = "1.0.0rc1" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" }, From 32c453f443fed8ab5653ab3bc65a81219e3bf3ae Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 16 Oct 2025 19:26:31 -0400 Subject: [PATCH 2/6] fix(xai): inject `model_provider` into `response_metadata` --- libs/partners/xai/langchain_xai/chat_models.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index c3e1b9bdad6d8..8fcff2efc1903 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -527,6 +527,11 @@ def _create_chat_result( ) -> ChatResult: rtn = super()._create_chat_result(response, generation_info) + for generation in rtn.generations: + if generation.message.response_metadata is None: + generation.message.response_metadata = {} + generation.message.response_metadata["model_provider"] = "xai" + if not isinstance(response, openai.BaseModel): return rtn @@ -553,6 +558,12 @@ def _convert_chunk_to_generation_chunk( default_chunk_class, base_generation_info, ) + + if generation_chunk: + if generation_chunk.message.response_metadata is None: + generation_chunk.message.response_metadata = {} + generation_chunk.message.response_metadata["model_provider"] = "xai" + if (choices := chunk.get("choices")) and generation_chunk: top = choices[0] if isinstance(generation_chunk.message, AIMessageChunk) and ( From 978ad9d32d2b9237e83739338f42e874729312da Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 16 Oct 2025 19:30:34 -0400 Subject: [PATCH 3/6] fix(xai): update reasoning test to clarify model capabilities --- .../tests/integration_tests/test_chat_models.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py index a328668470007..e709a953c476a 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -10,7 +10,11 @@ def test_reasoning() -> None: - """Test reasoning features.""" + """Test reasoning features. + + Note: `grok-4` does not return `reasoning_content`, but may optionally return + encrypted reasoning content if `use_encrypted_content` is set to True. + """ # Test reasoning effort chat_model = ChatXAI( # grok-4 doesn't support reasoning_effort @@ -21,14 +25,6 @@ def test_reasoning() -> None: assert response.content assert response.additional_kwargs["reasoning_content"] - # Test default reasoning - chat_model = ChatXAI( - model=MODEL_NAME, - ) - response = chat_model.invoke("What is 3^3?") - assert response.content - assert response.additional_kwargs["reasoning_content"] - # Test streaming full: BaseMessageChunk | None = None for chunk in chat_model.stream("What is 3^3?"): From 6067b6b893452e9b007992a12e2e6cd90f5771f2 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 16 Oct 2025 19:48:35 -0400 Subject: [PATCH 4/6] fix: use `override` from `typing_extensions` for Python<3.12 compatibility --- .../xai/tests/integration_tests/test_chat_models_standard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 095b4589cb68b..9e28f3a0bf208 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, override +from typing import TYPE_CHECKING import pytest from langchain_core.messages import AIMessage @@ -10,6 +10,7 @@ from langchain_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] ) +from typing_extensions import override from langchain_xai import ChatXAI From 4bfe25df29ef98be590628162016bce37d29096a Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Fri, 17 Oct 2025 13:21:51 -0400 Subject: [PATCH 5/6] resolve comments --- libs/partners/xai/langchain_xai/chat_models.py | 4 ---- .../xai/tests/integration_tests/test_chat_models.py | 8 +++++--- .../tests/integration_tests/test_chat_models_standard.py | 3 +-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 227ee714a1d04..d2de07c20ff04 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -530,8 +530,6 @@ def _create_chat_result( rtn = super()._create_chat_result(response, generation_info) for generation in rtn.generations: - if generation.message.response_metadata is None: - generation.message.response_metadata = {} generation.message.response_metadata["model_provider"] = "xai" if not isinstance(response, openai.BaseModel): @@ -562,8 +560,6 @@ def _convert_chunk_to_generation_chunk( ) if generation_chunk: - if generation_chunk.message.response_metadata is None: - generation_chunk.message.response_metadata = {} generation_chunk.message.response_metadata["model_provider"] = "xai" if (choices := chunk.get("choices")) and generation_chunk: diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py index e709a953c476a..ce1dda4fbe979 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -21,13 +21,14 @@ def test_reasoning() -> None: model="grok-3-mini", reasoning_effort="low", ) - response = chat_model.invoke("What is 3^3?") + input_message = "What is 3^3?" + response = chat_model.invoke(input_message) assert response.content assert response.additional_kwargs["reasoning_content"] # Test streaming full: BaseMessageChunk | None = None - for chunk in chat_model.stream("What is 3^3?"): + for chunk in chat_model.stream(input_message): full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) assert full.additional_kwargs["reasoning_content"] @@ -40,7 +41,8 @@ def test_reasoning() -> None: assert len(list(reasoning_content)) >= 1 # Test that passing message with reasoning back in works - followup = chat_model.invoke([response, "Based on your reasoning, what is 4^4?"]) + follow_up_message = "Based on your reasoning, what is 4^4?" + followup = chat_model.invoke([input_message, response, follow_up_message]) assert followup.content assert followup.additional_kwargs["reasoning_content"] followup_reasoning = ( diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 9e28f3a0bf208..9b7871d6ac26f 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -44,8 +44,7 @@ def chat_model_params(self) -> dict: @override def test_stop_sequence(self, model: BaseChatModel) -> None: """Override to use `grok-3` which supports stop sequences.""" - params = self.chat_model_params - params["model"] = "grok-3" + params = {**self.chat_model_params, "model": "grok-3"} grok3_model = ChatXAI(**params) From 25d6359652a86c8bd3a0cad8edb267e9af859167 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Fri, 17 Oct 2025 15:14:35 -0400 Subject: [PATCH 6/6] . --- .../integration_tests/test_chat_models.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py index ce1dda4fbe979..049f51e59bc70 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import Literal + +import pytest from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk from langchain_xai import ChatXAI @@ -9,18 +12,25 @@ MODEL_NAME = "grok-4-fast-reasoning" -def test_reasoning() -> None: +@pytest.mark.parametrize("output_version", ["", "v1"]) +def test_reasoning(output_version: Literal["", "v1"]) -> None: """Test reasoning features. Note: `grok-4` does not return `reasoning_content`, but may optionally return encrypted reasoning content if `use_encrypted_content` is set to True. """ # Test reasoning effort - chat_model = ChatXAI( - # grok-4 doesn't support reasoning_effort - model="grok-3-mini", - reasoning_effort="low", - ) + if output_version: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + output_version=output_version, + ) + else: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + ) input_message = "What is 3^3?" response = chat_model.invoke(input_message) assert response.content @@ -51,7 +61,13 @@ def test_reasoning() -> None: assert len(list(followup_reasoning)) >= 1 # Test passing in a ReasoningContentBlock - msg_w_reasoning = AIMessage(content_blocks=response.content_blocks) + response_metadata = {"model_provider": "xai"} + if output_version: + response_metadata["output_version"] = output_version + msg_w_reasoning = AIMessage( + content_blocks=response.content_blocks, + response_metadata=response_metadata, + ) followup_2 = chat_model.invoke( [msg_w_reasoning, "Based on your reasoning, what is 5^5?"] )