Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/aiq/profiler/callbacks/langchain_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ async def on_tool_start(
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()))

self.step_manager.push_intermediate_step(stats)
self._run_id_to_tool_input[str(run_id)] = input_str
self._run_id_to_tool_input[str(run_id)] = copy.deepcopy(inputs)
self._run_id_to_start_time[str(run_id)] = time.time()

async def on_tool_end(
Expand All @@ -277,14 +277,15 @@ async def on_tool_end(
**kwargs: Any,
) -> Any:

inputs = self._run_id_to_tool_input.get(str(run_id), "")

stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_END,
span_event_timestamp=self._run_id_to_start_time.get(str(run_id), time.time()),
framework=LLMFrameworkEnum.LANGCHAIN,
name=kwargs.get("name", ""),
UUID=str(run_id),
metadata=TraceMetadata(tool_outputs=output),
metadata=TraceMetadata(tool_inputs=inputs, tool_outputs=output),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
data=StreamEventData(input=self._run_id_to_tool_input.get(str(run_id), ""),
output=output))
data=StreamEventData(input=inputs, output=output))

self.step_manager.push_intermediate_step(stats)
129 changes: 104 additions & 25 deletions tests/aiq/profiler/test_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ async def test_llama_index_handler_order(reactive_stream: Subject):
from llama_index.core.llms import ChatMessage
from llama_index.core.llms import ChatResponse

payload_start = {EventPayload.PROMPT: "Say something wise."}
payload_start = {EventPayload.PROMPT.value: "Say something wise."}
handler.on_event_start(event_type=CBEventType.LLM, payload=payload_start, event_id="evt-1")

# Simulate an LLM end event
payload_end = {
EventPayload.RESPONSE:
EventPayload.RESPONSE.value:
ChatResponse(message=ChatMessage.from_str("42 is the meaning of life."), raw="42 is the meaning of life.")
}
handler.on_event_end(event_type=CBEventType.LLM, payload=payload_end, event_id="evt-1")
Expand Down Expand Up @@ -246,9 +246,7 @@ async def test_agno_handler_llm_call(reactive_stream: Subject):
"""
pytest.importorskip("litellm")

from aiq.builder.context import AIQContext
from aiq.profiler.callbacks.agno_callback_handler import AgnoProfilerHandler
from aiq.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel

# Create handler and set up collection of results
all_stats = []
Expand All @@ -258,7 +256,7 @@ async def test_agno_handler_llm_call(reactive_stream: Subject):
step_manager = AIQContext.get().intermediate_step_manager

# Mock the original LLM call function that would be patched
def original_completion(*args, **kwargs):
def original_completion(*_args, **_kwargs): # pylint: disable=unused-argument
return None

handler._original_llm_call = original_completion
Expand Down Expand Up @@ -405,32 +403,35 @@ def wrapped(*args, **kwargs):
# Find IntermediateStep objects in all_stats
intermediate_steps = [event for event in all_stats if hasattr(event, 'payload')]

# If we don't have IntermediateStep objects, check step_manager
# If we don't have IntermediateStep objects, check all_stats directly
if len(intermediate_steps) < 2:
print("Not enough IntermediateStep objects in all_stats, checking step_manager...")
steps = step_manager.get_intermediate_steps()
print(f"Found {len(steps)} steps in step_manager")
for i, step in enumerate(steps):
print(f"Step {i}: {step.event_type}")
print("Not enough IntermediateStep objects in all_stats, using all_stats directly...")
print(f"Found {len(all_stats)} items in all_stats")
for i, stat in enumerate(all_stats):
print(f"Stat {i}: {type(stat)}")

# Verify steps in step_manager
assert len(steps) >= 2, f"Expected at least 2 steps in step_manager, got {len(steps)}"
# Verify we have events in all_stats
assert len(all_stats) >= 2, f"Expected at least 2 events in all_stats, got {len(all_stats)}"

# Find the START and END events from step_manager
start_events = [s for s in steps if s.event_type == IntermediateStepType.LLM_START]
end_events = [s for s in steps if s.event_type == IntermediateStepType.LLM_END]
# Find the START and END events from all_stats
start_events = [
s for s in all_stats if hasattr(s, 'payload') and s.payload.event_type == IntermediateStepType.LLM_START
]
end_events = [
s for s in all_stats if hasattr(s, 'payload') and s.payload.event_type == IntermediateStepType.LLM_END
]

assert len(start_events) > 0, "No LLM_START events found in step_manager"
assert len(end_events) > 0, "No LLM_END events found in step_manager"
assert len(start_events) > 0, "No LLM_START events found in all_stats"
assert len(end_events) > 0, "No LLM_END events found in all_stats"

# Use the latest events for our test
start_event = start_events[-1]
end_event = end_events[-1]

# Check token usage values in the end event
assert end_event.usage_info.token_usage.prompt_tokens == token_usage_obj.prompt_tokens
assert end_event.usage_info.token_usage.completion_tokens == token_usage_obj.completion_tokens
assert end_event.usage_info.token_usage.total_tokens == token_usage_obj.total_tokens
assert end_event.payload.usage_info.token_usage.prompt_tokens == token_usage_obj.prompt_tokens
assert end_event.payload.usage_info.token_usage.completion_tokens == token_usage_obj.completion_tokens
assert end_event.payload.usage_info.token_usage.total_tokens == token_usage_obj.total_tokens
else:
# Find the START and END events in our intermediate steps
start_events = [e for e in intermediate_steps if e.payload.event_type == IntermediateStepType.LLM_START]
Expand Down Expand Up @@ -466,7 +467,6 @@ async def test_agno_handler_tool_execution(reactive_stream: Subject):
Note: This test simulates how tool execution is tracked in the tool_wrapper.py
since AgnoProfilerHandler doesn't directly patch tool execution.
"""
from aiq.builder.context import AIQContext
from aiq.data_models.intermediate_step import IntermediateStep
from aiq.data_models.invocation_node import InvocationNode
from aiq.profiler.callbacks.agno_callback_handler import AgnoProfilerHandler
Expand All @@ -479,7 +479,7 @@ async def test_agno_handler_tool_execution(reactive_stream: Subject):
step_manager = AIQContext.get().intermediate_step_manager

