Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from __future__ import annotations

import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast

from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.utils import Input, Output, gather_with_concurrency
Expand All @@ -33,7 +35,7 @@
message_to_dict,
)
from nemoguardrails.integrations.langchain.utils import async_wrap
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,7 +64,7 @@ class RunnableRails(Runnable[Input, Output]):
def __init__(
self,
config: RailsConfig,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
tools: Optional[List[Tool]] = None,
passthrough: bool = True,
runnable: Optional[Runnable] = None,
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(
if self.passthrough_runnable:
self._init_passthrough_fn()

def _init_passthrough_fn(self):
def _init_passthrough_fn(self) -> None:
"""Initialize the passthrough function for the LLM rails instance."""

async def passthrough_fn(context: dict, events: List[dict]):
Expand All @@ -134,7 +136,8 @@ async def passthrough_fn(context: dict, events: List[dict]):

return text, _output

self.rails.llm_generation_actions.passthrough_fn = passthrough_fn
# Dynamically assign passthrough_fn to avoid type checker issues
setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn)

def __or__(
self, other: Union[BaseLanguageModel, Runnable[Any, Any]]
Expand Down Expand Up @@ -687,6 +690,9 @@ def _full_rails_invoke(
res = self.rails.generate(
messages=input_messages, options=GenerationOptions(output_vars=True)
)
# When using output_vars=True, rails.generate returns a GenerationResponse
if not isinstance(res, GenerationResponse):
raise Exception(f"Expected GenerationResponse, got {type(res)}")
context = res.output_data
result = res.response

Expand Down
Loading