Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/nat/builder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import traceback
import typing
import uuid
from collections.abc import Awaitable
Expand All @@ -27,6 +29,8 @@
from nat.data_models.authentication import AuthProviderBaseConfig
from nat.data_models.interactive import HumanResponse
from nat.data_models.interactive import InteractionPrompt
from nat.data_models.intermediate_step import ErrorDetails
from nat.data_models.intermediate_step import EventStatus
from nat.data_models.intermediate_step import IntermediateStep
from nat.data_models.intermediate_step import IntermediateStepPayload
from nat.data_models.intermediate_step import IntermediateStepType
Expand Down Expand Up @@ -203,7 +207,7 @@ def push_active_function(self,
metadata: dict[str, typing.Any] | TraceMetadata | None = None):
"""
Set the 'active_function' in context, push an invocation node,
AND create an OTel child span for that function call.
AND create a child step for that function call.
"""
parent_function_node = self._context_state.active_function.get()
current_function_id = str(uuid.uuid4())
Expand All @@ -220,25 +224,41 @@ def push_active_function(self,
step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=current_function_id,
event_type=IntermediateStepType.FUNCTION_START,
status=EventStatus.SUCCESS,
name=function_name,
data=StreamEventData(input=input_data),
metadata=metadata))

manager = ActiveFunctionContextManager()

try:
yield manager # run the function body
finally:
# 3) Record function end

data = StreamEventData(input=input_data, output=manager.output)
start_time = time.time()

def _emit_end(status: EventStatus,
output_value: typing.Any | None = None,
trace_metadata: dict[str, typing.Any] | TraceMetadata | None = None) -> None:
step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=current_function_id,
event_type=IntermediateStepType.FUNCTION_END,
status=status,
span_event_timestamp=start_time,
name=function_name,
data=data))
data=StreamEventData(input=input_data, output=output_value),
metadata=trace_metadata))

try:
yield manager # run the function body
except Exception as e:
# 3) Record function end
# push failure event and re-raise
tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
error_metadata = TraceMetadata(
error_details=ErrorDetails(message=str(e), exception_type=type(e).__name__, traceback=tb_str))
_emit_end(EventStatus.ERROR, None, error_metadata)
raise
else:
# 3) Record function end
_emit_end(EventStatus.SUCCESS, manager.output, None) # push success event
finally:
# 4) Unset the function contextvar
self._context_state.active_function.reset(fn_token)

Expand Down
21 changes: 21 additions & 0 deletions src/nat/data_models/intermediate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class ToolSchema(BaseModel):
function: ToolDetails = Field(..., description="The function details.")


class ErrorDetails(BaseModel):
"""
Standardized error details captured for failed intermediate steps.
"""
message: str | None = Field(default=None, description="Human-readable error message.")
exception_type: str | None = Field(default=None, description="Exception class name (e.g., ValueError).")
traceback: str | None = Field(default=None, description="Formatted traceback string, if available.")


class TraceMetadata(BaseModel):
chat_responses: typing.Any | None = None
chat_inputs: typing.Any | None = None
Expand All @@ -114,11 +123,22 @@ class TraceMetadata(BaseModel):
provided_metadata: typing.Any | None = None
tools_schema: list[ToolSchema] = Field(default_factory=list,
description="The schema of tools used in a tool calling request.")
error_details: ErrorDetails | None = Field(default=None,
description="Standardized error details if the step failed.")

# Allow extra fields in the model_config to support derived models
model_config = ConfigDict(extra="allow")


class EventStatus(str, Enum):
"""
The status of the intermediate step payload, useful to track when a step was successful or not.
"""
SUCCESS = "success"
ERROR = "error"
UNKNOWN = "unknown"


