Skip to content

Commit b7e314b

Browse files
committed
feat(llm): Add async streaming support to ChatNVIDIA provider
Enables stream_async() to work with ChatNVIDIA/NIM models by implementing async streaming decorator and _agenerate method. Prior to this fix, stream_async() would fail with NIM engine configurations.
1 parent d2bfaea commit b7e314b

File tree

2 files changed

+443
-2
lines changed

2 files changed

+443
-2
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@
1818
from functools import wraps
1919
from typing import Any, Dict, List, Optional
2020

21-
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
22-
from langchain_core.language_models.chat_models import generate_from_stream
21+
from langchain_core.callbacks import Callbacks
22+
from langchain_core.callbacks.manager import (
23+
AsyncCallbackManagerForLLMRun,
24+
CallbackManagerForLLMRun,
25+
)
26+
from langchain_core.language_models.chat_models import (
27+
agenerate_from_stream,
28+
generate_from_stream,
29+
)
2330
from langchain_core.messages import BaseMessage
2431
from langchain_core.outputs import ChatResult
2532
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
@@ -50,6 +57,28 @@ def wrapper(
5057
return wrapper
5158

5259

60+
def async_stream_decorator(func):
61+
@wraps(func)
62+
async def wrapper(
63+
self,
64+
messages: List[BaseMessage],
65+
stop: Optional[List[str]] = None,
66+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
67+
stream: Optional[bool] = None,
68+
**kwargs: Any,
69+
) -> ChatResult:
70+
should_stream = stream if stream is not None else self.streaming
71+
if should_stream:
72+
stream_iter = self._astream(
73+
messages, stop=stop, run_manager=run_manager, **kwargs
74+
)
75+
return await agenerate_from_stream(stream_iter)
76+
else:
77+
return await func(self, messages, stop, run_manager, **kwargs)
78+
79+
return wrapper
80+
81+
5382
# NOTE: this needs to have the same name as the original class,
5483
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
5584
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
@@ -105,9 +134,26 @@ def _generate(
105134
messages: List[BaseMessage],
106135
stop: Optional[List[str]] = None,
107136
run_manager: Optional[CallbackManagerForLLMRun] = None,
137+
callbacks: Callbacks = None,
108138
**kwargs: Any,
109139
) -> ChatResult:
110140
return super()._generate(
141+
messages=messages,
142+
stop=stop,
143+
run_manager=run_manager,
144+
callbacks=callbacks,
145+
**kwargs,
146+
)
147+
148+
@async_stream_decorator
149+
async def _agenerate(
150+
self,
151+
messages: List[BaseMessage],
152+
stop: Optional[List[str]] = None,
153+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
154+
**kwargs: Any,
155+
) -> ChatResult:
156+
return await super()._agenerate(
111157
messages=messages, stop=stop, run_manager=run_manager, **kwargs
112158
)
113159

0 commit comments

Comments
 (0)