1717from functools import wraps
1818from typing import Any , List , Optional
1919
20- from langchain_core .callbacks .manager import CallbackManagerForLLMRun
21- from langchain_core .language_models .chat_models import generate_from_stream
20+ from langchain_core .callbacks import Callbacks
21+ from langchain_core .callbacks .manager import (
22+ AsyncCallbackManagerForLLMRun ,
23+ CallbackManagerForLLMRun ,
24+ )
25+ from langchain_core .language_models .chat_models import (
26+ agenerate_from_stream ,
27+ generate_from_stream ,
28+ )
2229from langchain_core .messages import BaseMessage
2330from langchain_core .outputs import ChatResult
2431from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
@@ -49,6 +56,28 @@ def wrapper(
4956 return wrapper
5057
5158
59+ def async_stream_decorator (func ):
60+ @wraps (func )
61+ async def wrapper (
62+ self ,
63+ messages : List [BaseMessage ],
64+ stop : Optional [List [str ]] = None ,
65+ run_manager : Optional [AsyncCallbackManagerForLLMRun ] = None ,
66+ stream : Optional [bool ] = None ,
67+ ** kwargs : Any ,
68+ ) -> ChatResult :
69+ should_stream = stream if stream is not None else self .streaming
70+ if should_stream :
71+ stream_iter = self ._astream (
72+ messages , stop = stop , run_manager = run_manager , ** kwargs
73+ )
74+ return await agenerate_from_stream (stream_iter )
75+ else :
76+ return await func (self , messages , stop , run_manager , ** kwargs )
77+
78+ return wrapper
79+
80+
5281# NOTE: this needs to have the same name as the original class,
5382# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
5483class ChatNVIDIA (ChatNVIDIAOriginal ):
@@ -62,9 +91,26 @@ def _generate(
6291 messages : List [BaseMessage ],
6392 stop : Optional [List [str ]] = None ,
6493 run_manager : Optional [CallbackManagerForLLMRun ] = None ,
94+ callbacks : Callbacks = None ,
6595 ** kwargs : Any ,
6696 ) -> ChatResult :
6797 return super ()._generate (
98+ messages = messages ,
99+ stop = stop ,
100+ run_manager = run_manager ,
101+ callbacks = callbacks ,
102+ ** kwargs ,
103+ )
104+
105+ @async_stream_decorator
106+ async def _agenerate (
107+ self ,
108+ messages : List [BaseMessage ],
109+ stop : Optional [List [str ]] = None ,
110+ run_manager : Optional [AsyncCallbackManagerForLLMRun ] = None ,
111+ ** kwargs : Any ,
112+ ) -> ChatResult :
113+ return await super ()._agenerate (
68114 messages = messages , stop = stop , run_manager = run_manager , ** kwargs
69115 )
70116
0 commit comments