class IntermediateStepPayload(BaseModel):
"""
IntermediateStep is a data model that represents an intermediate step in the NAT. Intermediate steps are
Expand All @@ -140,6 +160,7 @@ class IntermediateStepPayload(BaseModel):
data: StreamEventData | None = None
usage_info: UsageInfo | None = None
UUID: str = Field(default_factory=lambda: str(uuid.uuid4()))
status: EventStatus = Field(default=EventStatus.SUCCESS)

@property
def event_category(self) -> IntermediateStepCategory:
Expand Down
5 changes: 4 additions & 1 deletion src/nat/observability/exporter/span_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def _process_end_event(self, event: IntermediateStep):
sub_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL.value,
usage_info.token_usage.total_tokens if usage_info.token_usage else 0)

# Set the status of the span
sub_span.set_attribute(f"{self._span_prefix}.status",
event.payload.status.value if event.payload.status else "unknown")

if event.payload.data and event.payload.data.output is not None:
serialized_output, is_json = self._serialize_payload(event.payload.data.output)
sub_span.set_attribute(SpanAttributes.OUTPUT_VALUE.value, serialized_output)
Expand Down Expand Up @@ -264,7 +268,6 @@ def _process_end_event(self, event: IntermediateStep):
sub_span.set_attribute(f"{self._span_prefix}.metadata", serialized_metadata)
sub_span.set_attribute(f"{self._span_prefix}.metadata.mime_type",
MimeTypes.JSON.value if is_json else MimeTypes.TEXT.value)

end_ns = ns_timestamp(event.payload.event_timestamp)

# End the subspan
Expand Down
125 changes: 99 additions & 26 deletions src/nat/profiler/callbacks/langchain_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import threading
import time
import traceback
from typing import Any
from uuid import UUID
from uuid import uuid4
Expand All @@ -31,6 +32,8 @@

from nat.builder.context import Context
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.data_models.intermediate_step import ErrorDetails
from nat.data_models.intermediate_step import EventStatus
from nat.data_models.intermediate_step import IntermediateStepPayload
from nat.data_models.intermediate_step import IntermediateStepType
from nat.data_models.intermediate_step import StreamEventData
Expand Down Expand Up @@ -112,6 +115,7 @@ async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **k
self._run_id_to_model_name[run_id] = model_name

stats = IntermediateStepPayload(event_type=IntermediateStepType.LLM_START,
status=EventStatus.SUCCESS,
framework=LLMFrameworkEnum.LANGCHAIN,
name=model_name,
UUID=run_id,
Expand All @@ -128,17 +132,15 @@ async def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **k
self.last_call_ts = time.time()
self._run_id_to_start_time[run_id] = time.time()

async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
async def on_chat_model_start(self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any) -> Any:

model_name = ""
try:
Expand All @@ -151,6 +153,7 @@ async def on_chat_model_start(

stats = IntermediateStepPayload(
event_type=IntermediateStepType.LLM_START,
status=EventStatus.SUCCESS,
framework=LLMFrameworkEnum.LANGCHAIN,
name=model_name,
UUID=run_id,
Expand Down Expand Up @@ -183,6 +186,7 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:

stats = IntermediateStepPayload(
event_type=IntermediateStepType.LLM_NEW_TOKEN,
status=EventStatus.SUCCESS,
framework=LLMFrameworkEnum.LANGCHAIN,
name=model_name,
UUID=str(kwargs.get("run_id", str(uuid4()))),
Expand All @@ -200,11 +204,13 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
usage_metadata = {}

model_name = ""
run_id_str = str(kwargs.get("run_id", ""))

try:
model_name = response.llm_output["model_name"]
except Exception as e:
try:
model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "")
model_name = self._run_id_to_model_name.get(run_id_str, "")
except Exception as e_inner:
logger.exception("Error getting model name: %s from outer error %s", e_inner, e)

Expand Down Expand Up @@ -235,20 +241,56 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# update shared state behind lock
with self._lock:
usage_stat = IntermediateStepPayload(
span_event_timestamp=self._run_id_to_start_time.get(str(kwargs.get("run_id", "")), time.time()),
span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()),
event_type=IntermediateStepType.LLM_END,
status=EventStatus.SUCCESS,
framework=LLMFrameworkEnum.LANGCHAIN,
name=model_name,
UUID=str(kwargs.get("run_id", str(uuid4()))),
data=StreamEventData(input=self._run_id_to_llm_input.get(str(kwargs.get("run_id", "")), ""),
output=llm_text_output),
data=StreamEventData(input=self._run_id_to_llm_input.get(run_id_str, ""), output=llm_text_output),
usage_info=UsageInfo(token_usage=self._extract_token_base_model(usage_metadata)),
metadata=TraceMetadata(chat_responses=[generation] if generation else []))

self.step_manager.push_intermediate_step(usage_stat)

self._state = IntermediateStepType.LLM_END

# Cleanup LLM state to prevent memory growth
self._run_id_to_model_name.pop(run_id_str, None)
self._run_id_to_llm_input.pop(run_id_str, None)
self._run_id_to_start_time.pop(run_id_str, None)

async def on_llm_error(self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any) -> Any:

run_id_str = str(run_id)
model_name = self._run_id_to_model_name.get(run_id_str, "")

tb_str = "".join(traceback.format_exception(type(error), error, error.__traceback__))

stats = IntermediateStepPayload(
event_type=IntermediateStepType.LLM_END,
status=EventStatus.ERROR,
span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()),
framework=LLMFrameworkEnum.LANGCHAIN,
name=model_name,
UUID=run_id_str,
data=StreamEventData(input=self._run_id_to_llm_input.get(run_id_str, "")),
metadata=TraceMetadata(
error_details=ErrorDetails(message=str(error), exception_type=type(error).__name__, traceback=tb_str)),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()))

self.step_manager.push_intermediate_step(stats)

# Cleanup LLM state to prevent memory growth
self._run_id_to_model_name.pop(run_id_str, None)
self._run_id_to_llm_input.pop(run_id_str, None)
self._run_id_to_start_time.pop(run_id_str, None)

async def on_tool_start(
self,
serialized: dict[str, Any],
Expand All @@ -263,6 +305,7 @@ async def on_tool_start(
) -> Any:

stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_START,
status=EventStatus.SUCCESS,
framework=LLMFrameworkEnum.LANGCHAIN,
name=serialized.get("name", ""),
UUID=str(run_id),
Expand All @@ -275,23 +318,53 @@ async def on_tool_start(
self._run_id_to_tool_input[str(run_id)] = input_str
self._run_id_to_start_time[str(run_id)] = time.time()

async def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:

run_id_str = str(run_id)

stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_END,
span_event_timestamp=self._run_id_to_start_time.get(str(run_id), time.time()),
status=EventStatus.SUCCESS,
span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()),
framework=LLMFrameworkEnum.LANGCHAIN,
name=kwargs.get("name", ""),
UUID=str(run_id),
UUID=run_id_str,
metadata=TraceMetadata(tool_outputs=output),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
data=StreamEventData(input=self._run_id_to_tool_input.get(str(run_id), ""),
data=StreamEventData(input=self._run_id_to_tool_input.get(run_id_str, ""),
output=output))

self.step_manager.push_intermediate_step(stats)

# Cleanup tool state to prevent memory growth
self._run_id_to_tool_input.pop(run_id_str, None)
self._run_id_to_start_time.pop(run_id_str, None)

async def on_tool_error(self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any) -> Any:

run_id_str = str(run_id)

tb_str = "".join(traceback.format_exception(type(error), error, error.__traceback__))

stats = IntermediateStepPayload(
event_type=IntermediateStepType.TOOL_END,
status=EventStatus.ERROR,
span_event_timestamp=self._run_id_to_start_time.get(run_id_str, time.time()),
framework=LLMFrameworkEnum.LANGCHAIN,
name=kwargs.get("name", ""),
UUID=run_id_str,
metadata=TraceMetadata(
error_details=ErrorDetails(message=str(error), exception_type=type(error).__name__, traceback=tb_str)),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
data=StreamEventData(input=self._run_id_to_tool_input.get(run_id_str, "")))

# push error event
self.step_manager.push_intermediate_step(stats)

# Cleanup tool state to prevent memory growth
self._run_id_to_tool_input.pop(run_id_str, None)
self._run_id_to_start_time.pop(run_id_str, None)
Loading
Loading