Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shortfin llm] Add base classes for tracing #1067

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
125 changes: 125 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import functools
import logging
import time
import asyncio
from contextlib import contextmanager, asynccontextmanager
from typing import Any, Optional, Generator, AsyncGenerator, Callable, Union

# Configure logger
logger = logging.getLogger("shortfin-llm.tracing")

# Base class for tracing backends
class BaseTracingBackend:
def init(self, app_name: str) -> None:
pass

def frame_enter(self, frame_name: str, request_id: str) -> None:
pass

def frame_exit(self, frame_name: str, request_id: str) -> None:
pass


# Logging-based backend
class LoggerTracingBackend(BaseTracingBackend):
def __init__(self):
# Frame tracking - maps (frame_name, request_id) to start time
self._frames = {}

def init(self, app_name: str) -> None:
pass

def frame_enter(self, frame_name: str, request_id: str) -> None:
key = (frame_name, request_id)
self._frames[key] = time.time()

def frame_exit(self, frame_name: str, request_id: str) -> None:
key = (frame_name, request_id)
if key not in self._frames:
logger.warning(
f"TRACE: Exit without matching enter for {frame_name} [task={request_id}]"
)
return

duration_ms = round((time.time() - self._frames[key]) * 1e3)
del self._frames[key]

msg = f"TRACE: {frame_name} [task={request_id}] completed in {duration_ms}ms"
logger.info(msg)


# Global tracing configuration
class TracingConfig:
enabled: bool = True
app_name: str = "ShortfinLLM"
backend: BaseTracingBackend = LoggerTracingBackend()
_initialized: bool = False

@classmethod
def is_enabled(cls) -> bool:
return cls.enabled

@classmethod
def set_enabled(cls, enabled: bool) -> None:
cls.enabled = enabled
cls._ensure_initialized()

@classmethod
def set_backend(cls, backend_name: str) -> None:
if backend_name == "log":
cls.backend = LoggerTracingBackend()
elif backend_name == "tracy":
# Import the Tracy backend when requested
try:
from .tracy_tracing import TracyTracingBackend

cls.backend = TracyTracingBackend()
except ImportError as e:
raise NotImplementedError("Tracy backend is not implemented")
else:
raise ValueError(f"Unsupported tracing backend: {backend_name}")
cls._ensure_initialized()

@classmethod
def set_app_name(cls, app_name: str) -> None:
cls.app_name = app_name
cls._ensure_initialized()

@classmethod
def _ensure_initialized(cls) -> None:
if not cls._initialized and cls.enabled:
cls.backend.init(cls.app_name)
cls._initialized = True


# Context managers for manual tracing
@contextmanager
def trace_context(frame_name: str, request_id: str) -> Generator[None, None, None]:
"""Context manager for manual tracing of code blocks."""
if not TracingConfig.is_enabled():
yield
return

TracingConfig._ensure_initialized()
TracingConfig.backend.frame_enter(frame_name, request_id)
try:
yield
finally:
TracingConfig.backend.frame_exit(frame_name, request_id)


@asynccontextmanager
async def async_trace_context(
frame_name: str, request_id: str
) -> AsyncGenerator[None, None]:
"""Async context manager for manual tracing of code blocks."""
if not TracingConfig.is_enabled():
yield
return

TracingConfig._ensure_initialized()
TracingConfig.backend.frame_enter(frame_name, request_id)
try:
yield
finally:
TracingConfig.backend.frame_exit(frame_name, request_id)
1 change: 1 addition & 0 deletions shortfin/requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pytest
pytest-timeout
pytest-asyncio
requests
fastapi
onnx
Expand Down
125 changes: 125 additions & 0 deletions shortfin/tests/apps/llm/components/tracing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
import time
import asyncio
import logging
from unittest.mock import patch, MagicMock

from shortfin_apps.llm.components.tracing import (
LoggerTracingBackend,
TracingConfig,
trace_context,
async_trace_context,
)


class TestLoggerTracingBackend:
def test_frame_tracking(self):
backend = LoggerTracingBackend()
backend.init("TestApp")

backend.frame_enter("test_frame", "task123")
assert len(backend._frames) == 1
assert ("test_frame", "task123") in backend._frames

with patch("logging.Logger.info") as mock_info:
backend.frame_exit("test_frame", "task123")
mock_info.assert_called_once()
assert (
"TRACE: test_frame [task=task123] completed in"
in mock_info.call_args[0][0]
)

assert len(backend._frames) == 0

def test_frame_exit_without_enter(self):
backend = LoggerTracingBackend()
backend.init("TestApp")

with patch("logging.Logger.warning") as mock_warning:
backend.frame_exit("unknown_frame", "task123")
mock_warning.assert_called_once()
assert (
"TRACE: Exit without matching enter for unknown_frame"
in mock_warning.call_args[0][0]
)


class TestTraceContext:
def setup_method(self):
TracingConfig.enabled = True
TracingConfig.app_name = "TestApp"
TracingConfig.backend = LoggerTracingBackend()
TracingConfig._initialized = False

def test_sync_trace_context(self):
backend = MagicMock(spec=LoggerTracingBackend)
TracingConfig.backend = backend
op_name = "test_operation"
request_id = "task123"

with trace_context(op_name, request_id):
backend.frame_enter.assert_called_once_with(op_name, request_id)
backend.frame_exit.assert_not_called()

backend.frame_exit.assert_called_once_with(op_name, request_id)

@pytest.mark.asyncio
async def test_async_trace_context(self):
backend = MagicMock(spec=LoggerTracingBackend)
TracingConfig.backend = backend
op_name = "async_operation"
request_id = "task456"

async with async_trace_context(op_name, request_id):
backend.frame_enter.assert_called_once_with(op_name, request_id)
backend.frame_exit.assert_not_called()

backend.frame_exit.assert_called_once_with(op_name, request_id)

def test_sync_trace_context_with_exception(self):
backend = MagicMock(spec=LoggerTracingBackend)
TracingConfig.backend = backend
op_name = "test_operation"
request_id = "task123"

try:
with trace_context(op_name, request_id):
raise ValueError("Test exception")
except ValueError:
pass

backend.frame_enter.assert_called_once_with(op_name, request_id)
backend.frame_exit.assert_called_once_with(op_name, request_id)

@pytest.mark.asyncio
async def test_async_trace_context_with_exception(self):
backend = MagicMock(spec=LoggerTracingBackend)
TracingConfig.backend = backend
op_name = "async_operation"
request_id = "task456"

try:
async with async_trace_context(op_name, request_id):
raise ValueError("Test exception")
except ValueError:
pass

backend.frame_enter.assert_called_once_with(op_name, request_id)
backend.frame_exit.assert_called_once_with(op_name, request_id)

def test_trace_context_disabled(self):
TracingConfig.set_enabled(False)
backend = MagicMock(spec=LoggerTracingBackend)
TracingConfig.backend = backend

with trace_context("test_operation", "task123"):
pass

backend.frame_enter.assert_not_called()
backend.frame_exit.assert_not_called()
Loading