Skip to content
7 changes: 7 additions & 0 deletions libs/partners/xai/langchain_xai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def _create_chat_result(
) -> ChatResult:
rtn = super()._create_chat_result(response, generation_info)

for generation in rtn.generations:
generation.message.response_metadata["model_provider"] = "xai"

if not isinstance(response, openai.BaseModel):
return rtn

Expand All @@ -555,6 +558,10 @@ def _convert_chunk_to_generation_chunk(
default_chunk_class,
base_generation_info,
)

if generation_chunk:
generation_chunk.message.response_metadata["model_provider"] = "xai"

if (choices := chunk.get("choices")) and generation_chunk:
top = choices[0]
if isinstance(generation_chunk.message, AIMessageChunk) and (
Expand Down
96 changes: 96 additions & 0 deletions libs/partners/xai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Integration tests for ChatXAI specific features."""

from __future__ import annotations

from typing import Literal

import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk

from langchain_xai import ChatXAI

MODEL_NAME = "grok-4-fast-reasoning"


@pytest.mark.parametrize("output_version", ["", "v1"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("output_version", ["", "v1"])
@pytest.mark.parametrize("output_version", ["v0", "v1"])

None also works for v0

def test_reasoning(output_version: Literal["", "v1"]) -> None:
"""Test reasoning features.

Note: `grok-4` does not return `reasoning_content`, but may optionally return
encrypted reasoning content if `use_encrypted_content` is set to True.
"""
# Test reasoning effort
if output_version:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
output_version=output_version,
)
else:
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
)
input_message = "What is 3^3?"
response = chat_model.invoke(input_message)
assert response.content
assert response.additional_kwargs["reasoning_content"]

# Test streaming
full: BaseMessageChunk | None = None
for chunk in chat_model.stream(input_message):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]

# Check that we can access reasoning content blocks
assert response.content_blocks
reasoning_content = (
block for block in response.content_blocks if block["type"] == "reasoning"
)
assert len(list(reasoning_content)) >= 1

# Test that passing message with reasoning back in works
follow_up_message = "Based on your reasoning, what is 4^4?"
followup = chat_model.invoke([input_message, response, follow_up_message])
assert followup.content
assert followup.additional_kwargs["reasoning_content"]
followup_reasoning = (
block for block in followup.content_blocks if block["type"] == "reasoning"
)
assert len(list(followup_reasoning)) >= 1

# Test passing in a ReasoningContentBlock
response_metadata = {"model_provider": "xai"}
if output_version:
response_metadata["output_version"] = output_version
msg_w_reasoning = AIMessage(
content_blocks=response.content_blocks,
response_metadata=response_metadata,
)
followup_2 = chat_model.invoke(
[msg_w_reasoning, "Based on your reasoning, what is 5^5?"]
)
assert followup_2.content
assert followup_2.additional_kwargs["reasoning_content"]


def test_web_search() -> None:
llm = ChatXAI(
model=MODEL_NAME,
search_parameters={"mode": "on", "max_search_results": 3},
)

# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3

# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,25 @@

from typing import TYPE_CHECKING

from langchain_core.messages import AIMessageChunk, BaseMessageChunk
import pytest
from langchain_core.messages import AIMessage
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
)
from typing_extensions import override

from langchain_xai import ChatXAI

if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel

# Initialize the rate limiter in global scope, so it can be re-used
# across tests.
# Initialize the rate limiter in global scope, so it can be re-used across tests
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5,
)


# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
# reasoning content.
MODEL_NAME = "grok-4-fast-reasoning"


class TestXAIStandard(ChatModelIntegrationTests):
Expand All @@ -33,48 +32,28 @@ def chat_model_class(self) -> type[BaseChatModel]:

@property
def chat_model_params(self) -> dict:
# TODO: bump to test new Grok once they implement other features
return {
"model": "grok-3",
"model": MODEL_NAME,
"rate_limiter": rate_limiter,
"stream_usage": True,
}


def test_reasoning_content() -> None:
"""Test reasoning content."""
chat_model = ChatXAI(
model="grok-3-mini",
reasoning_effort="low",
)
response = chat_model.invoke("What is 3^3?")
assert response.content
assert response.additional_kwargs["reasoning_content"]

# Test streaming
full: BaseMessageChunk | None = None
for chunk in chat_model.stream("What is 3^3?"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]


def test_web_search() -> None:
llm = ChatXAI(
model="grok-3",
search_parameters={"mode": "auto", "max_search_results": 3},
@pytest.mark.xfail(
reason="Default model does not support stop sequences, using grok-3 instead"
)

# Test invoke
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
assert response.content
assert response.additional_kwargs["citations"]
assert len(response.additional_kwargs["citations"]) <= 3

# Test streaming
full = None
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["citations"]
assert len(full.additional_kwargs["citations"]) <= 3
@override
def test_stop_sequence(self, model: BaseChatModel) -> None:
"""Override to use `grok-3` which supports stop sequences."""
params = {**self.chat_model_params, "model": "grok-3"}

grok3_model = ChatXAI(**params)

result = grok3_model.invoke("hi", stop=["you"])
assert isinstance(result, AIMessage)

custom_model = ChatXAI(
**params,
stop_sequences=["you"],
)
result = custom_model.invoke("hi")
assert isinstance(result, AIMessage)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'kwargs': dict({
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'model_name': 'grok-4',
'request_timeout': 60.0,
'stop': list([
]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from langchain_xai import ChatXAI

MODEL_NAME = "grok-4"


class TestXAIStandard(ChatModelUnitTests):
@property
Expand All @@ -15,7 +17,7 @@ def chat_model_class(self) -> type[BaseChatModel]:

@property
def chat_model_params(self) -> dict:
return {"model": "grok-beta"}
return {"model": MODEL_NAME}

@property
def init_from_env_params(self) -> tuple[dict, dict, dict]:
Expand All @@ -24,7 +26,7 @@ def init_from_env_params(self) -> tuple[dict, dict, dict]:
"XAI_API_KEY": "api_key",
},
{
"model": "grok-beta",
"model": MODEL_NAME,
},
{
"xai_api_key": "api_key",
Expand Down
5 changes: 3 additions & 2 deletions libs/partners/xai/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.