diff --git a/Cargo.lock b/Cargo.lock index c5faa9c2..e62fae7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2884,6 +2884,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "goblin" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d20fd25aa456527ce4f544271ae4fea65d2eda4a6561ea56f39fb3ee4f7e3884" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "group" version = "0.12.1" @@ -5325,6 +5336,7 @@ dependencies = [ "cudarc", "half", "serde", + "tvm-ffi", ] [[package]] @@ -5628,6 +5640,12 @@ version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.7" @@ -5808,6 +5826,30 @@ dependencies = [ "toml_edit 0.25.11+spec-1.1.0", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -6977,6 +7019,26 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" +[[package]] +name = "scroll" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda28d4b4830b807a8b43f7b0e6b5df875311b3e7621d84577188c175b6ec1ec" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaaae8f38bb311444cfb7f1979af0bc9240d95795f75f9ceddf6a59b79ceffa0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "sct" version = "0.7.1" @@ -8561,6 +8623,39 @@ dependencies = [ "tokio", ] +[[package]] +name = "tvm-ffi" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c9d5b22b336c5de9ca98442715ec64f09b7038c6d88817a7de8ae00b2c384" +dependencies = [ + "paste", + "tvm-ffi-macros", + "tvm-ffi-sys", +] + +[[package]] +name = "tvm-ffi-macros" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18901fd34d643368ca7139e1d592a6c484cb855f6554057a699b9b77c29851e3" +dependencies = [ + "goblin", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "tvm-ffi-sys" +version = "0.1.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7552573e760924afb4e24fce8b5d453236b59710c8c37438d00c1006a9b38f1c" +dependencies = [ + "cmake", +] + [[package]] name = "typeid" version = "1.0.3" diff --git a/pegainfer-kernels/Cargo.toml b/pegainfer-kernels/Cargo.toml index f36b8e12..2f23cdf7 100644 --- a/pegainfer-kernels/Cargo.toml +++ b/pegainfer-kernels/Cargo.toml @@ -8,6 +8,7 @@ anyhow = { workspace = true } cudarc = { workspace = true } half = { workspace = true } serde = { workspace = true } +tvm-ffi = { version = "0.1.0-alpha.0", optional = true } [build-dependencies] cc = { workspace = true } @@ -17,6 +18,7 @@ default = [] deepseek-v4 = [] deepseek-v4-cutedsl-diagnostic = ["deepseek-v4"] kimi-k2 = [] +tvm-ffi-interop = ["dep:tvm-ffi"] [lints] workspace = true diff --git a/pegainfer-kernels/tests/fixtures/tvm_ffi_fixture.cc b/pegainfer-kernels/tests/fixtures/tvm_ffi_fixture.cc new file mode 100644 index 00000000..35d1625d --- /dev/null +++ b/pegainfer-kernels/tests/fixtures/tvm_ffi_fixture.cc @@ -0,0 +1,30 @@ +#include + +namespace ffi = tvm::ffi; + +int64_t AddOneScalar(int64_t x) { return x + 1; } + +int64_t ApplyCallback(ffi::Function callback, int64_t x) { + return callback(x).cast(); +} + +int64_t CallRegisteredHostAddThree(int64_t x) { + ffi::Function callback = + ffi::Function::GetGlobalRequired("pegainfer.testing.add_three"); + return callback(x).cast(); +} + +int64_t CallRegisteredHostFailIfNegative(int64_t x) { + ffi::Function callback = + ffi::Function::GetGlobalRequired("pegainfer.testing.fail_if_negative"); + return callback(x).cast(); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_scalar, AddOneScalar); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_callback, ApplyCallback); +TVM_FFI_DLL_EXPORT_TYPED_FUNC( + call_registered_host_add_three, + CallRegisteredHostAddThree); +TVM_FFI_DLL_EXPORT_TYPED_FUNC( + call_registered_host_fail_if_negative, + CallRegisteredHostFailIfNegative); diff --git a/pegainfer-kernels/tests/tvm_ffi_bidirectional.rs b/pegainfer-kernels/tests/tvm_ffi_bidirectional.rs new file mode 100644 index 00000000..2cf2e046 --- /dev/null +++ b/pegainfer-kernels/tests/tvm_ffi_bidirectional.rs @@ -0,0 +1,219 @@ +#![cfg(feature = "tvm-ffi-interop")] + +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::OnceLock; +use std::time::{SystemTime, UNIX_EPOCH}; + +use anyhow::{Context, Result, anyhow, ensure}; +use tvm_ffi::{Error, Function, Module, Result as TvmResult, VALUE_ERROR, into_typed_fn}; + +// Manual command: +// cargo test --release -p pegainfer-kernels --features tvm-ffi-interop tvm_ffi_bidirectional -- --ignored --nocapture + +static FIXTURE_LIB: OnceLock = OnceLock::new(); + +fn tvm(result: TvmResult, context: impl AsRef) -> Result { + result.map_err(|err| anyhow!("{}: {err}", context.as_ref())) +} + +fn fixture_extension() -> &'static str { + if cfg!(target_os = "macos") { + "dylib" + } else if cfg!(target_os = "windows") { + "dll" + } else { + "so" + } +} + +fn shell_quote(path: &Path) -> String { + let raw = path.to_string_lossy(); + format!("'{}'", raw.replace('\'', "'\"'\"'")) +} + +fn run_command(command: &str, context: &str) -> Result { + let output = Command::new("sh") + .arg("-c") + .arg(command) + .stdin(Stdio::null()) + .output() + .with_context(|| format!("failed to launch shell for {context}"))?; + if output.status.success() { + Ok(output) + } else { + Err(anyhow!( + "{context} failed\nstatus: {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )) + } +} + +fn require_tvm_ffi_config() -> Result { + let output = Command::new("tvm-ffi-config") + .arg("--libdir") + .stdin(Stdio::null()) + .output() + .context( + "failed to run tvm-ffi-config --libdir; install apache-tvm-ffi and ensure tvm-ffi-config is on PATH before enabling tvm-ffi-interop", + )?; + ensure!( + output.status.success(), + "tvm-ffi-config --libdir failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let libdir = String::from_utf8(output.stdout) + .context("tvm-ffi-config produced non-UTF8 output")? + .trim() + .to_string(); + ensure!( + !libdir.is_empty(), + "tvm-ffi-config --libdir returned an empty path" + ); + Ok(PathBuf::from(libdir)) +} + +fn build_fixture_library() -> Result { + if let Some(path) = FIXTURE_LIB.get() { + return Ok(path.clone()); + } + + let _libdir = require_tvm_ffi_config()?; + let cxx = std::env::var("PEGAINFER_TVM_FFI_CXX").unwrap_or_else(|_| "c++".to_string()); + let source = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join("tvm_ffi_fixture.cc"); + ensure!( + source.is_file(), + "fixture source missing: {}", + source.display() + ); + + let stamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .context("system clock is before UNIX_EPOCH")? + .as_nanos(); + let build_dir = std::env::temp_dir().join(format!("pegainfer-tvm-ffi-{stamp}")); + std::fs::create_dir_all(&build_dir) + .with_context(|| format!("failed to create {}", build_dir.display()))?; + let output = build_dir.join(format!("tvm_ffi_fixture.{}", fixture_extension())); + + let command = format!( + "{cxx} -shared -O3 -std=c++17 -fPIC -fvisibility=hidden \ + -o {output} {source} \ + $(tvm-ffi-config --cxxflags) \ + $(tvm-ffi-config --ldflags) \ + $(tvm-ffi-config --libs)", + output = shell_quote(&output), + source = shell_quote(&source), + ); + run_command(&command, "building TVM FFI fixture library").with_context(|| { + format!( + "failed to build TVM FFI fixture with compiler {cxx}; override with PEGAINFER_TVM_FFI_CXX if needed" + ) + })?; + + let _ = FIXTURE_LIB.set(output.clone()); + Ok(output) +} + +#[test] +#[ignore = "requires tvm-ffi runtime/tooling on PATH; run manually on a host with apache-tvm-ffi installed"] +fn tvm_ffi_bidirectional_smoke() -> Result<()> { + let fixture = build_fixture_library()?; + let fixture_path = fixture.to_string_lossy(); + let module = tvm( + Module::load_from_file(fixture_path.as_ref()), + format!( + "failed to load fixture module {}; verify libtvm_ffi runtime libraries are visible", + fixture.display() + ), + )?; + + let add_one = tvm( + module.get_function("add_one_scalar"), + "missing add_one_scalar export in TVM FFI fixture module", + )?; + let add_one = into_typed_fn!(add_one, Fn(i64) -> TvmResult); + assert_eq!(tvm(add_one(41), "calling add_one_scalar")?, 42); + + let apply_callback = tvm( + module.get_function("apply_callback"), + "missing apply_callback export in TVM FFI fixture module", + )?; + let apply_callback = into_typed_fn!(apply_callback, Fn(Function, i64) -> TvmResult); + let host_add_five = Function::from_typed(|x: i64| -> TvmResult { Ok(x + 5) }); + assert_eq!( + tvm( + apply_callback(host_add_five, 7), + "calling apply_callback with a Rust callback", + )?, + 12 + ); + + tvm( + Function::register_global( + "pegainfer.testing.add_three", + Function::from_typed(|x: i64| -> TvmResult { Ok(x + 3) }), + ), + "failed to register pegainfer.testing.add_three", + )?; + let call_registered = tvm( + module.get_function("call_registered_host_add_three"), + "missing call_registered_host_add_three export in TVM FFI fixture module", + )?; + let call_registered = into_typed_fn!(call_registered, Fn(i64) -> TvmResult); + assert_eq!( + tvm(call_registered(9), "calling call_registered_host_add_three")?, + 12 + ); + + tvm( + Function::register_global( + "pegainfer.testing.fail_if_negative", + Function::from_typed(|x: i64| -> TvmResult { + if x < 0 { + Err(Error::new( + VALUE_ERROR, + "negative input rejected by Rust callback", + "", + )) + } else { + Ok(x) + } + }), + ), + "failed to register pegainfer.testing.fail_if_negative", + )?; + let fail_callback = tvm( + module.get_function("call_registered_host_fail_if_negative"), + "missing call_registered_host_fail_if_negative export in TVM FFI fixture module", + )?; + let fail_callback = into_typed_fn!(fail_callback, Fn(i64) -> TvmResult); + let err = tvm( + fail_callback(-1), + "calling call_registered_host_fail_if_negative", + ) + .expect_err("negative callback should propagate an error"); + let err_text = err.to_string(); + assert!( + err_text.contains("negative input rejected by Rust callback"), + "unexpected callback error: {err_text}" + ); + + let missing = match module.get_function("does_not_exist") { + Ok(_) => panic!("missing TVM FFI symbol lookup should fail"), + Err(err) => err, + }; + let missing_text = missing.to_string(); + assert!( + missing_text.contains("Cannot convert from type `None` to `ffi.Function`"), + "unexpected missing-symbol error: {missing_text}" + ); + + Ok(()) +}