|
18 | 18 | from functools import wraps |
19 | 19 | from typing import Any, Dict, List, Optional |
20 | 20 |
|
21 | | -from langchain_core.callbacks.manager import CallbackManagerForLLMRun |
22 | | -from langchain_core.language_models.chat_models import generate_from_stream |
| 21 | +from langchain_core.callbacks import Callbacks |
| 22 | +from langchain_core.callbacks.manager import ( |
| 23 | + AsyncCallbackManagerForLLMRun, |
| 24 | + CallbackManagerForLLMRun, |
| 25 | +) |
| 26 | +from langchain_core.language_models.chat_models import ( |
| 27 | + agenerate_from_stream, |
| 28 | + generate_from_stream, |
| 29 | +) |
23 | 30 | from langchain_core.messages import BaseMessage |
24 | 31 | from langchain_core.outputs import ChatResult |
25 | 32 | from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal |
@@ -50,6 +57,28 @@ def wrapper( |
50 | 57 | return wrapper |
51 | 58 |
|
52 | 59 |
|
| 60 | +def async_stream_decorator(func): |
| 61 | + @wraps(func) |
| 62 | + async def wrapper( |
| 63 | + self, |
| 64 | + messages: List[BaseMessage], |
| 65 | + stop: Optional[List[str]] = None, |
| 66 | + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| 67 | + stream: Optional[bool] = None, |
| 68 | + **kwargs: Any, |
| 69 | + ) -> ChatResult: |
| 70 | + should_stream = stream if stream is not None else self.streaming |
| 71 | + if should_stream: |
| 72 | + stream_iter = self._astream( |
| 73 | + messages, stop=stop, run_manager=run_manager, **kwargs |
| 74 | + ) |
| 75 | + return await agenerate_from_stream(stream_iter) |
| 76 | + else: |
| 77 | + return await func(self, messages, stop, run_manager, **kwargs) |
| 78 | + |
| 79 | + return wrapper |
| 80 | + |
| 81 | + |
53 | 82 | # NOTE: this needs to have the same name as the original class, |
54 | 83 | # otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail. |
55 | 84 | class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover |
@@ -105,9 +134,26 @@ def _generate( |
105 | 134 | messages: List[BaseMessage], |
106 | 135 | stop: Optional[List[str]] = None, |
107 | 136 | run_manager: Optional[CallbackManagerForLLMRun] = None, |
| 137 | + callbacks: Callbacks = None, |
108 | 138 | **kwargs: Any, |
109 | 139 | ) -> ChatResult: |
110 | 140 | return super()._generate( |
| 141 | + messages=messages, |
| 142 | + stop=stop, |
| 143 | + run_manager=run_manager, |
| 144 | + callbacks=callbacks, |
| 145 | + **kwargs, |
| 146 | + ) |
| 147 | + |
| 148 | + @async_stream_decorator |
| 149 | + async def _agenerate( |
| 150 | + self, |
| 151 | + messages: List[BaseMessage], |
| 152 | + stop: Optional[List[str]] = None, |
| 153 | + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| 154 | + **kwargs: Any, |
| 155 | + ) -> ChatResult: |
| 156 | + return await super()._agenerate( |
111 | 157 | messages=messages, stop=stop, run_manager=run_manager, **kwargs |
112 | 158 | ) |
113 | 159 |
|
|
0 commit comments