From 8c65ed36bd37174ed7092326ec113d1450178021 Mon Sep 17 00:00:00 2001 From: Matthew Penn Date: Mon, 22 Sep 2025 21:30:59 -0400 Subject: [PATCH 1/3] Enhance error handling and event tracking in Context and LangchainProfilerHandler - Introduced standardized error details capturing in IntermediateStepPayload with ErrorDetails model. - Updated Context class to emit success or error events based on function execution outcomes, including detailed error metadata. - Enhanced LangchainProfilerHandler to handle LLM and tool errors, pushing error events with relevant details to the step manager. - Added EventStatus enumeration to track the status of intermediate steps, improving observability in processing flows. This ensures proper nesting of spans, event in the event of tool call/llm/function invocation failures Signed-off-by: Matthew Penn --- src/nat/builder/context.py | 36 +++++-- src/nat/data_models/intermediate_step.py | 21 ++++ .../observability/exporter/span_exporter.py | 5 +- .../callbacks/langchain_callback_handler.py | 100 ++++++++++++++++-- 4 files changed, 146 insertions(+), 16 deletions(-) diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index 072d86f11..280bddcc9 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time +import traceback import typing import uuid from collections.abc import Awaitable @@ -27,6 +29,8 @@ from nat.data_models.authentication import AuthProviderBaseConfig from nat.data_models.interactive import HumanResponse from nat.data_models.interactive import InteractionPrompt +from nat.data_models.intermediate_step import ErrorDetails +from nat.data_models.intermediate_step import EventStatus from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType @@ -203,7 +207,7 @@ def push_active_function(self, metadata: dict[str, typing.Any] | TraceMetadata | None = None): """ Set the 'active_function' in context, push an invocation node, - AND create an OTel child span for that function call. + AND create a child step for that function call. """ parent_function_node = self._context_state.active_function.get() current_function_id = str(uuid.uuid4()) @@ -220,25 +224,41 @@ def push_active_function(self, step_manager.push_intermediate_step( IntermediateStepPayload(UUID=current_function_id, event_type=IntermediateStepType.FUNCTION_START, + status=EventStatus.SUCCESS, name=function_name, data=StreamEventData(input=input_data), metadata=metadata)) manager = ActiveFunctionContextManager() - try: - yield manager # run the function body - finally: - # 3) Record function end - - data = StreamEventData(input=input_data, output=manager.output) + start_time = time.time() + def _emit_end(status: EventStatus, + output_value: typing.Any | None = None, + trace_metadata: dict[str, typing.Any] | TraceMetadata | None = None) -> None: step_manager.push_intermediate_step( IntermediateStepPayload(UUID=current_function_id, event_type=IntermediateStepType.FUNCTION_END, + status=status, + span_event_timestamp=start_time, name=function_name, - data=data)) + data=StreamEventData(input=input_data, output=output_value), + metadata=trace_metadata)) + try: + yield manager # run the function body + except Exception as e: + # 3) Record function end + # push failure event and re-raise + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + error_metadata = TraceMetadata( + error_details=ErrorDetails(message=str(e), exception_type=type(e).__name__, traceback=tb_str)) + _emit_end(EventStatus.ERROR, None, error_metadata) + raise + else: + # 3) Record function end + _emit_end(EventStatus.SUCCESS, manager.output, None) # push success event + finally: # 4) Unset the function contextvar self._context_state.active_function.reset(fn_token) diff --git a/src/nat/data_models/intermediate_step.py b/src/nat/data_models/intermediate_step.py index 0ee2d25cc..469e3b16a 100644 --- a/src/nat/data_models/intermediate_step.py +++ b/src/nat/data_models/intermediate_step.py @@ -103,6 +103,15 @@ class ToolSchema(BaseModel): function: ToolDetails = Field(..., description="The function details.") +class ErrorDetails(BaseModel): + """ + Standardized error details captured for failed intermediate steps. + """ + message: str | None = Field(default=None, description="Human-readable error message.") + exception_type: str | None = Field(default=None, description="Exception class name (e.g., ValueError).") + traceback: str | None = Field(default=None, description="Formatted traceback string, if available.") + + class TraceMetadata(BaseModel): chat_responses: typing.Any | None = None chat_inputs: typing.Any | None = None @@ -114,11 +123,22 @@ class TraceMetadata(BaseModel): provided_metadata: typing.Any | None = None tools_schema: list[ToolSchema] = Field(default_factory=list, description="The schema of tools used in a tool calling request.") + error_details: ErrorDetails | None = Field(default=None, + description="Standardized error details if the step failed.") # Allow extra fields in the model_config to support derived models model_config = ConfigDict(extra="allow") +class EventStatus(str, Enum): + """ + The status of the intermediate step payload, useful to track when a step was successful or not. + """ + SUCCESS = "success" + ERROR = "error" + UNKNOWN = "unknown" + + class IntermediateStepPayload(BaseModel): """ IntermediateStep is a data model that represents an intermediate step in the NAT. Intermediate steps are @@ -140,6 +160,7 @@ class IntermediateStepPayload(BaseModel): data: StreamEventData | None = None usage_info: UsageInfo | None = None UUID: str = Field(default_factory=lambda: str(uuid.uuid4())) + status: EventStatus = Field(default=EventStatus.SUCCESS) @property def event_category(self) -> IntermediateStepCategory: diff --git a/src/nat/observability/exporter/span_exporter.py b/src/nat/observability/exporter/span_exporter.py index 68a389d06..cdb682234 100644 --- a/src/nat/observability/exporter/span_exporter.py +++ b/src/nat/observability/exporter/span_exporter.py @@ -237,6 +237,10 @@ def _process_end_event(self, event: IntermediateStep): sub_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL.value, usage_info.token_usage.total_tokens if usage_info.token_usage else 0) + # Set the status of the span + sub_span.set_attribute(f"{self._span_prefix}.status", + event.payload.status.value if event.payload.status else "unknown") + if event.payload.data and event.payload.data.output is not None: serialized_output, is_json = self._serialize_payload(event.payload.data.output) sub_span.set_attribute(SpanAttributes.OUTPUT_VALUE.value, serialized_output) @@ -264,7 +268,6 @@ def _process_end_event(self, event: IntermediateStep): sub_span.set_attribute(f"{self._span_prefix}.metadata", serialized_metadata) sub_span.set_attribute(f"{self._span_prefix}.metadata.mime_type", MimeTypes.JSON.value if is_json else MimeTypes.TEXT.value) - end_ns = ns_timestamp(event.payload.event_timestamp) # End the subspan diff --git a/src/nat/profiler/callbacks/langchain_callback_handler.py b/src/nat/profiler/callbacks/langchain_callback_handler.py index 4adb847b7..f85fe9125 100644 --- a/src/nat/profiler/callbacks/langchain_callback_handler.py +++ b/src/nat/profiler/callbacks/langchain_callback_handler.py @@ -19,6 +19,7 @@ import logging import threading import time +import traceback from typing import Any from uuid import UUID from uuid import uuid4 @@ -31,6 +32,8 @@ from nat.builder.context import Context from nat.builder.framework_enum import LLMFrameworkEnum +from nat.data_models.intermediate_step import ErrorDetails +from nat.data_models.intermediate_step import EventStatus from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType from nat.data_models.intermediate_step import StreamEventData @@ -112,6 +115,7 @@ async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **k self._run_id_to_model_name[run_id] = model_name stats = IntermediateStepPayload(event_type=IntermediateStepType.LLM_START, + status=EventStatus.SUCCESS, framework=LLMFrameworkEnum.LANGCHAIN, name=model_name, UUID=run_id, @@ -151,6 +155,7 @@ async def on_chat_model_start( stats = IntermediateStepPayload( event_type=IntermediateStepType.LLM_START, + status=EventStatus.SUCCESS, framework=LLMFrameworkEnum.LANGCHAIN, name=model_name, UUID=run_id, @@ -183,6 +188,7 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: stats = IntermediateStepPayload( event_type=IntermediateStepType.LLM_NEW_TOKEN, + status=EventStatus.SUCCESS, framework=LLMFrameworkEnum.LANGCHAIN, name=model_name, UUID=str(kwargs.get("run_id", str(uuid4()))), @@ -200,11 +206,13 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: usage_metadata = {} model_name = "" + run_id_str = str(kwargs.get("run_id", "")) + try: model_name = response.llm_output["model_name"] except Exception as e: try: - model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "") + model_name = self._run_id_to_model_name.get(run_id_str, "") except Exception as e_inner: logger.exception("Error getting model name: %s from outer error %s", e_inner, e) @@ -235,13 +243,13 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # update shared state behind lock with self._lock: usage_stat = IntermediateStepPayload( - span_event_timestamp=self._run_id_to_start_time.get(str(kwargs.get("run_id", "")), time.time()), + span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()), event_type=IntermediateStepType.LLM_END, + status=EventStatus.SUCCESS, framework=LLMFrameworkEnum.LANGCHAIN, name=model_name, UUID=str(kwargs.get("run_id", str(uuid4()))), - data=StreamEventData(input=self._run_id_to_llm_input.get(str(kwargs.get("run_id", "")), ""), - output=llm_text_output), + data=StreamEventData(input=self._run_id_to_llm_input.get(run_id_str, ""), output=llm_text_output), usage_info=UsageInfo(token_usage=self._extract_token_base_model(usage_metadata)), metadata=TraceMetadata(chat_responses=[generation] if generation else [])) @@ -249,6 +257,44 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self._state = IntermediateStepType.LLM_END + # Cleanup LLM state to prevent memory growth + self._run_id_to_model_name.pop(run_id_str, None) + self._run_id_to_llm_input.pop(run_id_str, None) + self._run_id_to_start_time.pop(run_id_str, None) + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + + run_id_str = str(run_id) + model_name = self._run_id_to_model_name.get(run_id_str, "") + + tb_str = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + + stats = IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + status=EventStatus.ERROR, + span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()), + framework=LLMFrameworkEnum.LANGCHAIN, + name=model_name, + UUID=run_id_str, + data=StreamEventData(input=self._run_id_to_llm_input.get(run_id_str, "")), + metadata=TraceMetadata( + error_details=ErrorDetails(message=str(error), exception_type=type(error).__name__, traceback=tb_str)), + usage_info=UsageInfo(token_usage=TokenUsageBaseModel())) + + self.step_manager.push_intermediate_step(stats) + + # Cleanup LLM state to prevent memory growth + self._run_id_to_model_name.pop(run_id_str, None) + self._run_id_to_llm_input.pop(run_id_str, None) + self._run_id_to_start_time.pop(run_id_str, None) + async def on_tool_start( self, serialized: dict[str, Any], @@ -263,6 +309,7 @@ async def on_tool_start( ) -> Any: stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_START, + status=EventStatus.SUCCESS, framework=LLMFrameworkEnum.LANGCHAIN, name=serialized.get("name", ""), UUID=str(run_id), @@ -284,14 +331,53 @@ async def on_tool_end( **kwargs: Any, ) -> Any: + run_id_str = str(run_id) + stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_END, - span_event_timestamp=self._run_id_to_start_time.get(str(run_id), time.time()), + status=EventStatus.SUCCESS, + span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()), framework=LLMFrameworkEnum.LANGCHAIN, name=kwargs.get("name", ""), - UUID=str(run_id), + UUID=run_id_str, metadata=TraceMetadata(tool_outputs=output), usage_info=UsageInfo(token_usage=TokenUsageBaseModel()), - data=StreamEventData(input=self._run_id_to_tool_input.get(str(run_id), ""), + data=StreamEventData(input=self._run_id_to_tool_input.get(run_id_str, ""), output=output)) self.step_manager.push_intermediate_step(stats) + + # Cleanup tool state to prevent memory growth + self._run_id_to_tool_input.pop(run_id_str, None) + self._run_id_to_start_time.pop(run_id_str, None) + + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + + run_id_str = str(run_id) + + tb_str = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + + stats = IntermediateStepPayload( + event_type=IntermediateStepType.TOOL_END, + status=EventStatus.ERROR, + span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()), + framework=LLMFrameworkEnum.LANGCHAIN, + name=kwargs.get("name", ""), + UUID=run_id_str, + metadata=TraceMetadata( + error_details=ErrorDetails(message=str(error), exception_type=type(error).__name__, traceback=tb_str)), + usage_info=UsageInfo(token_usage=TokenUsageBaseModel()), + data=StreamEventData(input=self._run_id_to_tool_input.get(run_id_str, ""))) + + # push error event + self.step_manager.push_intermediate_step(stats) + + # Cleanup tool state to prevent memory growth + self._run_id_to_tool_input.pop(run_id_str, None) + self._run_id_to_start_time.pop(run_id_str, None) From a5e307bb160ec6c7402d3b5dc195bb381bbeedcd Mon Sep 17 00:00:00 2001 From: Matthew Penn Date: Mon, 22 Sep 2025 21:40:47 -0400 Subject: [PATCH 2/3] Update test_span_exporter.py tests Signed-off-by: Matthew Penn --- .../exporter/test_span_exporter.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/nat/observability/exporter/test_span_exporter.py b/tests/nat/observability/exporter/test_span_exporter.py index f4c3d708c..21e79791d 100644 --- a/tests/nat/observability/exporter/test_span_exporter.py +++ b/tests/nat/observability/exporter/test_span_exporter.py @@ -16,11 +16,15 @@ import os import uuid from datetime import datetime +from typing import Any +from typing import cast from unittest.mock import patch import pytest from nat.builder.framework_enum import LLMFrameworkEnum +from nat.data_models.intermediate_step import ErrorDetails +from nat.data_models.intermediate_step import EventStatus from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType @@ -375,7 +379,7 @@ def test_process_end_event_invalid_metadata(self, span_exporter): # Test when end_metadata is not a dict or TraceMetadata after creation end_event = start_event.model_copy() end_event.payload.event_type = IntermediateStepType.LLM_END - end_event.payload.metadata = "invalid_metadata_string" # This is invalid type + end_event.payload.metadata = cast(Any, "invalid_metadata_string") # This is invalid type span_exporter.export(end_event) mock_logger.warning.assert_called() @@ -433,6 +437,50 @@ async def test_cleanup_no_outstanding_spans(self, span_exporter): assert len(span_exporter._span_stack) == 0 assert len(span_exporter._metadata_stack) == 0 + async def test_process_end_event_error_no_output(self, span_exporter): + """Error END events set status="error", avoid OUTPUT_VALUE, and include error_details in metadata.""" + event_id = str(uuid.uuid4()) + + # Start event + start_event = create_intermediate_step(UUID=event_id, + event_type=IntermediateStepType.TOOL_START, + framework=LLMFrameworkEnum.LANGCHAIN, + name="tool_call", + event_timestamp=datetime.now().timestamp(), + data=StreamEventData(input={"arg": 1}), + metadata={"start": True}) + + # Error END event (no output) + error_meta = TraceMetadata(error_details=ErrorDetails(message="boom", exception_type="ValueError")) + end_event = create_intermediate_step(UUID=event_id, + event_type=IntermediateStepType.TOOL_END, + framework=LLMFrameworkEnum.LANGCHAIN, + name="tool_call", + event_timestamp=datetime.now().timestamp(), + span_event_timestamp=datetime.now().timestamp(), + status=EventStatus.ERROR, + data=StreamEventData(), + metadata=error_meta) + + async with span_exporter.start(): + span_exporter.export(start_event) + span_exporter.export(end_event) + await span_exporter.wait_for_tasks() + + assert len(span_exporter.exported_spans) == 1 + exported_span = span_exporter.exported_spans[0] + + # status attribute reflects error + assert exported_span.attributes.get("nat.status") == "error" + + # no OUTPUT_VALUE attribute for error case without output + assert SpanAttributes.OUTPUT_VALUE.value not in exported_span.attributes + + # metadata contains error_details + serialized_meta = exported_span.attributes.get("nat.metadata", "") + assert "error_details" in serialized_meta + assert "\"message\": \"boom\"" in serialized_meta + def test_span_attribute_setting(self, span_exporter, sample_start_event): """Test various span attribute settings.""" # Test with different input formats From 22d0d67e2c4b407ec266dfc75f8a21432df200c0 Mon Sep 17 00:00:00 2001 From: Matthew Penn Date: Tue, 23 Sep 2025 00:38:29 -0400 Subject: [PATCH 3/3] Added tests with coverage for context.py, new methods in LangchainCallbackHandler Signed-off-by: Matthew Penn --- .../callbacks/langchain_callback_handler.py | 57 +- tests/nat/builder/test_context.py | 673 ++++++++++++++++++ .../{ => callbacks}/test_callback_handler.py | 306 +++++++- 3 files changed, 998 insertions(+), 38 deletions(-) create mode 100644 tests/nat/builder/test_context.py rename tests/nat/profiler/{ => callbacks}/test_callback_handler.py (69%) diff --git a/src/nat/profiler/callbacks/langchain_callback_handler.py b/src/nat/profiler/callbacks/langchain_callback_handler.py index f85fe9125..343cbc6e0 100644 --- a/src/nat/profiler/callbacks/langchain_callback_handler.py +++ b/src/nat/profiler/callbacks/langchain_callback_handler.py @@ -132,17 +132,15 @@ async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **k self.last_call_ts = time.time() self._run_id_to_start_time[run_id] = time.time() - async def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - *, - run_id: UUID, - parent_run_id: UUID | None = None, - tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Any: + async def on_chat_model_start(self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any) -> Any: model_name = "" try: @@ -262,14 +260,12 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self._run_id_to_llm_input.pop(run_id_str, None) self._run_id_to_start_time.pop(run_id_str, None) - async def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - **kwargs: Any, - ) -> Any: + async def on_llm_error(self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any) -> Any: run_id_str = str(run_id) model_name = self._run_id_to_model_name.get(run_id_str, "") @@ -322,14 +318,7 @@ async def on_tool_start( self._run_id_to_tool_input[str(run_id)] = input_str self._run_id_to_start_time[str(run_id)] = time.time() - async def on_tool_end( - self, - output: Any, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - **kwargs: Any, - ) -> Any: + async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any: run_id_str = str(run_id) @@ -350,14 +339,12 @@ async def on_tool_end( self._run_id_to_tool_input.pop(run_id_str, None) self._run_id_to_start_time.pop(run_id_str, None) - async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - **kwargs: Any, - ) -> Any: + async def on_tool_error(self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any) -> Any: run_id_str = str(run_id) diff --git a/tests/nat/builder/test_context.py b/tests/nat/builder/test_context.py new file mode 100644 index 000000000..2efac9735 --- /dev/null +++ b/tests/nat/builder/test_context.py @@ -0,0 +1,673 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time +from unittest.mock import MagicMock + +import pytest + +from nat.builder.context import ActiveFunctionContextManager +from nat.builder.context import AIQContext +from nat.builder.context import AIQContextState +from nat.builder.context import Context +from nat.builder.context import ContextState +from nat.builder.context import Singleton +from nat.builder.intermediate_step_manager import IntermediateStepManager +from nat.builder.user_interaction_manager import UserInteractionManager +from nat.data_models.authentication import AuthenticatedContext +from nat.data_models.authentication import AuthProviderBaseConfig +from nat.data_models.interactive import HumanResponse +from nat.data_models.interactive import InteractionPrompt +from nat.data_models.invocation_node import InvocationNode +from nat.runtime.user_metadata import RequestAttributes +from nat.utils.reactive.subject import Subject + +# --------------------------------------------------------------------------- # +# Test Fixtures and Helpers +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def context_state(): + """Create a fresh ContextState instance for testing.""" + # Reset singleton instance to ensure clean state + if hasattr(ContextState, 'instance'): + ContextState.instance = None + return ContextState() + + +@pytest.fixture +def context(context_state): + """Create a Context instance with a fresh ContextState.""" + return Context(context_state) + + +@pytest.fixture +def mock_auth_config(): + """Create a mock AuthProviderBaseConfig.""" + return MagicMock(spec=AuthProviderBaseConfig) + + +@pytest.fixture +def mock_authenticated_context(): + """Create a mock AuthenticatedContext.""" + return MagicMock(spec=AuthenticatedContext) + + +@pytest.fixture +def mock_interaction_prompt(): + """Create a mock InteractionPrompt.""" + return MagicMock(spec=InteractionPrompt) + + +@pytest.fixture +def mock_human_response(): + """Create a mock HumanResponse.""" + return MagicMock(spec=HumanResponse) + + +# --------------------------------------------------------------------------- # +# Test Singleton Metaclass +# --------------------------------------------------------------------------- # + + +class TestSingleton: + """Test the Singleton metaclass behavior.""" + + def test_singleton_creates_single_instance(self): + """Test that Singleton metaclass creates only one instance.""" + + class TestClass(metaclass=Singleton): + + def __init__(self, value=None): + self.value = value + + instance1 = TestClass("first") + instance2 = TestClass("second") + + assert instance1 is instance2 + assert instance1.value == "first" # First initialization wins + assert instance2.value == "first" + + def test_singleton_different_classes_different_instances(self): + """Test that different classes with Singleton metaclass have different instances.""" + + class TestClass1(metaclass=Singleton): + pass + + class TestClass2(metaclass=Singleton): + pass + + instance1 = TestClass1() + instance2 = TestClass2() + + assert instance1 is not instance2 + assert type(instance1) is not type(instance2) + + +# --------------------------------------------------------------------------- # +# Test ActiveFunctionContextManager +# --------------------------------------------------------------------------- # + + +class TestActiveFunctionContextManager: + """Test the ActiveFunctionContextManager class.""" + + def test_initialization(self): + """Test ActiveFunctionContextManager initialization.""" + manager = ActiveFunctionContextManager() + assert manager.output is None + + def test_set_and_get_output(self): + """Test setting and getting output.""" + manager = ActiveFunctionContextManager() + test_output = {"result": "test"} + + manager.set_output(test_output) + assert manager.output == test_output + + def test_output_can_be_any_type(self): + """Test that output can be any type.""" + manager = ActiveFunctionContextManager() + + # Test with different types + test_values = ["string", 42, [1, 2, 3], {"key": "value"}, None, True] + + for value in test_values: + manager.set_output(value) + assert manager.output == value + + +# --------------------------------------------------------------------------- # +# Test ContextState +# --------------------------------------------------------------------------- # + + +class TestContextState: + """Test the ContextState class.""" + + def test_singleton_behavior(self, context_state): + """Test that ContextState follows singleton pattern.""" + state1 = ContextState.get() + state2 = ContextState.get() + assert state1 is state2 + + def test_initialization_sets_context_vars(self, context_state): + """Test that initialization creates all required ContextVars.""" + assert hasattr(context_state, 'conversation_id') + assert hasattr(context_state, 'user_message_id') + assert hasattr(context_state, 'input_message') + assert hasattr(context_state, 'user_manager') + assert hasattr(context_state, '_metadata') + assert hasattr(context_state, '_event_stream') + assert hasattr(context_state, '_active_function') + assert hasattr(context_state, '_active_span_id_stack') + assert hasattr(context_state, 'user_input_callback') + assert hasattr(context_state, 'user_auth_callback') + + def test_context_vars_default_values(self, context_state): + """Test default values of ContextVars.""" + assert context_state.conversation_id.get() is None + assert context_state.user_message_id.get() is None + assert context_state.input_message.get() is None + assert context_state.user_manager.get() is None + assert context_state._metadata.get() is None + assert context_state._event_stream.get() is None + assert context_state._active_function.get() is None + assert context_state._active_span_id_stack.get() is None + assert context_state.user_input_callback.get() is not None # Has default + assert context_state.user_auth_callback.get() is None + + def test_metadata_property_lazy_initialization(self, context_state): + """Test that metadata property creates RequestAttributes when None.""" + # Initially None + assert context_state._metadata.get() is None + + # Accessing property should create instance + metadata = context_state.metadata.get() + assert isinstance(metadata, RequestAttributes) + + # Should be the same instance on subsequent calls + metadata2 = context_state.metadata.get() + assert metadata is metadata2 + + def test_active_function_property_lazy_initialization(self, context_state): + """Test that active_function property creates root InvocationNode when None.""" + # Initially None + assert context_state._active_function.get() is None + + # Accessing property should create root node + active_func = context_state.active_function.get() + assert isinstance(active_func, InvocationNode) + assert active_func.function_id == "root" + assert active_func.function_name == "root" + + # Should be the same instance on subsequent calls + active_func2 = context_state.active_function.get() + assert active_func is active_func2 + + def test_event_stream_property_lazy_initialization(self, context_state): + """Test that event_stream property creates Subject when None.""" + # Initially None + assert context_state._event_stream.get() is None + + # Accessing property should create Subject + stream = context_state.event_stream.get() + assert isinstance(stream, Subject) + + # Should be the same instance on subsequent calls + stream2 = context_state.event_stream.get() + assert stream is stream2 + + def test_active_span_id_stack_property_lazy_initialization(self, context_state): + """Test that active_span_id_stack property creates default stack when None.""" + # Initially None + assert context_state._active_span_id_stack.get() is None + + # Accessing property should create default stack + stack = context_state.active_span_id_stack.get() + assert isinstance(stack, list) + assert stack == ["root"] + + # Should be the same instance on subsequent calls + stack2 = context_state.active_span_id_stack.get() + assert stack is stack2 + + +# --------------------------------------------------------------------------- # +# Test Context +# --------------------------------------------------------------------------- # + + +class TestContext: + """Test the Context class.""" + + def test_initialization(self, context_state): + """Test Context initialization.""" + context = Context(context_state) + assert context._context_state is context_state + + def test_input_message_property(self, context): + """Test input_message property.""" + # Initially None + assert context.input_message is None + + # Set value and test + test_message = "test message" + context._context_state.input_message.set(test_message) + assert context.input_message == test_message + + def test_user_manager_property(self, context): + """Test user_manager property.""" + # Initially None + assert context.user_manager is None + + # Set value and test + test_manager = MagicMock() + context._context_state.user_manager.set(test_manager) + assert context.user_manager is test_manager + + def test_metadata_property(self, context): + """Test metadata property.""" + metadata = context.metadata + assert isinstance(metadata, RequestAttributes) + + def test_user_interaction_manager_property(self, context): + """Test user_interaction_manager property.""" + manager = context.user_interaction_manager + assert isinstance(manager, UserInteractionManager) + + def test_intermediate_step_manager_property(self, context): + """Test intermediate_step_manager property.""" + manager = context.intermediate_step_manager + assert isinstance(manager, IntermediateStepManager) + + def test_conversation_id_property(self, context): + """Test conversation_id property.""" + # Initially None + assert context.conversation_id is None + + # Set value and test + test_id = "conv-123" + context._context_state.conversation_id.set(test_id) + assert context.conversation_id == test_id + + def test_user_message_id_property(self, context): + """Test user_message_id property.""" + # Initially None + assert context.user_message_id is None + + # Set value and test + test_id = "msg-456" + context._context_state.user_message_id.set(test_id) + assert context.user_message_id == test_id + + def test_active_function_property(self, context): + """Test active_function property.""" + active_func = context.active_function + assert isinstance(active_func, InvocationNode) + assert active_func.function_id == "root" + assert active_func.function_name == "root" + + def test_active_span_id_property(self, context): + """Test active_span_id property.""" + span_id = context.active_span_id + assert span_id == "root" + + # Test with modified stack + context._context_state.active_span_id_stack.get().append("child-span") + assert context.active_span_id == "child-span" + + def test_user_auth_callback_property_success(self, context, mock_auth_config, mock_authenticated_context): + """Test user_auth_callback property when callback is set.""" + + async def mock_callback(config, flow_type): + return mock_authenticated_context + + context._context_state.user_auth_callback.set(mock_callback) + callback = context.user_auth_callback + assert callback is mock_callback + + def test_user_auth_callback_property_not_set(self, context): + """Test user_auth_callback property when callback is not set.""" + with pytest.raises(RuntimeError, match="User authentication callback is not set in the context"): + _ = context.user_auth_callback + + def test_get_static_method(self): + """Test Context.get() static method.""" + context = Context.get() + assert isinstance(context, Context) + assert isinstance(context._context_state, ContextState) + + +# --------------------------------------------------------------------------- # +# Test push_active_function Context Manager +# --------------------------------------------------------------------------- # + + +class TestPushActiveFunctionContextManager: + """Test the push_active_function context manager.""" + + def test_push_active_function_basic_functionality(self, context): + """Test basic functionality of push_active_function.""" + function_name = "test_function" + input_data = {"param": "value"} + + # Test that context manager works and returns correct manager + with context.push_active_function(function_name, input_data) as manager: + assert isinstance(manager, ActiveFunctionContextManager) + + # Check that active function was set + active_func = context.active_function + assert active_func.function_name == function_name + assert active_func.parent_id == "root" + assert active_func.parent_name == "root" + + # Set output for testing + test_output = "test result" + manager.set_output(test_output) + assert manager.output == test_output + + def test_push_active_function_with_exception(self, context): + """Test function execution with exception in push_active_function.""" + function_name = "failing_function" + input_data = {"param": "value"} + test_exception = ValueError("Test error") + + # Exception should be re-raised + with pytest.raises(ValueError, match="Test error"): + with context.push_active_function(function_name, input_data): + raise test_exception + + def test_push_active_function_restores_previous_function(self, context): + """Test that push_active_function restores the previous active function.""" + # Set initial active function + initial_func = InvocationNode(function_id="initial", function_name="initial_func") + context._context_state.active_function.set(initial_func) + + function_name = "nested_function" + input_data = {} + + with context.push_active_function(function_name, input_data): + # Inside context, should have new function + active_func = context.active_function + assert active_func.function_name == function_name + assert active_func.parent_id == "initial" + assert active_func.parent_name == "initial_func" + + # After context, should restore initial function + restored_func = context.active_function + assert restored_func is initial_func + + def test_push_active_function_nested_calls(self, context): + """Test nested push_active_function calls.""" + with context.push_active_function("func1", {"data": 1}) as manager1: + func1 = context.active_function + assert func1.function_name == "func1" + assert func1.parent_name == "root" + + with context.push_active_function("func2", {"data": 2}) as manager2: + func2 = context.active_function + assert func2.function_name == "func2" + assert func2.parent_name == "func1" + + # Both managers should be different instances + assert manager1 is not manager2 + + # Should restore to func1 + restored_func1 = context.active_function + assert restored_func1 is func1 + + # Should restore to root + root_func = context.active_function + assert root_func.function_name == "root" + + def test_push_active_function_with_none_input_data(self, context): + """Test push_active_function with None input data.""" + function_name = "test_function" + with context.push_active_function(function_name, None) as manager: + assert isinstance(manager, ActiveFunctionContextManager) + active_func = context.active_function + assert active_func.function_name == function_name + + +# --------------------------------------------------------------------------- # +# Test Context in Multi-threading Environment +# --------------------------------------------------------------------------- # + + +class TestContextMultiThreading: + """Test Context behavior in multi-threading scenarios.""" + + def test_context_vars_isolation_between_threads(self): + """Test that ContextVars maintain isolation between threads.""" + results = {} + + def worker(thread_id): + context = Context.get() + # Set thread-specific values + context._context_state.conversation_id.set(f"conv-{thread_id}") + context._context_state.user_message_id.set(f"msg-{thread_id}") + + # Small delay to ensure threads overlap + time.sleep(0.01) + + # Read values back + results[thread_id] = { + 'conversation_id': context.conversation_id, 'user_message_id': context.user_message_id + } + + # Start multiple threads + threads = [] + for i in range(5): + thread = threading.Thread(target=worker, args=(i, )) + threads.append(thread) + thread.start() + + # Wait for all threads + for thread in threads: + thread.join() + + # Verify each thread maintained its own values + for i in range(5): + assert results[i]['conversation_id'] == f"conv-{i}" + assert results[i]['user_message_id'] == f"msg-{i}" + + @pytest.mark.asyncio + async def test_context_vars_isolation_in_async_tasks(self): + """Test that ContextVars maintain isolation in async tasks.""" + results = {} + + async def async_worker(task_id): + context = Context.get() + # Set task-specific values + context._context_state.conversation_id.set(f"conv-{task_id}") + context._context_state.user_message_id.set(f"msg-{task_id}") + + # Small delay to ensure tasks overlap + await asyncio.sleep(0.01) + + # Read values back + results[task_id] = {'conversation_id': context.conversation_id, 'user_message_id': context.user_message_id} + + # Start multiple async tasks + tasks = [] + for i in range(5): + task = asyncio.create_task(async_worker(i)) + tasks.append(task) + + # Wait for all tasks + await asyncio.gather(*tasks) + + # Verify each task maintained its own values + for i in range(5): + assert results[i]['conversation_id'] == f"conv-{i}" + assert results[i]['user_message_id'] == f"msg-{i}" + + +# --------------------------------------------------------------------------- # +# Test Integration with Managers +# --------------------------------------------------------------------------- # + + +class TestContextManagerIntegration: + """Test Context integration with various managers.""" + + def test_user_interaction_manager_integration(self, context, mock_interaction_prompt, mock_human_response): + """Test integration with UserInteractionManager.""" + + # Set up mock callback + async def mock_callback(prompt): + return mock_human_response + + context._context_state.user_input_callback.set(mock_callback) + + # Get manager and verify it uses the context + user_manager = context.user_interaction_manager + assert isinstance(user_manager, UserInteractionManager) + + # The manager should have access to the context state + assert user_manager._context_state is context._context_state + + def test_intermediate_step_manager_integration(self, context): + """Test integration with IntermediateStepManager.""" + manager = context.intermediate_step_manager + assert isinstance(manager, IntermediateStepManager) + + # The manager should have access to the context state + assert manager._context_state is context._context_state + + +# --------------------------------------------------------------------------- # +# Test Compatibility Aliases +# --------------------------------------------------------------------------- # + + +class TestCompatibilityAliases: + """Test compatibility aliases with previous releases.""" + + def test_aiq_context_state_alias(self): + """Test that AIQContextState is an alias for ContextState.""" + assert AIQContextState is ContextState + + def test_aiq_context_alias(self): + """Test that AIQContext is an alias for Context.""" + assert AIQContext is Context + + def test_aliases_work_as_expected(self, context_state): + """Test that aliases work as expected in practice.""" + # Create instances using aliases + aiq_context_state = AIQContextState() + aiq_context = AIQContext(context_state) + + # Should be the same types as the original classes + assert isinstance(aiq_context_state, ContextState) + assert isinstance(aiq_context, Context) + + +# --------------------------------------------------------------------------- # +# Test Edge Cases and Error Conditions +# --------------------------------------------------------------------------- # + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_context_with_none_context_state(self): + """Test Context behavior with None context state.""" + # This should work but might cause issues when accessing properties + context = Context(None) # type: ignore + + # Accessing properties should raise AttributeError + with pytest.raises(AttributeError): + _ = context.input_message + + def test_active_span_id_with_empty_stack(self, context): + """Test active_span_id behavior with empty stack.""" + # Manually set empty stack (this shouldn't happen in normal usage) + context._context_state._active_span_id_stack.set([]) + + # Should raise IndexError when trying to access last element + with pytest.raises(IndexError): + _ = context.active_span_id + + def test_context_state_singleton_thread_safety(self): + """Test that ContextState singleton is thread-safe.""" + instances = [] + + def create_instance(): + instances.append(ContextState.get()) + + # Create instances from multiple threads + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # All instances should be the same + first_instance = instances[0] + for instance in instances[1:]: + assert instance is first_instance + + +# --------------------------------------------------------------------------- # +# Performance and Stress Tests +# --------------------------------------------------------------------------- # + + +class TestPerformance: + """Test performance characteristics of Context classes.""" + + def test_context_creation_performance(self): + """Test that Context creation is reasonably fast.""" + start_time = time.time() + + # Create many contexts + contexts = [] + for _ in range(1000): + contexts.append(Context.get()) + + end_time = time.time() + + # Should complete in reasonable time (adjust threshold as needed) + assert end_time - start_time < 1.0 # Less than 1 second + + # All contexts should reference the same ContextState (singleton) + first_state = contexts[0]._context_state + for context in contexts[1:]: + assert context._context_state is first_state + + def test_nested_push_active_function_performance(self, context): + """Test performance of nested push_active_function calls.""" + start_time = time.time() + + # Create nested function calls + with context.push_active_function("func1", {}): + with context.push_active_function("func2", {}): + with context.push_active_function("func3", {}): + with context.push_active_function("func4", {}): + with context.push_active_function("func5", {}): + pass + + end_time = time.time() + + # Should complete in reasonable time + assert end_time - start_time < 1.0 # Less than 1 second (relaxed for CI) diff --git a/tests/nat/profiler/test_callback_handler.py b/tests/nat/profiler/callbacks/test_callback_handler.py similarity index 69% rename from tests/nat/profiler/test_callback_handler.py rename to tests/nat/profiler/callbacks/test_callback_handler.py index b078f2e50..34ff1af37 100644 --- a/tests/nat/profiler/test_callback_handler.py +++ b/tests/nat/profiler/callbacks/test_callback_handler.py @@ -21,6 +21,7 @@ from nat.builder.context import Context from nat.builder.framework_enum import LLMFrameworkEnum +from nat.data_models.intermediate_step import EventStatus from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType from nat.data_models.intermediate_step import StreamEventData @@ -246,9 +247,7 @@ async def test_agno_handler_llm_call(reactive_stream: Subject): """ pytest.importorskip("litellm") - from nat.builder.context import Context from nat.profiler.callbacks.agno_callback_handler import AgnoProfilerHandler - from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel # Create handler and set up collection of results all_stats = [] @@ -466,7 +465,6 @@ async def test_agno_handler_tool_execution(reactive_stream: Subject): Note: This test simulates how tool execution is tracked in the tool_wrapper.py since AgnoProfilerHandler doesn't directly patch tool execution. """ - from nat.builder.context import Context from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.invocation_node import InvocationNode from nat.profiler.callbacks.agno_callback_handler import AgnoProfilerHandler @@ -611,3 +609,305 @@ def execute_agno_tool(tool_func, *args, **kwargs): assert end_event.payload.name == "TestTool" assert "result" in end_event.payload.metadata.tool_outputs assert end_event.payload.metadata.tool_outputs["result"] == "Tool execution result" + + +async def test_langchain_handler_on_llm_error(reactive_stream: Subject): + """ + Test that the LangchainProfilerHandler properly handles LLM errors: + - Should generate LLM_END event with ERROR status when on_llm_error is called + - Should capture error details including message, exception type, and traceback + - Should clean up state properly after error + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Simulate an LLM start event first + prompts = ["Tell me about AI"] + run_id = str(uuid4()) + run_uuid = UUID(run_id) + + await handler.on_llm_start(serialized={}, prompts=prompts, run_id=run_uuid) + + # Create a test exception + test_error = ValueError("LLM service temporarily unavailable") + + # Simulate an LLM error + await handler.on_llm_error(error=test_error, run_id=run_uuid) + + # Verify we have 2 events: start and error + assert len(all_stats) == 2, f"Expected 2 events, got {len(all_stats)}" + + # Check the start event + start_event = all_stats[0] + assert start_event.payload.event_type == IntermediateStepType.LLM_START + assert start_event.payload.status == EventStatus.SUCCESS + + # Check the error event + error_event = all_stats[1] + assert error_event.payload.event_type == IntermediateStepType.LLM_END + assert error_event.payload.status == EventStatus.ERROR + assert error_event.payload.UUID == run_id + + # Verify error details are captured + assert error_event.payload.metadata.error_details is not None + assert error_event.payload.metadata.error_details.message == "LLM service temporarily unavailable" + assert error_event.payload.metadata.error_details.exception_type == "ValueError" + assert "LLM service temporarily unavailable" in error_event.payload.metadata.error_details.traceback + + # Verify input is preserved from the start event + assert error_event.payload.data.input == prompts[-1] + + # Verify usage info is included with empty token usage + assert error_event.payload.usage_info is not None + assert error_event.payload.usage_info.token_usage.prompt_tokens == 0 + assert error_event.payload.usage_info.token_usage.completion_tokens == 0 + + # Verify state cleanup - these should be empty after error handling + assert run_id not in handler._run_id_to_model_name + assert run_id not in handler._run_id_to_llm_input + assert run_id not in handler._run_id_to_start_time + + +async def test_langchain_handler_on_llm_error_without_start(reactive_stream: Subject): + """ + Test on_llm_error when called without a preceding on_llm_start: + - The IntermediateStepManager should not push events without matching start events + - This tests that the error handler gracefully handles missing state + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Create a test exception + test_error = RuntimeError("Unexpected LLM failure") + run_id = str(uuid4()) + run_uuid = UUID(run_id) + + # Call on_llm_error without preceding on_llm_start + await handler.on_llm_error(error=test_error, run_id=run_uuid) + + # Verify no events are pushed since there was no matching start event + # This is the correct behavior - the IntermediateStepManager logs a warning and returns + assert len(all_stats) == 0, f"Expected 0 events (no matching start), got {len(all_stats)}" + + # Verify state cleanup still works (should be empty since nothing was added) + assert run_id not in handler._run_id_to_model_name + assert run_id not in handler._run_id_to_llm_input + assert run_id not in handler._run_id_to_start_time + + +async def test_langchain_handler_on_tool_error(reactive_stream: Subject): + """ + Test that the LangchainProfilerHandler properly handles tool errors: + - Should generate TOOL_END event with ERROR status when on_tool_error is called + - Should capture error details including message, exception type, and traceback + - Should clean up state properly after error + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Simulate a tool start event first + tool_input = "search for latest AI research" + run_id = str(uuid4()) + run_uuid = UUID(run_id) + serialized = {"name": "web_search", "description": "Search the web for information"} + inputs = {"query": "latest AI research", "max_results": 5} + + await handler.on_tool_start(serialized=serialized, input_str=tool_input, run_id=run_uuid, inputs=inputs) + + # Create a test exception + test_error = ConnectionError("Failed to connect to search API") + + # Simulate a tool error + await handler.on_tool_error(error=test_error, run_id=run_uuid, name="web_search") + + # Verify we have 2 events: start and error + assert len(all_stats) == 2, f"Expected 2 events, got {len(all_stats)}" + + # Check the start event + start_event = all_stats[0] + assert start_event.payload.event_type == IntermediateStepType.TOOL_START + assert start_event.payload.status == EventStatus.SUCCESS + assert start_event.payload.name == "web_search" + + # Check the error event + error_event = all_stats[1] + assert error_event.payload.event_type == IntermediateStepType.TOOL_END + assert error_event.payload.status == EventStatus.ERROR + assert error_event.payload.UUID == run_id + assert error_event.payload.name == "web_search" + + # Verify error details are captured + assert error_event.payload.metadata.error_details is not None + assert error_event.payload.metadata.error_details.message == "Failed to connect to search API" + assert error_event.payload.metadata.error_details.exception_type == "ConnectionError" + assert "Failed to connect to search API" in error_event.payload.metadata.error_details.traceback + + # Verify input is preserved from the start event + assert error_event.payload.data.input == tool_input + + # Verify usage info is included with empty token usage + assert error_event.payload.usage_info is not None + assert error_event.payload.usage_info.token_usage.prompt_tokens == 0 + assert error_event.payload.usage_info.token_usage.completion_tokens == 0 + + # Verify state cleanup + assert run_id not in handler._run_id_to_tool_input + assert run_id not in handler._run_id_to_start_time + + +async def test_langchain_handler_on_tool_error_without_start(reactive_stream: Subject): + """ + Test on_tool_error when called without a preceding on_tool_start: + - The IntermediateStepManager should not push events without matching start events + - This tests that the error handler gracefully handles missing state + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Create a test exception + test_error = TimeoutError("Tool execution timed out") + run_id = str(uuid4()) + run_uuid = UUID(run_id) + + # Call on_tool_error without preceding on_tool_start + await handler.on_tool_error(error=test_error, run_id=run_uuid, name="timeout_tool") + + # Verify no events are pushed since there was no matching start event + # This is the correct behavior - the IntermediateStepManager logs a warning and returns + assert len(all_stats) == 0, f"Expected 0 events (no matching start), got {len(all_stats)}" + + # Verify state cleanup still works (should be empty since nothing was added) + assert run_id not in handler._run_id_to_tool_input + assert run_id not in handler._run_id_to_start_time + + +async def test_langchain_handler_error_with_complex_exception(reactive_stream: Subject): + """ + Test error handling with a complex exception that has nested causes: + - Should capture full traceback information + - Should handle exceptions with custom attributes + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Create a complex exception with cause chain + root_cause = ValueError("Invalid API key") + wrapper_error = RuntimeError("Authentication failed") + wrapper_error.__cause__ = root_cause + + run_id = str(uuid4()) + run_uuid = UUID(run_id) + + # Set up some state first + await handler.on_llm_start(serialized={}, prompts=["test"], run_id=run_uuid) + + # Test LLM error with complex exception + await handler.on_llm_error(error=wrapper_error, run_id=run_uuid) + + assert len(all_stats) == 2 + error_event = all_stats[1] + + # Verify the wrapper exception is captured + assert error_event.payload.metadata.error_details.message == "Authentication failed" + assert error_event.payload.metadata.error_details.exception_type == "RuntimeError" + + # Verify traceback contains information about the cause chain + traceback_str = error_event.payload.metadata.error_details.traceback + assert "RuntimeError: Authentication failed" in traceback_str + # The exact format of cause chains in tracebacks may vary, but we should see the root cause + assert "ValueError: Invalid API key" in traceback_str + + +async def test_langchain_handler_concurrent_error_handling(reactive_stream: Subject): + """ + Test that error handling works correctly with concurrent requests: + - Multiple LLM calls with different run_ids + - Some succeed, some fail + - State cleanup should be isolated per run_id + """ + from uuid import UUID + + all_stats = [] + handler = LangchainProfilerHandler() + _ = reactive_stream.subscribe(all_stats.append) + + # Create multiple concurrent requests + run_id_1 = str(uuid4()) + run_id_2 = str(uuid4()) + run_id_3 = str(uuid4()) + + # Start all three requests + await asyncio.gather(handler.on_llm_start(serialized={}, prompts=["prompt 1"], run_id=UUID(run_id_1)), + handler.on_llm_start(serialized={}, prompts=["prompt 2"], run_id=UUID(run_id_2)), + handler.on_llm_start(serialized={}, prompts=["prompt 3"], run_id=UUID(run_id_3))) + + # Let two succeed and one fail + from langchain_core.messages import AIMessage + from langchain_core.outputs import ChatGeneration + from langchain_core.outputs import LLMResult + + # Success for run_id_1 + generation_1 = ChatGeneration(message=AIMessage(content="Response 1")) + llm_result_1 = LLMResult(generations=[[generation_1]]) + + # Error for run_id_2 + error_2 = ValueError("Model overloaded") + + # Success for run_id_3 + generation_3 = ChatGeneration(message=AIMessage(content="Response 3")) + llm_result_3 = LLMResult(generations=[[generation_3]]) + + # Process end events + await asyncio.gather(handler.on_llm_end(response=llm_result_1, run_id=UUID(run_id_1)), + handler.on_llm_error(error=error_2, run_id=UUID(run_id_2)), + handler.on_llm_end(response=llm_result_3, run_id=UUID(run_id_3))) + + # Should have 6 events total: 3 starts + 2 ends + 1 error + assert len(all_stats) == 6 + + # Find events by run_id + events_by_run_id = {} + for event in all_stats: + run_id = event.payload.UUID + if run_id not in events_by_run_id: + events_by_run_id[run_id] = [] + events_by_run_id[run_id].append(event) + + # Verify run_id_1 (success) + assert len(events_by_run_id[run_id_1]) == 2 + assert events_by_run_id[run_id_1][0].payload.event_type == IntermediateStepType.LLM_START + assert events_by_run_id[run_id_1][1].payload.event_type == IntermediateStepType.LLM_END + assert events_by_run_id[run_id_1][1].payload.status == EventStatus.SUCCESS + + # Verify run_id_2 (error) + assert len(events_by_run_id[run_id_2]) == 2 + assert events_by_run_id[run_id_2][0].payload.event_type == IntermediateStepType.LLM_START + assert events_by_run_id[run_id_2][1].payload.event_type == IntermediateStepType.LLM_END + assert events_by_run_id[run_id_2][1].payload.status == EventStatus.ERROR + assert events_by_run_id[run_id_2][1].payload.metadata.error_details.message == "Model overloaded" + + # Verify run_id_3 (success) + assert len(events_by_run_id[run_id_3]) == 2 + assert events_by_run_id[run_id_3][0].payload.event_type == IntermediateStepType.LLM_START + assert events_by_run_id[run_id_3][1].payload.event_type == IntermediateStepType.LLM_END + assert events_by_run_id[run_id_3][1].payload.status == EventStatus.SUCCESS + + # Verify all state has been cleaned up + assert len(handler._run_id_to_model_name) == 0 + assert len(handler._run_id_to_llm_input) == 0 + assert len(handler._run_id_to_start_time) == 0