Skip to content

Commit f06d0d5

Browse files
committed
feat(llm): Add async streaming support to ChatNVIDIA provider
Enables stream_async() to work with ChatNVIDIA/NIM models by implementing async streaming decorator and _agenerate method. Prior to this fix, stream_async() would fail with NIM engine configurations.
1 parent 847ace8 commit f06d0d5

File tree

2 files changed

+443
-2
lines changed

2 files changed

+443
-2
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@
1717
from functools import wraps
1818
from typing import Any, List, Optional
1919

20-
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
21-
from langchain_core.language_models.chat_models import generate_from_stream
20+
from langchain_core.callbacks import Callbacks
21+
from langchain_core.callbacks.manager import (
22+
AsyncCallbackManagerForLLMRun,
23+
CallbackManagerForLLMRun,
24+
)
25+
from langchain_core.language_models.chat_models import (
26+
agenerate_from_stream,
27+
generate_from_stream,
28+
)
2229
from langchain_core.messages import BaseMessage
2330
from langchain_core.outputs import ChatResult
2431
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
@@ -49,6 +56,28 @@ def wrapper(
4956
return wrapper
5057

5158

59+
def async_stream_decorator(func):
60+
@wraps(func)
61+
async def wrapper(
62+
self,
63+
messages: List[BaseMessage],
64+
stop: Optional[List[str]] = None,
65+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
66+
stream: Optional[bool] = None,
67+
**kwargs: Any,
68+
) -> ChatResult:
69+
should_stream = stream if stream is not None else self.streaming
70+
if should_stream:
71+
stream_iter = self._astream(
72+
messages, stop=stop, run_manager=run_manager, **kwargs
73+
)
74+
return await agenerate_from_stream(stream_iter)
75+
else:
76+
return await func(self, messages, stop, run_manager, **kwargs)
77+
78+
return wrapper
79+
80+
5281
# NOTE: this needs to have the same name as the original class,
5382
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
5483
class ChatNVIDIA(ChatNVIDIAOriginal):
@@ -62,9 +91,26 @@ def _generate(
6291
messages: List[BaseMessage],
6392
stop: Optional[List[str]] = None,
6493
run_manager: Optional[CallbackManagerForLLMRun] = None,
94+
callbacks: Callbacks = None,
6595
**kwargs: Any,
6696
) -> ChatResult:
6797
return super()._generate(
98+
messages=messages,
99+
stop=stop,
100+
run_manager=run_manager,
101+
callbacks=callbacks,
102+
**kwargs,
103+
)
104+
105+
@async_stream_decorator
106+
async def _agenerate(
107+
self,
108+
messages: List[BaseMessage],
109+
stop: Optional[List[str]] = None,
110+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
111+
**kwargs: Any,
112+
) -> ChatResult:
113+
return await super()._agenerate(
68114
messages=messages, stop=stop, run_manager=run_manager, **kwargs
69115
)
70116

0 commit comments

Comments
 (0)