diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 22701217153f9..d2de07c20ff04 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -529,6 +529,9 @@ def _create_chat_result( ) -> ChatResult: rtn = super()._create_chat_result(response, generation_info) + for generation in rtn.generations: + generation.message.response_metadata["model_provider"] = "xai" + if not isinstance(response, openai.BaseModel): return rtn @@ -555,6 +558,10 @@ def _convert_chunk_to_generation_chunk( default_chunk_class, base_generation_info, ) + + if generation_chunk: + generation_chunk.message.response_metadata["model_provider"] = "xai" + if (choices := chunk.get("choices")) and generation_chunk: top = choices[0] if isinstance(generation_chunk.message, AIMessageChunk) and ( diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py new file mode 100644 index 0000000000000..049f51e59bc70 --- /dev/null +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -0,0 +1,96 @@ +"""Integration tests for ChatXAI specific features.""" + +from __future__ import annotations + +from typing import Literal + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk + +from langchain_xai import ChatXAI + +MODEL_NAME = "grok-4-fast-reasoning" + + +@pytest.mark.parametrize("output_version", ["", "v1"]) +def test_reasoning(output_version: Literal["", "v1"]) -> None: + """Test reasoning features. + + Note: `grok-4` does not return `reasoning_content`, but may optionally return + encrypted reasoning content if `use_encrypted_content` is set to True. + """ + # Test reasoning effort + if output_version: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + output_version=output_version, + ) + else: + chat_model = ChatXAI( + model="grok-3-mini", + reasoning_effort="low", + ) + input_message = "What is 3^3?" + response = chat_model.invoke(input_message) + assert response.content + assert response.additional_kwargs["reasoning_content"] + + # Test streaming + full: BaseMessageChunk | None = None + for chunk in chat_model.stream(input_message): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["reasoning_content"] + + # Check that we can access reasoning content blocks + assert response.content_blocks + reasoning_content = ( + block for block in response.content_blocks if block["type"] == "reasoning" + ) + assert len(list(reasoning_content)) >= 1 + + # Test that passing message with reasoning back in works + follow_up_message = "Based on your reasoning, what is 4^4?" + followup = chat_model.invoke([input_message, response, follow_up_message]) + assert followup.content + assert followup.additional_kwargs["reasoning_content"] + followup_reasoning = ( + block for block in followup.content_blocks if block["type"] == "reasoning" + ) + assert len(list(followup_reasoning)) >= 1 + + # Test passing in a ReasoningContentBlock + response_metadata = {"model_provider": "xai"} + if output_version: + response_metadata["output_version"] = output_version + msg_w_reasoning = AIMessage( + content_blocks=response.content_blocks, + response_metadata=response_metadata, + ) + followup_2 = chat_model.invoke( + [msg_w_reasoning, "Based on your reasoning, what is 5^5?"] + ) + assert followup_2.content + assert followup_2.additional_kwargs["reasoning_content"] + + +def test_web_search() -> None: + llm = ChatXAI( + model=MODEL_NAME, + search_parameters={"mode": "on", "max_search_results": 3}, + ) + + # Test invoke + response = llm.invoke("Provide me a digest of world news in the last 24 hours.") + assert response.content + assert response.additional_kwargs["citations"] + assert len(response.additional_kwargs["citations"]) <= 3 + + # Test streaming + full = None + for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.additional_kwargs["citations"] + assert len(full.additional_kwargs["citations"]) <= 3 diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index dada7d5c4bcd5..9b7871d6ac26f 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -4,26 +4,25 @@ from typing import TYPE_CHECKING -from langchain_core.messages import AIMessageChunk, BaseMessageChunk +import pytest +from langchain_core.messages import AIMessage from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] ) +from typing_extensions import override from langchain_xai import ChatXAI if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel -# Initialize the rate limiter in global scope, so it can be re-used -# across tests. +# Initialize the rate limiter in global scope, so it can be re-used across tests rate_limiter = InMemoryRateLimiter( requests_per_second=0.5, ) - -# Not using Grok 4 since it doesn't support reasoning params (effort) or returns -# reasoning content. +MODEL_NAME = "grok-4-fast-reasoning" class TestXAIStandard(ChatModelIntegrationTests): @@ -33,48 +32,28 @@ def chat_model_class(self) -> type[BaseChatModel]: @property def chat_model_params(self) -> dict: - # TODO: bump to test new Grok once they implement other features return { - "model": "grok-3", + "model": MODEL_NAME, "rate_limiter": rate_limiter, "stream_usage": True, } - -def test_reasoning_content() -> None: - """Test reasoning content.""" - chat_model = ChatXAI( - model="grok-3-mini", - reasoning_effort="low", - ) - response = chat_model.invoke("What is 3^3?") - assert response.content - assert response.additional_kwargs["reasoning_content"] - - # Test streaming - full: BaseMessageChunk | None = None - for chunk in chat_model.stream("What is 3^3?"): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["reasoning_content"] - - -def test_web_search() -> None: - llm = ChatXAI( - model="grok-3", - search_parameters={"mode": "auto", "max_search_results": 3}, + @pytest.mark.xfail( + reason="Default model does not support stop sequences, using grok-3 instead" ) - - # Test invoke - response = llm.invoke("Provide me a digest of world news in the last 24 hours.") - assert response.content - assert response.additional_kwargs["citations"] - assert len(response.additional_kwargs["citations"]) <= 3 - - # Test streaming - full = None - for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."): - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.additional_kwargs["citations"] - assert len(full.additional_kwargs["citations"]) <= 3 + @override + def test_stop_sequence(self, model: BaseChatModel) -> None: + """Override to use `grok-3` which supports stop sequences.""" + params = {**self.chat_model_params, "model": "grok-3"} + + grok3_model = ChatXAI(**params) + + result = grok3_model.invoke("hi", stop=["you"]) + assert isinstance(result, AIMessage) + + custom_model = ChatXAI( + **params, + stop_sequences=["you"], + ) + result = custom_model.invoke("hi") + assert isinstance(result, AIMessage) diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 4cd1261555c90..6f1cd18c2e2dc 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -9,7 +9,7 @@ 'kwargs': dict({ 'max_retries': 2, 'max_tokens': 100, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py index 7c9a947a913ed..7224655db79ed 100644 --- a/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/unit_tests/test_chat_models_standard.py @@ -7,6 +7,8 @@ from langchain_xai import ChatXAI +MODEL_NAME = "grok-4" + class TestXAIStandard(ChatModelUnitTests): @property @@ -15,7 +17,7 @@ def chat_model_class(self) -> type[BaseChatModel]: @property def chat_model_params(self) -> dict: - return {"model": "grok-beta"} + return {"model": MODEL_NAME} @property def init_from_env_params(self) -> tuple[dict, dict, dict]: @@ -24,7 +26,7 @@ def init_from_env_params(self) -> tuple[dict, dict, dict]: "XAI_API_KEY": "api_key", }, { - "model": "grok-beta", + "model": MODEL_NAME, }, { "xai_api_key": "api_key", diff --git a/libs/partners/xai/uv.lock b/libs/partners/xai/uv.lock index 0e0024f50f844..ac0ffb5c3bc01 100644 --- a/libs/partners/xai/uv.lock +++ b/libs/partners/xai/uv.lock @@ -621,7 +621,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0a8" +version = "1.0.0rc2" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -699,6 +699,7 @@ dev = [{ name = "langchain-core", editable = "../../core" }] lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] test = [ { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, + { name = "langchain", editable = "../../langchain_v1" }, { name = "langchain-core", editable = "../../core" }, { name = "langchain-tests", editable = "../../standard-tests" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" }, @@ -728,7 +729,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "1.0.0a2" +version = "1.0.0rc1" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" },