# Define a simple tool function
def sample_tool(arg1, arg2, param1=None, tool_name="SampleTool"):
def sample_tool(arg1, arg2, param1=None, _tool_name="SampleTool"): # pylint: disable=unused-argument
print(f"Tool called with {arg1}, {arg2}, {param1}")
return "Tool execution result"

Expand Down Expand Up @@ -516,7 +516,9 @@ def execute_agno_tool(tool_func, *args, **kwargs):

# Call the tool function
try:
result = tool_func(*args, **kwargs)
# Remove tool_name from kwargs before calling the function since it's only used for metadata
tool_kwargs = {k: v for k, v in kwargs.items() if k != "tool_name"}
result = tool_func(*args, **tool_kwargs)

# Create end event payload
end_payload = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_END,
Expand Down Expand Up @@ -606,8 +608,85 @@ def execute_agno_tool(tool_func, *args, **kwargs):
# Verify event details
assert start_event.payload.name == "TestTool"
assert "args" in start_event.payload.metadata.tool_inputs
assert tool_args[0] in start_event.payload.metadata.tool_inputs["args"]
assert tool_args[0] in start_event.payload.metadata.tool_inputs.get("args", [])

assert end_event.payload.name == "TestTool"
assert "result" in end_event.payload.metadata.tool_outputs
assert end_event.payload.metadata.tool_outputs["result"] == "Tool execution result"


async def test_langchain_handler_tool_execution(reactive_stream: Subject):
"""
Test that the LangchainProfilerHandler properly stores and retrieves
structured tool inputs for TOOL_START and TOOL_END events.
"""

all_stats = []
handler = LangchainProfilerHandler()
_ = reactive_stream.subscribe(all_stats.append)

# Simulate a tool start event with structured inputs
tool_name = "TestSearchTool"
run_id = uuid4()

# Create structured input data
structured_inputs = {
"query": "test search query",
"max_results": 5,
"filters": {
"date_range": {
"start": "2025-01-01", "end": "2025-12-31"
}, "category": ["tech", "science"]
}
}

await handler.on_tool_start(
serialized={"name": tool_name}, # Tool name should be in serialized dict
input_str="test search query", # This was the old format
run_id=run_id,
inputs=structured_inputs, # This is the new structured format
name=tool_name)

# Simulate tool processing time
await asyncio.sleep(0.1)

# Create tool output
tool_output = {
"results": [{
"title": "Result 1", "url": "http://example.com/1"
}, {
"title": "Result 2", "url": "http://example.com/2"
}],
"count": 2
}

# Simulate tool end event
await handler.on_tool_end(output=tool_output, run_id=run_id, name=tool_name)

# Verify we have the correct number of events
assert len(all_stats) == 2, f"Expected 2 events but got {len(all_stats)}"

tool_start_event = all_stats[0]
tool_end_event = all_stats[1]

# Verify TOOL_START event
assert tool_start_event.event_type == IntermediateStepType.TOOL_START
assert tool_start_event.name == tool_name
assert tool_start_event.framework == LLMFrameworkEnum.LANGCHAIN

# Verify TOOL_END event
assert tool_end_event.event_type == IntermediateStepType.TOOL_END
assert tool_end_event.name == tool_name
assert tool_end_event.framework == LLMFrameworkEnum.LANGCHAIN

# Verify that structured inputs are preserved in TOOL_END event
assert tool_end_event.metadata.tool_inputs == structured_inputs
assert tool_end_event.metadata.tool_outputs == tool_output
assert tool_end_event.data.input == structured_inputs
assert tool_end_event.data.output == tool_output

# Verify that the inputs are deep copied (not just referenced)
# Modify original inputs and ensure event data is unchanged
structured_inputs["query"] = "modified query"
assert tool_end_event.metadata.tool_inputs.get("query") == "test search query"
assert tool_end_event.data.input.get("query") == "test search query"