Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@
from functools import wraps
from typing import Any, Dict, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import generate_from_stream
from langchain_core.callbacks import Callbacks
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
Expand Down Expand Up @@ -50,6 +57,28 @@ def wrapper(
return wrapper


def async_stream_decorator(func): # pragma: no cover
@wraps(func)
async def wrapper(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
else:
return await func(self, messages, stop, run_manager, **kwargs)

return wrapper


# NOTE: this needs to have the same name as the original class,
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
Expand Down Expand Up @@ -105,9 +134,26 @@ def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> ChatResult:
return super()._generate(
messages=messages,
stop=stop,
run_manager=run_manager,
callbacks=callbacks,
**kwargs,
)

@async_stream_decorator
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await super()._agenerate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
Comment on lines +149 to 158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Async method _agenerate missing callbacks parameter that was added to sync _generate (line 137). This inconsistency may cause issues if callers expect uniform API.

Suggested change
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await super()._agenerate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
@async_stream_decorator
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> ChatResult:
return await super()._agenerate(
messages=messages, stop=stop, run_manager=run_manager, callbacks=callbacks, **kwargs
)


Expand Down
Loading