From 09d0f461c9e4fc1fbc1e0e1d48b3386a9d4ee5b5 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 18 Nov 2025 13:12:06 +0100 Subject: [PATCH 1/4] feat(llm): Add async streaming support to ChatNVIDIA provider Enables stream_async() to work with ChatNVIDIA/NIM models by implementing async streaming decorator and _agenerate method. Prior to this fix, stream_async() would fail with NIM engine configurations. --- .../_langchain_nvidia_ai_endpoints_patch.py | 50 ++- ...est_langchain_nvidia_ai_endpoints_patch.py | 395 ++++++++++++++++++ 2 files changed, 443 insertions(+), 2 deletions(-) create mode 100644 tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py diff --git a/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py b/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py index 5eef9d144..b88a56368 100644 --- a/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py +++ b/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py @@ -18,8 +18,15 @@ from functools import wraps from typing import Any, Dict, List, Optional -from langchain_core.callbacks.manager import CallbackManagerForLLMRun -from langchain_core.language_models.chat_models import generate_from_stream +from langchain_core.callbacks import Callbacks +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal @@ -50,6 +57,28 @@ def wrapper( return wrapper +def async_stream_decorator(func): + @wraps(func) + async def wrapper( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + else: + return await func(self, messages, stop, run_manager, **kwargs) + + return wrapper + + # NOTE: this needs to have the same name as the original class, # otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail. class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover @@ -105,9 +134,26 @@ def _generate( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> ChatResult: return super()._generate( + messages=messages, + stop=stop, + run_manager=run_manager, + callbacks=callbacks, + **kwargs, + ) + + @async_stream_decorator + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return await super()._agenerate( messages=messages, stop=stop, run_manager=run_manager, **kwargs ) diff --git a/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py new file mode 100644 index 000000000..2f78f059f --- /dev/null +++ b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py @@ -0,0 +1,395 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import time +from unittest.mock import patch + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + +from nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA + +langchain_nvidia_ai_endpoints = pytest.importorskip("langchain_nvidia_ai_endpoints") + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") + + +class FakeCallbackHandler: + def __init__(self): + self.llm_streams = 0 + self.tokens = [] + + async def on_llm_new_token(self, token: str, **kwargs): + self.llm_streams += 1 + self.tokens.append(token) + + +class TestAsyncStreamDecorator: + @pytest.mark.asyncio + async def test_decorator_with_streaming_enabled(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + + messages = [HumanMessage(content="Hello")] + + with patch.object(chat, "_astream") as mock_astream: + mock_chunk = ChatGenerationChunk(message=AIMessageChunk(content="Hi there")) + mock_astream.return_value = AsyncIteratorMock([mock_chunk]) + + result = await chat._agenerate(messages) + + assert isinstance(result, ChatResult) + assert len(result.generations) == 1 + assert result.generations[0].message.content == "Hi there" + mock_astream.assert_called_once() + + @pytest.mark.asyncio + async def test_decorator_with_streaming_disabled(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=False, + ) + + messages = [HumanMessage(content="Hello")] + + with patch( + "langchain_nvidia_ai_endpoints.ChatNVIDIA._agenerate" + ) as mock_parent_agenerate: + expected_result = ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content="Response from parent")) + ] + ) + mock_parent_agenerate.return_value = expected_result + + result = await chat._agenerate(messages) + + assert result == expected_result + mock_parent_agenerate.assert_called_once() + + @pytest.mark.asyncio + async def test_decorator_preserves_function_metadata(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + ) + + assert chat._agenerate.__name__ == "_agenerate" + assert asyncio.iscoroutinefunction(chat._agenerate) + + @pytest.mark.asyncio + async def test_streaming_aggregates_multiple_chunks(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + + messages = [HumanMessage(content="Hello")] + + with patch.object(chat, "_astream") as mock_astream: + chunks = [ + ChatGenerationChunk(message=AIMessageChunk(content="Hello ")), + ChatGenerationChunk(message=AIMessageChunk(content="world")), + ChatGenerationChunk(message=AIMessageChunk(content="!")), + ] + mock_astream.return_value = AsyncIteratorMock(chunks) + + result = await chat._agenerate(messages) + + assert isinstance(result, ChatResult) + assert len(result.generations) == 1 + assert result.generations[0].message.content == "Hello world!" + mock_astream.assert_called_once() + + @pytest.mark.asyncio + async def test_streaming_with_empty_chunks(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + + messages = [HumanMessage(content="Hello")] + + with patch.object(chat, "_astream") as mock_astream: + chunks = [ + ChatGenerationChunk(message=AIMessageChunk(content="")), + ChatGenerationChunk(message=AIMessageChunk(content="Hello")), + ChatGenerationChunk(message=AIMessageChunk(content="")), + ] + mock_astream.return_value = AsyncIteratorMock(chunks) + + result = await chat._agenerate(messages) + + assert isinstance(result, ChatResult) + assert len(result.generations) == 1 + assert result.generations[0].message.content == "Hello" + + +class TestChatNVIDIAPatch: + @pytest.mark.asyncio + async def test_agenerate_calls_patched_agenerate(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=False, + ) + + messages = [[HumanMessage(content="Hello")], [HumanMessage(content="Hi")]] + + with patch( + "langchain_nvidia_ai_endpoints.ChatNVIDIA._agenerate" + ) as mock_parent: + mock_parent.return_value = ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response"))] + ) + + result = await chat.agenerate(messages) + + assert isinstance(result.generations, list) + assert len(result.generations) == 2 + for generation_list in result.generations: + assert len(generation_list) == 1 + assert generation_list[0].message.content == "Response" + assert mock_parent.call_count == 2 + + @pytest.mark.asyncio + async def test_agenerate_with_streaming_enabled(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + + messages = [[HumanMessage(content="Hello")]] + + with patch.object(chat, "_astream") as mock_astream: + chunks = [ + ChatGenerationChunk(message=AIMessageChunk(content="Hello ")), + ChatGenerationChunk(message=AIMessageChunk(content="world")), + ] + mock_astream.return_value = AsyncIteratorMock(chunks) + + result = await chat.agenerate(messages) + + assert isinstance(result.generations, list) + assert len(result.generations) == 1 + assert len(result.generations[0]) == 1 + assert result.generations[0][0].message.content == "Hello world" + mock_astream.assert_called_once() + + @pytest.mark.asyncio + async def test_streaming_field_exists(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + ) + + assert hasattr(chat, "streaming") + assert chat.streaming == False + + chat_with_streaming = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + assert chat_with_streaming.streaming == True + + @pytest.mark.asyncio + async def test_backward_compatibility_sync_generate(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=False, + ) + + messages = [[HumanMessage(content="Hello")]] + + with patch("langchain_nvidia_ai_endpoints.ChatNVIDIA._generate") as mock_parent: + mock_parent.return_value = ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response"))] + ) + + result = chat.generate(messages) + + assert isinstance(result.generations, list) + assert len(result.generations[0]) == 1 + assert result.generations[0][0].message.content == "Response" + mock_parent.assert_called() + + @pytest.mark.asyncio + async def test_streaming_handles_multiple_message_batches(self): + chat = ChatNVIDIA( + model="meta/llama-3.3-70b-instruct", + base_url="http://localhost:8000/v1", + streaming=True, + ) + + messages = [ + [HumanMessage(content="First message")], + [HumanMessage(content="Second message")], + ] + + with patch.object(chat, "_astream") as mock_astream: + mock_astream.side_effect = [ + AsyncIteratorMock( + [ + ChatGenerationChunk(message=AIMessageChunk(content="First ")), + ChatGenerationChunk(message=AIMessageChunk(content="response")), + ] + ), + AsyncIteratorMock( + [ + ChatGenerationChunk(message=AIMessageChunk(content="Second ")), + ChatGenerationChunk(message=AIMessageChunk(content="response")), + ] + ), + ] + + result = await chat.agenerate(messages) + + assert len(result.generations) == 2 + assert result.generations[0][0].message.content == "First response" + assert result.generations[1][0].message.content == "Second response" + assert mock_astream.call_count == 2 + + +class TestIntegrationWithLLMRails: + @pytest.mark.asyncio + async def test_chatnvidia_with_llmrails_async(self): + from nemoguardrails import LLMRails, RailsConfig + + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "nim", + "model": "meta/llama-3.3-70b-instruct", + } + ] + } + ) + + with patch( + "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch.ChatNVIDIA._agenerate" + ) as mock_agenerate: + mock_agenerate.return_value = ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Test response"))] + ) + + rails = LLMRails(config) + + result = await rails.generate_async( + messages=[{"role": "user", "content": "Hello"}] + ) + + assert result is not None + assert "content" in result + assert result["content"] == "Test response" + + @pytest.mark.asyncio + async def test_chatnvidia_streaming_with_llmrails(self): + from nemoguardrails import LLMRails, RailsConfig + + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "nim", + "model": "meta/llama-3.3-70b-instruct", + "parameters": {"streaming": True}, + } + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + chat_model = rails.llm + + assert hasattr(chat_model, "streaming") + assert chat_model.streaming == True + + +class AsyncIteratorMock: + def __init__(self, items): + self.items = items + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.items): + raise StopAsyncIteration + item = self.items[self.index] + self.index += 1 + return item + + +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE environment variable to be set for live testing", +) +class TestChatNVIDIAStreamingE2E: + @pytest.mark.asyncio + async def test_stream_async_ttft_with_nim(self): + from nemoguardrails import LLMRails, RailsConfig + from nemoguardrails.actions.llm.utils import LLMCallException + + yaml_content = """ +models: + - type: main + engine: nim + model: meta/llama-3.3-70b-instruct + +streaming: True +""" + config = RailsConfig.from_content(yaml_content=yaml_content) + rails = LLMRails(config) + + chunk_times = [time.time()] + chunks = [] + + async for chunk in rails.stream_async( + messages=[ + {"role": "user", "content": "Count to 20 by 2s, e.g. 2 4 6 8 ..."} + ] + ): + chunks.append(chunk) + chunk_times.append(time.time()) + + ttft = chunk_times[1] - chunk_times[0] + total_time = chunk_times[-1] - chunk_times[0] + + assert len(chunks) > 0, "Should receive at least one chunk" + assert ttft < (total_time / 2), ( + f"TTFT ({ttft:.3f}s) should be less than half of total time ({total_time:.3f}s)" + ) + assert len(chunk_times) > 2, "Should receive multiple chunks for streaming" + + full_response = "".join(chunks) + assert len(full_response) > 0, "Full response should not be empty" From 19ddac446526b46ba9c0aed05889566d2aeea384 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:43:06 +0100 Subject: [PATCH 2/4] fix style --- .../test_langchain_nvidia_ai_endpoints_patch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py index 2f78f059f..e714b68d5 100644 --- a/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py +++ b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py @@ -386,9 +386,9 @@ async def test_stream_async_ttft_with_nim(self): total_time = chunk_times[-1] - chunk_times[0] assert len(chunks) > 0, "Should receive at least one chunk" - assert ttft < (total_time / 2), ( - f"TTFT ({ttft:.3f}s) should be less than half of total time ({total_time:.3f}s)" - ) + assert ttft < ( + total_time / 2 + ), f"TTFT ({ttft:.3f}s) should be less than half of total time ({total_time:.3f}s)" assert len(chunk_times) > 2, "Should receive multiple chunks for streaming" full_response = "".join(chunks) From d75187230ac05e9ceadbb3ac836ecaea128c1020 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:51:11 +0100 Subject: [PATCH 3/4] modify import order to skip tests --- .../test_langchain_nvidia_ai_endpoints_patch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py index e714b68d5..61db66cfa 100644 --- a/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py +++ b/tests/llm_providers/test_langchain_nvidia_ai_endpoints_patch.py @@ -19,13 +19,14 @@ from unittest.mock import patch import pytest + +langchain_nvidia_ai_endpoints = pytest.importorskip("langchain_nvidia_ai_endpoints") + from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA -langchain_nvidia_ai_endpoints = pytest.importorskip("langchain_nvidia_ai_endpoints") - LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") From 71f6e7c2a381efa4809cffd8b5860fa6f01f73cf Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Wed, 19 Nov 2025 11:59:43 +0100 Subject: [PATCH 4/4] add pragma no cover to async decorator --- .../llm/providers/_langchain_nvidia_ai_endpoints_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py b/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py index b88a56368..9b2f9027d 100644 --- a/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py +++ b/nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py @@ -57,7 +57,7 @@ def wrapper( return wrapper -def async_stream_decorator(func): +def async_stream_decorator(func): # pragma: no cover @wraps(func) async def wrapper( self,