diff --git a/README.md b/README.md index 3a671d4..f2c9125 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,111 @@ python3 -m pytest tests/ -v --timeout=180 # CI: CPU on every push, GPU on MI355X runners ``` +## Embedding + +rocm-trace-lite can be embedded into another profiler or tracer by redirecting events through callback hooks instead of writing to the built-in SQLite database. This is how [rocmProfileData (RPD)](https://github.com/ROCm/rocmProfileData) integrates rocm-trace-lite as its `RtlDataSource`. + +### Callback API + +Register callbacks **before** HSA `OnLoad` fires. In practice this means setting them during your library's initialization, before the application makes its first HIP/HSA call: + +```cpp +#include "trace_db.h" + +// Kernel dispatch completion (GPU-side, called from the completion worker thread) +void my_kernel_handler(const trace_db::KernelEventRecord& event, void* user_data) { + // event.name — demangled kernel name + // event.device_id — GPU index + // event.queue_id — HSA queue handle + // event.start_ns — GPU start timestamp (ns) + // event.end_ns — GPU end timestamp (ns) + // event.correlation_id + // event.wg_x/y/z — workgroup dimensions + // event.grid_x/y/z — grid dimensions +} + +// HIP API call (CPU-side, called from the app thread, requires RTL_MODE=hip) +void my_api_handler(const trace_db::ApiEventRecord& event, void* user_data) { + // event.name — e.g. "hipModuleLaunchKernel" + // event.args — formatted argument string + // event.start_ns — host start timestamp (ns, CLOCK_MONOTONIC) + // event.end_ns — host end timestamp (ns) + // event.correlation_id + // event.pid, event.tid +} + +// 1. Register callbacks +trace_db::set_kernel_event_callback(my_kernel_handler, nullptr); +trace_db::set_api_event_callback(my_api_handler, nullptr); + +// 2. ... application runs, callbacks fire ... + +// 3. At shutdown, drain pending events before finalizing your storage +trace_db::rtl_trigger_shutdown(); // joins completion worker, delivers remaining events +// Now safe to close your database / flush tables +``` + +### What happens when callbacks are set + +- **No SQLite file is created** — `get_trace_db()` lazy init is skipped +- **Kernel completions** route through the callback instead of `TraceDB::record_kernel()` +- **HIP API wrappers** route through the callback instead of `TraceDB::record_hip_api()` +- **`is_trace_ready()`** returns true for HIP wrappers even without a TraceDB, so API recording works immediately +- **Shutdown** skips `TraceDB::flush()`/`close()` — the consumer owns flushing +- **Fallback**: if no callback is set, everything works as before (standalone mode) + +### Shutdown ordering + +Call `rtl_trigger_shutdown()` before finalizing your own storage. This function: + +1. Joins the HSA completion worker thread (waits for in-flight kernel signals) +2. Drains remaining dispatch data, delivering final events through your callback +3. Cleans up the signal pool and queue map + +After it returns, no more callbacks will fire. Safe to call multiple times (idempotent). + +### Build integration + +Add rocm-trace-lite as a git submodule and compile its source files directly into your project: + +```makefile +RTL_SRC = rocm-trace-lite/src + +# Submodule sources — compiled with your project +RTL_OBJS = hsa_intercept.o hip_api_intercept.o trace_db.o + +# Your source that registers callbacks +MY_OBJS += my_bridge.o + +CXXFLAGS += -I$(RTL_SRC) -DAMD_INTERNAL_BUILD -std=c++17 -fPIC +LDFLAGS += -lhsa-runtime64 -lsqlite3 -ldl -lpthread + +hsa_intercept.o: $(RTL_SRC)/hsa_intercept.cpp + $(CXX) -o $@ -c $< $(CXXFLAGS) + +hip_api_intercept.o: $(RTL_SRC)/hip_api_intercept.cpp + $(CXX) -o $@ -c $< $(CXXFLAGS) + +trace_db.o: $(RTL_SRC)/trace_db.cpp + $(CXX) -o $@ -c $< $(CXXFLAGS) +``` + +Notes: +- `-DAMD_INTERNAL_BUILD` is required so `hsa_api_trace.h` resolves its includes correctly +- `trace_db.cpp` is still needed — it provides `tick()`, `next_correlation_id()`, and the callback storage +- SQLite is linked but unused when callbacks are active (no file I/O occurs) +- The submodule's `OnLoad`/`OnUnload` symbols are exported from your shared library; set `HSA_TOOLS_LIB` to point to it (or auto-set it via `dladdr` during init — see RPD example below) + +### Example: RPD integration + +RPD adds rocm-trace-lite as a submodule and compiles the source files directly into `librpd_tracer.so`. The entire integration is a single file: + +**RtlDataSource.cpp** (~160 lines): +1. `init()` — registers callbacks with `set_kernel_event_callback()` / `set_api_event_callback()`, auto-sets `HSA_TOOLS_LIB` via `dladdr` so `runTracer.sh` works unmodified +2. `on_kernel_event()` — writes `KernelEventRecord` to RPD `OpTable` + `StringTable` +3. `on_api_event()` — writes `ApiEventRecord` to RPD `ApiTable` / `KernelApiTable` / `CopyApiTable` +4. `end()` — calls `rtl_trigger_shutdown()`, ensuring all events are delivered before Logger finalizes tables + ## Acknowledgments This project was inspired by and builds upon the work of: diff --git a/src/hip_api_intercept.cpp b/src/hip_api_intercept.cpp index 30fadce..83d1c5f 100644 --- a/src/hip_api_intercept.cpp +++ b/src/hip_api_intercept.cpp @@ -90,6 +90,30 @@ static int get_tid() { return (int)syscall(SYS_gettid); } +static bool is_recording_ready() { + return is_trace_ready() || trace_db::get_api_event_callback() != nullptr; +} + +static void deliver_hip_api(const char* name, const char* args, + uint64_t start_ns, uint64_t duration_ns, + uint64_t correlation_id) { + auto cb = trace_db::get_api_event_callback(); + if (cb) { + trace_db::ApiEventRecord rec; + rec.name = name; + rec.args = args; + rec.start_ns = start_ns; + rec.end_ns = start_ns + duration_ns; + rec.correlation_id = correlation_id; + rec.pid = getpid(); + rec.tid = get_tid(); + cb(rec, trace_db::get_api_event_callback_data()); + } else { + get_trace_db().record_hip_api(name, args, start_ns, duration_ns, + correlation_id, getpid(), get_tid()); + } +} + // HIP type definitions (avoid including hip_runtime_api.h to keep zero-dep) typedef int hipError_t; typedef void* hipStream_t; @@ -210,7 +234,7 @@ hipError_t hipModuleLaunchKernel( resolve_hipModuleLaunchKernel(); if (!real_hipModuleLaunchKernel) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipModuleLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, stream, kernelParams, extra); @@ -231,8 +255,8 @@ hipError_t hipModuleLaunchKernel( snprintf(args, sizeof(args), "grid=%u,%u,%u block=%u,%u,%u shared=%u", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes); - get_trace_db().record_hip_api("hipModuleLaunchKernel", args, - t0, t1 - t0, corr, getpid(), get_tid()); + deliver_hip_api("hipModuleLaunchKernel", args, + t0, t1 - t0, corr); return ret; } @@ -247,7 +271,7 @@ hipError_t hipExtModuleLaunchKernel( resolve_hipExtModuleLaunchKernel(); if (!real_hipExtModuleLaunchKernel) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipExtModuleLaunchKernel(f, globalWorkSizeX, globalWorkSizeY, globalWorkSizeZ, localWorkSizeX, localWorkSizeY, localWorkSizeZ, sharedMemBytes, stream, kernelParams, extra, @@ -270,8 +294,8 @@ hipError_t hipExtModuleLaunchKernel( snprintf(args, sizeof(args), "grid=%u,%u,%u block=%u,%u,%u shared=%u", globalWorkSizeX, globalWorkSizeY, globalWorkSizeZ, localWorkSizeX, localWorkSizeY, localWorkSizeZ, sharedMemBytes); - get_trace_db().record_hip_api("hipExtModuleLaunchKernel", args, - t0, t1 - t0, corr, getpid(), get_tid()); + deliver_hip_api("hipExtModuleLaunchKernel", args, + t0, t1 - t0, corr); return ret; } @@ -280,7 +304,7 @@ hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes, resolve_hipMemcpy(); if (!real_hipMemcpy) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipMemcpy(dst, src, sizeBytes, kind); } ScopedReentrancyGuard _guard; @@ -290,8 +314,8 @@ hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes, uint64_t t1 = tick(); char args[64]; snprintf(args, sizeof(args), "size=%zu kind=%d", sizeBytes, (int)kind); - get_trace_db().record_hip_api("hipMemcpy", args, t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipMemcpy", args, t0, t1 - t0, + corr); return ret; } @@ -300,7 +324,7 @@ hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes, resolve_hipMemcpyAsync(); if (!real_hipMemcpyAsync) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipMemcpyAsync(dst, src, sizeBytes, kind, stream); } ScopedReentrancyGuard _guard; @@ -310,8 +334,8 @@ hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes, uint64_t t1 = tick(); char args[64]; snprintf(args, sizeof(args), "size=%zu kind=%d", sizeBytes, (int)kind); - get_trace_db().record_hip_api("hipMemcpyAsync", args, t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipMemcpyAsync", args, t0, t1 - t0, + corr); return ret; } @@ -319,7 +343,7 @@ hipError_t hipMalloc(void** ptr, size_t size) { resolve_hipMalloc(); if (!real_hipMalloc) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipMalloc(ptr, size); } ScopedReentrancyGuard _guard; @@ -329,8 +353,8 @@ hipError_t hipMalloc(void** ptr, size_t size) { uint64_t t1 = tick(); char args[64]; snprintf(args, sizeof(args), "size=%zu", size); - get_trace_db().record_hip_api("hipMalloc", args, t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipMalloc", args, t0, t1 - t0, + corr); return ret; } @@ -338,7 +362,7 @@ hipError_t hipFree(void* ptr) { resolve_hipFree(); if (!real_hipFree) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipFree(ptr); } ScopedReentrancyGuard _guard; @@ -346,8 +370,8 @@ hipError_t hipFree(void* ptr) { uint64_t t0 = tick(); hipError_t ret = real_hipFree(ptr); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipFree", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipFree", "", t0, t1 - t0, + corr); return ret; } @@ -355,7 +379,7 @@ hipError_t hipStreamSynchronize(hipStream_t stream) { resolve_hipStreamSynchronize(); if (!real_hipStreamSynchronize) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipStreamSynchronize(stream); } ScopedReentrancyGuard _guard; @@ -363,8 +387,8 @@ hipError_t hipStreamSynchronize(hipStream_t stream) { uint64_t t0 = tick(); hipError_t ret = real_hipStreamSynchronize(stream); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipStreamSynchronize", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipStreamSynchronize", "", t0, t1 - t0, + corr); return ret; } @@ -372,7 +396,7 @@ hipError_t hipDeviceSynchronize() { resolve_hipDeviceSynchronize(); if (!real_hipDeviceSynchronize) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipDeviceSynchronize(); } ScopedReentrancyGuard _guard; @@ -380,8 +404,8 @@ hipError_t hipDeviceSynchronize() { uint64_t t0 = tick(); hipError_t ret = real_hipDeviceSynchronize(); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipDeviceSynchronize", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipDeviceSynchronize", "", t0, t1 - t0, + corr); return ret; } @@ -389,7 +413,7 @@ hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream) { resolve_hipGraphLaunch(); if (!real_hipGraphLaunch) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipGraphLaunch(graphExec, stream); } ScopedReentrancyGuard _guard; @@ -401,8 +425,8 @@ hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream) { hipError_t ret = real_hipGraphLaunch(graphExec, stream); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipGraphLaunch", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipGraphLaunch", "", t0, t1 - t0, + corr); return ret; } @@ -412,7 +436,7 @@ hipError_t hipSetDevice(int deviceId) { resolve_hipSetDevice(); if (!real_hipSetDevice) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipSetDevice(deviceId); } ScopedReentrancyGuard _guard; @@ -422,8 +446,8 @@ hipError_t hipSetDevice(int deviceId) { uint64_t t1 = tick(); char args[32]; snprintf(args, sizeof(args), "device=%d", deviceId); - get_trace_db().record_hip_api("hipSetDevice", args, t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipSetDevice", args, t0, t1 - t0, + corr); return ret; } @@ -431,7 +455,7 @@ hipError_t hipStreamCreate(hipStream_t* stream) { resolve_hipStreamCreate(); if (!real_hipStreamCreate) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipStreamCreate(stream); } ScopedReentrancyGuard _guard; @@ -439,8 +463,8 @@ hipError_t hipStreamCreate(hipStream_t* stream) { uint64_t t0 = tick(); hipError_t ret = real_hipStreamCreate(stream); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipStreamCreate", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipStreamCreate", "", t0, t1 - t0, + corr); return ret; } @@ -448,7 +472,7 @@ hipError_t hipStreamDestroy(hipStream_t stream) { resolve_hipStreamDestroy(); if (!real_hipStreamDestroy) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipStreamDestroy(stream); } ScopedReentrancyGuard _guard; @@ -456,8 +480,8 @@ hipError_t hipStreamDestroy(hipStream_t stream) { uint64_t t0 = tick(); hipError_t ret = real_hipStreamDestroy(stream); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipStreamDestroy", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipStreamDestroy", "", t0, t1 - t0, + corr); return ret; } @@ -465,7 +489,7 @@ hipError_t hipEventCreate(hipEvent_t* event) { resolve_hipEventCreate(); if (!real_hipEventCreate) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipEventCreate(event); } ScopedReentrancyGuard _guard; @@ -473,8 +497,8 @@ hipError_t hipEventCreate(hipEvent_t* event) { uint64_t t0 = tick(); hipError_t ret = real_hipEventCreate(event); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipEventCreate", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipEventCreate", "", t0, t1 - t0, + corr); return ret; } @@ -482,7 +506,7 @@ hipError_t hipEventDestroy(hipEvent_t event) { resolve_hipEventDestroy(); if (!real_hipEventDestroy) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipEventDestroy(event); } ScopedReentrancyGuard _guard; @@ -490,8 +514,8 @@ hipError_t hipEventDestroy(hipEvent_t event) { uint64_t t0 = tick(); hipError_t ret = real_hipEventDestroy(event); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipEventDestroy", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipEventDestroy", "", t0, t1 - t0, + corr); return ret; } @@ -499,7 +523,7 @@ hipError_t hipEventRecord(hipEvent_t event, hipStream_t stream) { resolve_hipEventRecord(); if (!real_hipEventRecord) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipEventRecord(event, stream); } ScopedReentrancyGuard _guard; @@ -507,8 +531,8 @@ hipError_t hipEventRecord(hipEvent_t event, hipStream_t stream) { uint64_t t0 = tick(); hipError_t ret = real_hipEventRecord(event, stream); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipEventRecord", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipEventRecord", "", t0, t1 - t0, + corr); return ret; } @@ -516,7 +540,7 @@ hipError_t hipEventSynchronize(hipEvent_t event) { resolve_hipEventSynchronize(); if (!real_hipEventSynchronize) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipEventSynchronize(event); } ScopedReentrancyGuard _guard; @@ -524,8 +548,8 @@ hipError_t hipEventSynchronize(hipEvent_t event) { uint64_t t0 = tick(); hipError_t ret = real_hipEventSynchronize(event); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipEventSynchronize", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipEventSynchronize", "", t0, t1 - t0, + corr); return ret; } @@ -533,7 +557,7 @@ hipError_t hipGraphCreate(hipGraph_t* graph, unsigned int flags) { resolve_hipGraphCreate(); if (!real_hipGraphCreate) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipGraphCreate(graph, flags); } ScopedReentrancyGuard _guard; @@ -541,8 +565,8 @@ hipError_t hipGraphCreate(hipGraph_t* graph, unsigned int flags) { uint64_t t0 = tick(); hipError_t ret = real_hipGraphCreate(graph, flags); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipGraphCreate", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipGraphCreate", "", t0, t1 - t0, + corr); return ret; } @@ -551,7 +575,7 @@ hipError_t hipGraphInstantiate(hipGraphExec_t* exec, hipGraph_t graph, resolve_hipGraphInstantiate(); if (!real_hipGraphInstantiate) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipGraphInstantiate(exec, graph, errNode, errLog, bufSize); } ScopedReentrancyGuard _guard; @@ -559,8 +583,8 @@ hipError_t hipGraphInstantiate(hipGraphExec_t* exec, hipGraph_t graph, uint64_t t0 = tick(); hipError_t ret = real_hipGraphInstantiate(exec, graph, errNode, errLog, bufSize); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipGraphInstantiate", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipGraphInstantiate", "", t0, t1 - t0, + corr); return ret; } @@ -568,7 +592,7 @@ hipError_t hipGraphExecDestroy(hipGraphExec_t exec) { resolve_hipGraphExecDestroy(); if (!real_hipGraphExecDestroy) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipGraphExecDestroy(exec); } ScopedReentrancyGuard _guard; @@ -576,8 +600,8 @@ hipError_t hipGraphExecDestroy(hipGraphExec_t exec) { uint64_t t0 = tick(); hipError_t ret = real_hipGraphExecDestroy(exec); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipGraphExecDestroy", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipGraphExecDestroy", "", t0, t1 - t0, + corr); return ret; } @@ -585,7 +609,7 @@ hipError_t hipHostMalloc(void** ptr, size_t size, unsigned int flags) { resolve_hipHostMalloc(); if (!real_hipHostMalloc) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipHostMalloc(ptr, size, flags); } ScopedReentrancyGuard _guard; @@ -595,8 +619,8 @@ hipError_t hipHostMalloc(void** ptr, size_t size, unsigned int flags) { uint64_t t1 = tick(); char args[64]; snprintf(args, sizeof(args), "size=%zu flags=%u", size, flags); - get_trace_db().record_hip_api("hipHostMalloc", args, t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipHostMalloc", args, t0, t1 - t0, + corr); return ret; } @@ -604,7 +628,7 @@ hipError_t hipHostFree(void* ptr) { resolve_hipHostFree(); if (!real_hipHostFree) return kHipErrorUnresolved; if (tls_in_hip_api || !hip_api::g_enabled.load(std::memory_order_relaxed) - || !is_trace_ready()) { + || !is_recording_ready()) { return real_hipHostFree(ptr); } ScopedReentrancyGuard _guard; @@ -612,8 +636,8 @@ hipError_t hipHostFree(void* ptr) { uint64_t t0 = tick(); hipError_t ret = real_hipHostFree(ptr); uint64_t t1 = tick(); - get_trace_db().record_hip_api("hipHostFree", "", t0, t1 - t0, - corr, getpid(), get_tid()); + deliver_hip_api("hipHostFree", "", t0, t1 - t0, + corr); return ret; } diff --git a/src/hsa_intercept.cpp b/src/hsa_intercept.cpp index 21baca5..5ad637e 100644 --- a/src/hsa_intercept.cpp +++ b/src/hsa_intercept.cpp @@ -255,15 +255,33 @@ static void completion_worker() { g_drop_ts_invalid.fetch_add(1, std::memory_order_relaxed); } else { std::string name = lookup_kernel_name(dd->kernel_object); - char dispatch_info[128]; - snprintf(dispatch_info, sizeof(dispatch_info), - "hwq=0x%" PRIx64 " wg=%u,%u,%u grid=%u,%u,%u", - dd->hw_queue_addr, - (unsigned)dd->wg_x, (unsigned)dd->wg_y, (unsigned)dd->wg_z, - dd->grid_x, dd->grid_y, dd->grid_z); - get_trace_db().record_kernel(name.c_str(), dd->device_id, dd->queue_id, - time.start, time.end, dd->correlation_id, - dispatch_info); + auto kernel_cb = trace_db::get_kernel_event_callback(); + if (kernel_cb) { + trace_db::KernelEventRecord rec; + rec.name = name.c_str(); + rec.device_id = dd->device_id; + rec.queue_id = dd->queue_id; + rec.start_ns = time.start; + rec.end_ns = time.end; + rec.correlation_id = dd->correlation_id; + rec.wg_x = dd->wg_x; + rec.wg_y = dd->wg_y; + rec.wg_z = dd->wg_z; + rec.grid_x = dd->grid_x; + rec.grid_y = dd->grid_y; + rec.grid_z = dd->grid_z; + kernel_cb(rec, trace_db::get_kernel_event_callback_data()); + } else { + char dispatch_info[128]; + snprintf(dispatch_info, sizeof(dispatch_info), + "hwq=0x%" PRIx64 " wg=%u,%u,%u grid=%u,%u,%u", + dd->hw_queue_addr, + (unsigned)dd->wg_x, (unsigned)dd->wg_y, (unsigned)dd->wg_z, + dd->grid_x, dd->grid_y, dd->grid_z); + get_trace_db().record_kernel(name.c_str(), dd->device_id, dd->queue_id, + time.start, time.end, dd->correlation_id, + dispatch_info); + } g_recorded_ok.fetch_add(1, std::memory_order_relaxed); } @@ -676,9 +694,11 @@ static void shutdown() { } } - // Flush and close trace DB - trace_db::get_trace_db().flush(); - trace_db::get_trace_db().close(); + // Flush and close trace DB (skip if callback hooks are set — consumer owns flushing) + if (!trace_db::get_kernel_event_callback() && !trace_db::get_api_event_callback()) { + trace_db::get_trace_db().flush(); + trace_db::get_trace_db().close(); + } // Destroy signal pool. // At this point the system is quiesced: g_shutdown is true, the worker @@ -705,6 +725,12 @@ static void shutdown() { } // namespace hsa_intercept +namespace trace_db { +void rtl_trigger_shutdown() { + hsa_intercept::shutdown(); +} +} // namespace trace_db + // ---- Entry points for HSA_TOOLS_LIB ---- extern "C" bool OnLoad(void* pTable, @@ -715,8 +741,10 @@ extern "C" bool OnLoad(void* pTable, using namespace hsa_intercept; - // Ensure trace database is open - (void)trace_db::get_trace_db(); + // Ensure trace database is open (skip if callback hooks are set) + if (!trace_db::get_kernel_event_callback() && !trace_db::get_api_event_callback()) { + (void)trace_db::get_trace_db(); + } // Save original API tables HsaApiTable* table = reinterpret_cast(pTable); diff --git a/src/trace_db.cpp b/src/trace_db.cpp index bd36b60..f5a3c0d 100644 --- a/src/trace_db.cpp +++ b/src/trace_db.cpp @@ -386,4 +386,29 @@ void TraceDB::record_roctx(const char* message, uint64_t start_ns, uint64_t dura } } +// ---- Callback hook storage ---- +// Set once before OnLoad, read from hot paths afterward. The store +// happens-before any reader thread exists (OnLoad hasn't fired yet), +// so plain pointers are safe — no concurrent write/read is possible. + +static ApiEventCallback g_api_event_cb = nullptr; +static void* g_api_event_cb_data = nullptr; +static KernelEventCallback g_kernel_event_cb = nullptr; +static void* g_kernel_event_cb_data = nullptr; + +void set_api_event_callback(ApiEventCallback cb, void* user_data) { + g_api_event_cb = cb; + g_api_event_cb_data = user_data; +} + +void set_kernel_event_callback(KernelEventCallback cb, void* user_data) { + g_kernel_event_cb = cb; + g_kernel_event_cb_data = user_data; +} + +ApiEventCallback get_api_event_callback() { return g_api_event_cb; } +KernelEventCallback get_kernel_event_callback() { return g_kernel_event_cb; } +void* get_api_event_callback_data() { return g_api_event_cb_data; } +void* get_kernel_event_callback_data() { return g_kernel_event_cb_data; } + } // namespace trace_db diff --git a/src/trace_db.h b/src/trace_db.h index 0f27f6d..1d8913f 100644 --- a/src/trace_db.h +++ b/src/trace_db.h @@ -75,4 +75,60 @@ bool is_trace_ready(); // Global correlation ID counter uint64_t next_correlation_id(); +// ---- Optional callback hooks ---- +// When set, interception code calls these instead of writing to TraceDB. +// This allows embedding (e.g., RPD tracer) to redirect events without +// modifying the interception logic. +// +// String lifetime: all const char* fields in event records point to +// thread-local or stack storage. Pointers are valid only for the +// duration of the callback invocation. Embedders must copy any strings +// they need to retain. +// +// Thread safety: callbacks must be registered before OnLoad fires +// (i.e., before any HIP/HSA call). The store happens-before any +// reader thread exists, so plain pointers are safe. +// +// Callback requirements: callbacks execute on hot paths (completion +// worker thread for kernel events, application thread for API events). +// They must be noexcept and non-blocking. + +struct ApiEventRecord { + const char* name; // valid only during callback + const char* args; // valid only during callback + uint64_t start_ns; + uint64_t end_ns; + uint64_t correlation_id; + int pid; + int tid; +}; + +struct KernelEventRecord { + const char* name; // valid only during callback + int device_id; + uint64_t queue_id; + uint64_t start_ns; + uint64_t end_ns; + uint64_t correlation_id; + uint16_t wg_x, wg_y, wg_z; + uint32_t grid_x, grid_y, grid_z; +}; + +using ApiEventCallback = void(*)(const ApiEventRecord& event, void* user_data); +using KernelEventCallback = void(*)(const KernelEventRecord& event, void* user_data); + +void set_api_event_callback(ApiEventCallback cb, void* user_data); +void set_kernel_event_callback(KernelEventCallback cb, void* user_data); + +ApiEventCallback get_api_event_callback(); +KernelEventCallback get_kernel_event_callback(); +void* get_api_event_callback_data(); +void* get_kernel_event_callback_data(); + +// Trigger shutdown of the HSA intercept (joins worker, drains queue). +// Idempotent: guarded by atomic flag, second and subsequent calls are +// no-ops. Exposed so embedders can drain pending events before +// finalizing their own storage. +void rtl_trigger_shutdown(); + } // namespace trace_db diff --git a/tests/test_embedded_callback.py b/tests/test_embedded_callback.py new file mode 100644 index 0000000..e560dd3 --- /dev/null +++ b/tests/test_embedded_callback.py @@ -0,0 +1,166 @@ +"""Tests for the embedded callback API (trace_db.h hooks). + +Validates the callback registration, event delivery, and shutdown +contracts used by embedders like RPD's RtlDataSource. + +CPU-only tests validate API contracts via source inspection. +GPU tests (marked gpu) validate end-to-end event delivery. +""" +import os +import re +import subprocess +import tempfile + +import pytest + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +SRC_DIR = os.path.join(REPO_ROOT, "src") +TRACE_DB_H = os.path.join(SRC_DIR, "trace_db.h") +TRACE_DB_CPP = os.path.join(SRC_DIR, "trace_db.cpp") +HSA_INTERCEPT = os.path.join(SRC_DIR, "hsa_intercept.cpp") +HIP_INTERCEPT = os.path.join(SRC_DIR, "hip_api_intercept.cpp") + + +class TestCallbackAPIContract: + """Verify the callback API exists and has the documented contract.""" + + def test_event_structs_exist(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "struct ApiEventRecord" in content + assert "struct KernelEventRecord" in content + + def test_setter_functions_declared(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "void set_api_event_callback(" in content + assert "void set_kernel_event_callback(" in content + + def test_getter_functions_declared(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "ApiEventCallback get_api_event_callback()" in content + assert "KernelEventCallback get_kernel_event_callback()" in content + + def test_shutdown_function_declared(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "void rtl_trigger_shutdown()" in content + + +class TestStringLifetimeContract: + """Verify string lifetime is documented.""" + + def test_lifetime_documented_in_header(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "valid only during callback" in content + + def test_api_name_annotated(self): + with open(TRACE_DB_H) as f: + content = f.read() + match = re.search(r'struct ApiEventRecord\s*\{(.*?)\}', content, re.DOTALL) + assert match, "Could not find ApiEventRecord" + body = match.group(1) + assert "valid only during callback" in body + + def test_kernel_name_annotated(self): + with open(TRACE_DB_H) as f: + content = f.read() + match = re.search(r'struct KernelEventRecord\s*\{(.*?)\}', content, re.DOTALL) + assert match, "Could not find KernelEventRecord" + body = match.group(1) + assert "valid only during callback" in body + + +class TestCallbackThreadSafety: + """Verify thread safety contract is documented and implemented.""" + + def test_set_before_onload_documented(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "before OnLoad" in content + + def test_noexcept_documented(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "noexcept" in content + + def test_non_blocking_documented(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "non-blocking" in content + + +class TestShutdownIdempotency: + """Verify rtl_trigger_shutdown is idempotent.""" + + def test_idempotency_documented(self): + with open(TRACE_DB_H) as f: + content = f.read() + assert "Idempotent" in content or "idempotent" in content + + def test_shutdown_has_atomic_guard(self): + """The underlying shutdown() must have a once-guard.""" + with open(HSA_INTERCEPT) as f: + content = f.read() + match = re.search( + r'static void shutdown\(\)\s*\{(.*?)\n\}', + content, re.DOTALL + ) + assert match, "Could not find shutdown()" + body = match.group(1) + assert "shutdown_done" in body, ( + "shutdown() must use an atomic guard for idempotency" + ) + + +class TestSQLiteGating: + """Verify SQLite is skipped when either callback is set.""" + + def test_onload_checks_both_callbacks(self): + with open(HSA_INTERCEPT) as f: + content = f.read() + # The gating around get_trace_db() in OnLoad should check both + assert "get_api_event_callback()" in content + + def test_shutdown_checks_both_callbacks(self): + """flush/close gating must check both kernel and API callbacks.""" + with open(HSA_INTERCEPT) as f: + content = f.read() + # Find the shutdown function and verify it checks both + match = re.search( + r'static void shutdown\(\)\s*\{(.*?)\n\}', + content, re.DOTALL + ) + assert match, "Could not find shutdown()" + body = match.group(1) + assert "get_api_event_callback" in body, ( + "shutdown must check API callback before flushing SQLite" + ) + + +class TestHipInterceptCallbackPath: + """Verify HIP wrappers route through callbacks.""" + + def test_deliver_hip_api_exists(self): + with open(HIP_INTERCEPT) as f: + content = f.read() + assert "deliver_hip_api(" in content + + def test_is_recording_ready_exists(self): + with open(HIP_INTERCEPT) as f: + content = f.read() + assert "is_recording_ready()" in content + + def test_deliver_checks_callback(self): + """deliver_hip_api must check get_api_event_callback.""" + with open(HIP_INTERCEPT) as f: + content = f.read() + match = re.search( + r'static void deliver_hip_api\((.*?\n\})', + content, re.DOTALL + ) + assert match, "Could not find deliver_hip_api" + body = match.group(1) + assert "get_api_event_callback" in body