diff --git a/Cargo.lock b/Cargo.lock index 02c03811..5f0886e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2884,6 +2884,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "goblin" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d20fd25aa456527ce4f544271ae4fea65d2eda4a6561ea56f39fb3ee4f7e3884" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "group" version = "0.12.1" @@ -5326,6 +5337,7 @@ dependencies = [ "cudarc", "half", "serde", + "tvm-ffi", ] [[package]] @@ -5635,6 +5647,12 @@ version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.7" @@ -5815,6 +5833,30 @@ dependencies = [ "toml_edit 0.25.11+spec-1.1.0", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -6984,6 +7026,26 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" +[[package]] +name = "scroll" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda28d4b4830b807a8b43f7b0e6b5df875311b3e7621d84577188c175b6ec1ec" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaaae8f38bb311444cfb7f1979af0bc9240d95795f75f9ceddf6a59b79ceffa0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "sct" version = "0.7.1" @@ -8568,6 +8630,39 @@ dependencies = [ "tokio", ] +[[package]] +name = "tvm-ffi" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c9d5b22b336c5de9ca98442715ec64f09b7038c6d88817a7de8ae00b2c384" +dependencies = [ + "paste", + "tvm-ffi-macros", + "tvm-ffi-sys", +] + +[[package]] +name = "tvm-ffi-macros" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18901fd34d643368ca7139e1d592a6c484cb855f6554057a699b9b77c29851e3" +dependencies = [ + "goblin", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "tvm-ffi-sys" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7552573e760924afb4e24fce8b5d453236b59710c8c37438d00c1006a9b38f1c" +dependencies = [ + "cmake", +] + [[package]] name = "typeid" version = "1.0.3" diff --git a/docs/index.md b/docs/index.md index d2e71462..6be710d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -119,6 +119,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `subsystems/kernels/pegainfer-kernels-boundary.md` | Architecture decision: pegainfer should use reusable frontend/runtime/data-plane layers plus per-model engines; kernels become first-class assets through a ledger, simulator, and request tracing. | | `subsystems/kernels/kernel-op-reports.md` | Qwen3 kernel/report tooling is feature-gated: `qwen3_kernel_report` covers per-op kernel reports, and `qwen3_model_report` emits runtime-traced eager-DAG decode operator rollups with TensorSpec `KernelCall`s, latency stats, tables, and Graphviz DOT; measured FA2 `CTA_TILE_Q=64` prefill default in place. | | `subsystems/kernels/typed-forward-pipeline.md` | Reusable typed tensor pipeline macro in `pegainfer-kernels` so model crates can express common `typed_ops` chains without model-specific wrapper macros. | +| `subsystems/kernels/tvm-ffi-mvp.md` | Optional `tvm-ffi-triton-cubin` bridge in `pegainfer-kernels` plus a packed TVM wrapper for the Qwen3.5 GDR solve Triton AOT CUBIN launcher. | ## playbooks diff --git a/docs/subsystems/kernels/tvm-ffi-mvp.md b/docs/subsystems/kernels/tvm-ffi-mvp.md new file mode 100644 index 00000000..7d85e89f --- /dev/null +++ b/docs/subsystems/kernels/tvm-ffi-mvp.md @@ -0,0 +1,85 @@ +# TVM FFI Triton CUBIN Wrapper + +> **TL;DR:** `pegainfer-kernels` now has an optional `tvm-ffi-triton-cubin` bridge for the Qwen3.5 GDR solve Triton AOT CUBIN launcher, with unit coverage for wrapper registration and packed-ABI diagnostics. +> +> **Last touched:** 2026-06 + +## Preparation + +- **Read**: + - `docs/index.md` - routed this task to the kernels subsystem. + - `docs/subsystems/kernels/pegainfer-kernels-boundary.md` - confirmed DSL/kernel integration belongs at the kernels boundary rather than in model runtimes. + - `docs/subsystems/kernels/kernel-op-reports.md` - confirmed Triton/CuTe tooling is already feature-scoped in kernel infrastructure. + - `pegainfer-kernels/tools/triton/README.md` - described the current Triton AOT CUBIN generation and validation path. + - `pegainfer-kernels/build.rs` - showed generated Triton AOT C stubs and wrapper symbols. + - `pegainfer-kernels/src/ffi/qwen35.rs` and `pegainfer-kernels/src/ffi/shared.rs` - showed the existing C ABI launch symbols used by Rust model code. + - Local `tvm-ffi` crate source - confirmed typed callbacks only cover up to 8 arguments, so Triton launchers need packed TVM FFI wrappers. +- **Relevant history**: + - GitHub issue `#191` proposed TVM FFI as the DSL interface direction. + - Draft PR `#202` kept TVM FFI optional/test-only; PR `#315` now keeps the bridge optional behind `tvm-ffi-triton-cubin` while focusing it on Triton CUBIN launch wrappers. +- **Plan**: + 1. Add `tvm-ffi` as an optional dependency of `pegainfer-kernels` behind `tvm-ffi-triton-cubin`. + 2. Add a `triton_cubin` module that exposes a current Qwen3.5 Triton AOT CUBIN launcher as a packed TVM FFI function. + 3. Keep existing C ABI and Rust call sites available; the TVM FFI layer is an additional DSL boundary, not a production scheduler/model migration. + 4. Add a small example that registers the wrapper and prints the function contract. + 5. Validate formatting and the strongest local build/test checks available. +- **Risks / open questions**: + - The `tvm-ffi-triton-cubin` feature means `tvm-ffi-config` and `libtvm_ffi` are build prerequisites only for the optional bridge path. + - The current wrapper accepts raw device pointer and stream handles as TVM integers or opaque pointers; a future DLPack/tensor-handle wrapper can sit on top once the DSL artifact contract is stable. + +## Execution Log + +### Step 1: Required dependency and wrapper surface +- Added optional `tvm-ffi = "0.1.0-alpha.0"` to `pegainfer-kernels` behind `tvm-ffi-triton-cubin`. +- Added `pegainfer_kernels::triton_cubin`, which exposes metadata plus a packed TVM FFI callback for the generated Qwen3.5 GDR solve Triton AOT launcher. +- Kept existing CUDA C ABI symbols and model call sites unchanged. + +### Step 2: Small example +- Added `pegainfer-kernels/examples/triton_cubin_tvm_ffi.rs` to register the TVM FFI global function and print the launch contract. + +### Step 3: Unit test coverage +- Added wrapper unit tests for: + - known/unknown wrapper lookup; + - global TVM FFI registry round-trip; + - accepted raw handle encodings (`u64` and opaque pointer); + - missing-argument diagnostics before CUDA launch; + - handle and scalar type diagnostics before CUDA launch. +- Kept tests on pre-launch validation paths so they do not require valid device memory or actually launch the Triton CUBIN. + +### Step 4: Validation +- `cargo fmt --all --check` passed. +- `cargo check --release -p pegainfer-kernels` no longer requires `tvm-ffi-config`; the TVM FFI bridge is feature-gated. +- Retried with the discovered local TVM FFI install on `PATH`: + - `PATH=/home/ziyang/gpu_memory_profiling/.venv/bin:$PATH cargo check --release -p pegainfer-kernels --features tvm-ffi-triton-cubin` + - `tvm-ffi` built successfully. + - The build then failed in the existing CUDA build at `pegainfer-kernels/csrc/shared/flashinfer_top1.cu` because the dirty `pegainfer-kernels/third_party/flashinfer` submodule commit changes the `TopKDispatch` API. This is unrelated to the TVM FFI wrapper and was left untouched. +- After adding tests: + - `cargo fmt --all --check` passed. + - `PATH=/home/ziyang/gpu_memory_profiling/.venv/bin:$PATH cargo test --release -p pegainfer-kernels --features tvm-ffi-triton-cubin triton_cubin --lib` builds the optional bridge, then currently hits the existing `flashinfer_top1.cu` `TopKDispatch` API mismatch before Rust unit tests can run in this dirty submodule checkout. + +### Step 5: Review fixes +- Addressed xiaguan's requested changes on PR `#315`: + - made `tvm-ffi` optional behind `tvm-ffi-triton-cubin` so normal `pegainfer-kernels` builds do not require `tvm-ffi-config` / `libtvm_ffi`; + - replaced `expect_err(...)` in tests with explicit `Result` matching because `tvm_ffi::Any` does not implement `Debug`; + - updated the example and docs to require/pass the feature. +- Also addressed automated inline feedback by accepting TVM FFI packed integers as `i64` for pointer handles and scalar launch dimensions, with range checks before casting. +- Review-fix validation: + - `cargo fmt --all --check` passed. + - `cargo metadata --no-deps --format-version 1` passed. + - `cargo tree -p pegainfer-kernels -e normal --no-default-features --depth 1` shows normal dependencies only (`anyhow`, `cudarc`, `half`, `serde`), no `tvm-ffi`. + - `cargo tree -p pegainfer-kernels -e normal --features tvm-ffi-triton-cubin --depth 1` shows `tvm-ffi` only with the bridge feature enabled. + - `cargo check --release -p pegainfer-kernels` no longer needs `tvm-ffi-config`, then stops at the existing dirty-FlashInfer `flashinfer_top1.cu` `TopKDispatch` mismatch. + - `PATH=/home/ziyang/gpu_memory_profiling/.venv/bin:$PATH cargo test --release -p pegainfer-kernels --features tvm-ffi-triton-cubin triton_cubin --lib -- --nocapture` also stops at the same CUDA build-script mismatch before Rust tests run in this checkout. + +## Debrief + +- **Outcome**: Added optional TVM FFI dependency wiring plus a real Triton CUBIN wrapper MVP for the Qwen3.5 GDR solve launcher, with unit tests covering wrapper discovery, registry registration, packed handle conversion, and pre-launch diagnostics. +- **Pitfalls encountered**: + - `apply_patch` and normal shell commands were blocked by the sandbox namespace failure, so edits were applied with scoped `git apply` patches. + - TVM FFI is now a real build prerequisite only when `tvm-ffi-triton-cubin` is enabled; hosts using that feature need `tvm-ffi-config` on `PATH`. + - Local full kernel-crate validation is currently blocked by the pre-existing dirty FlashInfer submodule, not by the TVM FFI code. +- **Lessons learned**: + - TVM FFI typed callbacks currently cover only up to 8 arguments, while Triton/CUDA launchers can exceed that, so the wrapper should use packed TVM FFI callbacks for launch surfaces. +- **Follow-ups**: + - Add packed TVM FFI wrappers for the remaining generated Triton AOT launchers once the FlashInfer submodule is back at the expected API or the CUDA call site is updated. + - Consider a higher-level DLPack/tensor-handle wrapper above the raw pointer/stream packed ABI once the DSL artifact contract is stable. diff --git a/pegainfer-kernels/Cargo.toml b/pegainfer-kernels/Cargo.toml index f36b8e12..873d1ffe 100644 --- a/pegainfer-kernels/Cargo.toml +++ b/pegainfer-kernels/Cargo.toml @@ -8,15 +8,21 @@ anyhow = { workspace = true } cudarc = { workspace = true } half = { workspace = true } serde = { workspace = true } +tvm-ffi = { version = "0.1.0-alpha.0", optional = true } [build-dependencies] cc = { workspace = true } [features] default = [] +tvm-ffi-triton-cubin = ["dep:tvm-ffi"] deepseek-v4 = [] deepseek-v4-cutedsl-diagnostic = ["deepseek-v4"] kimi-k2 = [] +[[example]] +name = "triton_cubin_tvm_ffi" +required-features = ["tvm-ffi-triton-cubin"] + [lints] workspace = true diff --git a/pegainfer-kernels/examples/triton_cubin_tvm_ffi.rs b/pegainfer-kernels/examples/triton_cubin_tvm_ffi.rs new file mode 100644 index 00000000..0ff923c6 --- /dev/null +++ b/pegainfer-kernels/examples/triton_cubin_tvm_ffi.rs @@ -0,0 +1,20 @@ +use pegainfer_kernels::triton_cubin::{self, QWEN35_GDR_CHUNK_SOLVE}; + +fn main() -> tvm_ffi::Result<()> { + triton_cubin::register_global_functions()?; + + println!("registered Triton CUBIN TVM FFI functions:"); + for spec in triton_cubin::TRITON_CUBIN_FUNCTIONS { + println!(" {} -> {}", spec.name, spec.ffi_symbol); + } + + let solve = triton_cubin::get_global_or_register(QWEN35_GDR_CHUNK_SOLVE.name)?; + println!( + "{} is ready; call it with packed args: {}", + QWEN35_GDR_CHUNK_SOLVE.name, + QWEN35_GDR_CHUNK_SOLVE.arg_names.join(", ") + ); + + drop(solve); + Ok(()) +} diff --git a/pegainfer-kernels/src/lib.rs b/pegainfer-kernels/src/lib.rs index 0bcbf7f0..6ae6e872 100644 --- a/pegainfer-kernels/src/lib.rs +++ b/pegainfer-kernels/src/lib.rs @@ -7,4 +7,6 @@ pub mod gpu_buffers; pub mod ops; pub mod paged_kv; pub mod tensor; +#[cfg(feature = "tvm-ffi-triton-cubin")] +pub mod triton_cubin; pub mod typed_ops; diff --git a/pegainfer-kernels/src/triton_cubin.rs b/pegainfer-kernels/src/triton_cubin.rs new file mode 100644 index 00000000..27cdba53 --- /dev/null +++ b/pegainfer-kernels/src/triton_cubin.rs @@ -0,0 +1,321 @@ +//! TVM FFI wrappers for Triton AOT CUBIN launchers. +//! +//! The generated Triton AOT C stubs remain the low-level CUDA launch owner. +//! This module exposes the launchers through TVM FFI so DSL-produced artifacts +//! can call them without depending on PegaInfer's private Rust operator APIs. + +use std::ffi::c_void; + +use cudarc::driver::sys::{CUresult, CUstream}; +use tvm_ffi::{ + Any, AnyView, Error, Function, RUNTIME_ERROR, Result as TvmResult, TYPE_ERROR, VALUE_ERROR, +}; + +use crate::ffi; + +/// Metadata for one Triton AOT CUBIN launcher exposed through TVM FFI. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct TritonCubinFunctionSpec { + /// TVM global function name. + pub name: &'static str, + /// Linked C ABI symbol that ultimately launches the generated Triton CUBIN. + pub ffi_symbol: &'static str, + /// Packed TVM FFI argument names, in call order. + pub arg_names: &'static [&'static str], +} + +const QWEN35_GDR_CHUNK_SOLVE_ARGS: &[&str] = &[ + "a_tril_ptr", + "a_inv_ptr", + "seq_len", + "num_value_heads", + "stream", +]; + +/// Qwen3.5 gated-delta-rule triangular solve Triton AOT launcher. +pub const QWEN35_GDR_CHUNK_SOLVE: TritonCubinFunctionSpec = TritonCubinFunctionSpec { + name: "pegainfer.triton_cubin.qwen35.gated_delta_rule_chunk_solve", + ffi_symbol: "gated_delta_rule_prefill_chunk_solve_cuda", + arg_names: QWEN35_GDR_CHUNK_SOLVE_ARGS, +}; + +/// Triton AOT CUBIN launchers currently exposed through TVM FFI. +pub const TRITON_CUBIN_FUNCTIONS: &[TritonCubinFunctionSpec] = &[QWEN35_GDR_CHUNK_SOLVE]; + +/// Return a fresh TVM FFI function for a known Triton CUBIN launcher. +#[must_use] +pub fn function(name: &str) -> Option { + if name == QWEN35_GDR_CHUNK_SOLVE.name { + return Some(Function::from_packed(launch_qwen35_gdr_chunk_solve)); + } + None +} + +/// Register all current Triton CUBIN launchers in TVM FFI's global registry. +pub fn register_global_functions() -> TvmResult<()> { + for spec in TRITON_CUBIN_FUNCTIONS { + if Function::get_global(spec.name).is_ok() { + continue; + } + let func = function(spec.name).ok_or_else(|| { + Error::new( + RUNTIME_ERROR, + &format!("missing TVM FFI wrapper for {}", spec.name), + "", + ) + })?; + Function::register_global(spec.name, func)?; + } + Ok(()) +} + +/// Register wrappers if needed, then fetch one wrapper from TVM FFI's global registry. +pub fn get_global_or_register(name: &str) -> TvmResult { + register_global_functions()?; + Function::get_global(name) +} + +fn expect_args(args: &[AnyView<'_>], spec: TritonCubinFunctionSpec) -> TvmResult<()> { + if args.len() == spec.arg_names.len() { + return Ok(()); + } + Err(Error::new( + VALUE_ERROR, + &format!( + "{} expects {} arguments ({}) but got {}", + spec.name, + spec.arg_names.len(), + spec.arg_names.join(", "), + args.len() + ), + "", + )) +} + +fn type_error(spec: TritonCubinFunctionSpec, idx: usize, expected: &str) -> Error { + Error::new( + TYPE_ERROR, + &format!( + "{} argument #{} `{}` must be {}", + spec.name, idx, spec.arg_names[idx], expected + ), + "", + ) +} + +fn arg_handle(args: &[AnyView<'_>], spec: TritonCubinFunctionSpec, idx: usize) -> TvmResult { + let value = args + .get(idx) + .ok_or_else(|| type_error(spec, idx, "a non-negative integer or opaque pointer"))?; + if let Some(raw) = value.try_as::() { + return usize::try_from(raw) + .map_err(|_| type_error(spec, idx, "a non-negative integer or opaque pointer")); + } + if let Some(raw) = value.try_as::() { + return usize::try_from(raw) + .map_err(|_| type_error(spec, idx, "a non-negative integer or opaque pointer")); + } + if let Some(raw) = value.try_as::<*mut c_void>() { + return Ok(raw as usize); + } + Err(type_error( + spec, + idx, + "a non-negative integer or opaque pointer", + )) +} + +fn arg_i32(args: &[AnyView<'_>], spec: TritonCubinFunctionSpec, idx: usize) -> TvmResult { + let value = args + .get(idx) + .ok_or_else(|| type_error(spec, idx, "an i32-range integer"))?; + if let Some(raw) = value.try_as::() { + return i32::try_from(raw).map_err(|_| type_error(spec, idx, "an i32-range integer")); + } + if let Some(raw) = value.try_as::() { + return i32::try_from(raw).map_err(|_| type_error(spec, idx, "an i32-range integer")); + } + if let Some(raw) = value.try_as::() { + return Ok(raw); + } + Err(type_error(spec, idx, "an i32-range integer")) +} + +fn stream(args: &[AnyView<'_>], spec: TritonCubinFunctionSpec, idx: usize) -> TvmResult { + Ok(arg_handle(args, spec, idx)? as CUstream) +} + +fn f32_const( + args: &[AnyView<'_>], + spec: TritonCubinFunctionSpec, + idx: usize, +) -> TvmResult<*const f32> { + Ok(arg_handle(args, spec, idx)? as *const f32) +} + +fn half_mut( + args: &[AnyView<'_>], + spec: TritonCubinFunctionSpec, + idx: usize, +) -> TvmResult<*mut ffi::Half> { + Ok(arg_handle(args, spec, idx)? as *mut ffi::Half) +} + +fn cuda_result(spec: TritonCubinFunctionSpec, result: CUresult) -> TvmResult { + if result as u32 == 0 { + Ok(Any::from(())) + } else { + Err(Error::new( + RUNTIME_ERROR, + &format!( + "{} via {} returned CUDA result {:?}", + spec.name, spec.ffi_symbol, result + ), + "", + )) + } +} + +fn launch_qwen35_gdr_chunk_solve(args: &[AnyView<'_>]) -> TvmResult { + let spec = QWEN35_GDR_CHUNK_SOLVE; + expect_args(args, spec)?; + let result = unsafe { + ffi::gated_delta_rule_prefill_chunk_solve_cuda( + f32_const(args, spec, 0)?, + half_mut(args, spec, 1)?, + arg_i32(args, spec, 2)?, + arg_i32(args, spec, 3)?, + stream(args, spec, 4)?, + ) + }; + cuda_result(spec, result) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn error_message(result: TvmResult, context: &str) -> String { + match result { + Ok(_) => panic!("{context}"), + Err(err) => err.message().to_string(), + } + } + + #[test] + fn exposes_current_triton_cubin_specs() { + assert_eq!(TRITON_CUBIN_FUNCTIONS, &[QWEN35_GDR_CHUNK_SOLVE]); + assert!( + QWEN35_GDR_CHUNK_SOLVE + .name + .starts_with("pegainfer.triton_cubin.qwen35.") + ); + assert_eq!( + QWEN35_GDR_CHUNK_SOLVE.arg_names, + QWEN35_GDR_CHUNK_SOLVE_ARGS + ); + } + + #[test] + fn packed_wrapper_reports_argument_contract_before_launch() { + let func = function(QWEN35_GDR_CHUNK_SOLVE.name).expect("known wrapper"); + let err = error_message(func.call_packed(&[]), "missing args should fail"); + assert!(err.contains("expects 5 arguments")); + assert!(err.contains("a_tril_ptr")); + } + + #[test] + fn rejects_unknown_triton_cubin_function() { + assert!(function("pegainfer.triton_cubin.unknown").is_none()); + } + + #[test] + fn global_registry_round_trips_wrapper() { + let func = get_global_or_register(QWEN35_GDR_CHUNK_SOLVE.name).expect("registered wrapper"); + let err = error_message(func.call_packed(&[]), "missing args should fail"); + assert!(err.contains(QWEN35_GDR_CHUNK_SOLVE.name)); + assert!(err.contains("expects 5 arguments")); + } + + #[test] + fn handle_args_accept_integer_and_opaque_pointer() { + let tvm_integer_handle = 0x1234_i64; + let tvm_integer_args = [AnyView::from(&tvm_integer_handle)]; + assert_eq!( + arg_handle(&tvm_integer_args, QWEN35_GDR_CHUNK_SOLVE, 0).expect("i64 handle"), + tvm_integer_handle as usize + ); + + let rust_integer_handle = 0x3456_u64; + let rust_integer_args = [AnyView::from(&rust_integer_handle)]; + assert_eq!( + arg_handle(&rust_integer_args, QWEN35_GDR_CHUNK_SOLVE, 0).expect("u64 handle"), + rust_integer_handle as usize + ); + + let opaque_handle = 0x5678_usize as *mut c_void; + let opaque_args = [AnyView::from(&opaque_handle)]; + assert_eq!( + arg_handle(&opaque_args, QWEN35_GDR_CHUNK_SOLVE, 0).expect("opaque handle"), + opaque_handle as usize + ); + } + + #[test] + fn scalar_args_accept_tvm_i64_integer() { + let seq_len = 16_i64; + let args = [AnyView::from(&seq_len)]; + assert_eq!( + arg_i32(&args, QWEN35_GDR_CHUNK_SOLVE, 2).expect("i64 scalar"), + 16_i32 + ); + } + + #[test] + fn packed_wrapper_reports_handle_type_errors_before_launch() { + let bad_handle = 1.25_f32; + let a_inv_ptr = 0_u64; + let seq_len = 16_i32; + let num_value_heads = 8_i32; + let stream = 0_u64; + let args = [ + AnyView::from(&bad_handle), + AnyView::from(&a_inv_ptr), + AnyView::from(&seq_len), + AnyView::from(&num_value_heads), + AnyView::from(&stream), + ]; + + let func = function(QWEN35_GDR_CHUNK_SOLVE.name).expect("known wrapper"); + let err = error_message( + func.call_packed(&args), + "bad handle should fail before launch", + ); + assert!(err.contains("argument #0 `a_tril_ptr`")); + assert!(err.contains("integer or opaque pointer")); + } + + #[test] + fn packed_wrapper_reports_scalar_type_errors_before_launch() { + let a_tril_ptr = 0_u64; + let a_inv_ptr = 0_u64; + let bad_seq_len = 16.0_f32; + let num_value_heads = 8_i32; + let stream = 0_u64; + let args = [ + AnyView::from(&a_tril_ptr), + AnyView::from(&a_inv_ptr), + AnyView::from(&bad_seq_len), + AnyView::from(&num_value_heads), + AnyView::from(&stream), + ]; + + let func = function(QWEN35_GDR_CHUNK_SOLVE.name).expect("known wrapper"); + let err = error_message( + func.call_packed(&args), + "bad scalar should fail before launch", + ); + assert!(err.contains("argument #2 `seq_len`")); + assert!(err.contains("must be an i32-range integer")); + } +} diff --git a/pegainfer-kernels/tools/triton/README.md b/pegainfer-kernels/tools/triton/README.md index 56b06b4d..192d62a6 100644 --- a/pegainfer-kernels/tools/triton/README.md +++ b/pegainfer-kernels/tools/triton/README.md @@ -3,6 +3,11 @@ `pegainfer` currently uses Triton AOT for the Qwen3.5 HD256 prefill kernel and the Qwen3.5 GDR chunkwise prefill kernels. +`pegainfer-kernels` can also expose selected generated CUBIN launchers through +TVM FFI under `pegainfer_kernels::triton_cubin` when the +`tvm-ffi-triton-cubin` feature is enabled. This is the DSL-facing wrapper layer; +the generated C stubs remain the low-level CUDA launch owner. + ## What this covers - Build-time generation of Triton AOT cubins for: @@ -20,6 +25,10 @@ export CUDA_HOME=/usr/local/cuda export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH ``` +The TVM FFI bridge is optional. Only install the TVM FFI runtime when building +with `--features tvm-ffi-triton-cubin`; in that mode `tvm-ffi-config` must be on +`PATH` and `libtvm_ffi` must be discoverable during build and runtime. + Bootstrap a repo-local Triton Python once: ```bash @@ -68,6 +77,17 @@ Generated Triton artifacts are written to Cargo `OUT_DIR`, typically under: target/release/build/pegainfer-kernels-*/out/triton_aot/ ``` +## TVM FFI wrapper example + +```bash +cargo run --release -p pegainfer-kernels --features tvm-ffi-triton-cubin --example triton_cubin_tvm_ffi +``` + +The registered names use the `pegainfer.triton_cubin.qwen35.*` prefix. Pointer +and stream arguments are packed as TVM integers or opaque pointers; scalar launch +arguments use TVM integers. The wrapper returns `()` on CUDA success and a TVM +`RuntimeError` if the underlying CUBIN launcher returns a non-success CUDA result. + ## Validation Run the focused GPU tests for the active Triton-backed paths: