diff --git a/flake.lock b/flake.lock index 6ef5aca..8b8f1c8 100644 --- a/flake.lock +++ b/flake.lock @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1725634671, - "narHash": "sha256-v3rIhsJBOMLR8e/RNWxr828tB+WywYIoajrZKFM+0Gg=", + "lastModified": 1729665710, + "narHash": "sha256-AlcmCXJZPIlO5dmFzV3V2XF6x/OpNWUV8Y/FMPGd8Z4=", "owner": "nixos", "repo": "nixpkgs", - "rev": "574d1eac1c200690e27b8eb4e24887f8df7ac27c", + "rev": "2768c7d042a37de65bb1b5b3268fc987e534c49d", "type": "github" }, "original": { diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index de79a86..525318b 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -122,6 +122,7 @@ fn main() -> Result<()> { mnn_c_build(PathBuf::from(MANIFEST_DIR).join("mnn_c"), &vendor) .with_context(|| "Failed to build mnn_c")?; mnn_c_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_c bindings")?; + mnn_cpp_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_cpp bindings")?; println!("cargo:include={vendor}/include", vendor = vendor.display()); if *TARGET_OS == "macos" { #[cfg(feature = "metal")] @@ -225,6 +226,36 @@ pub fn mnn_c_bindgen(vendor: impl AsRef, out: impl AsRef) -> Result< Ok(()) } +pub fn mnn_cpp_bindgen(vendor: impl AsRef, out: impl AsRef) -> Result<()> { + let vendor = vendor.as_ref(); + let bindings = bindgen::Builder::default() + .clang_args(["-x", "c++"]) + .clang_args(["-std=c++11"]) + .clang_arg(CxxOption::VULKAN.cxx()) + .clang_arg(CxxOption::METAL.cxx()) + .clang_arg(CxxOption::COREML.cxx()) + .clang_arg(CxxOption::OPENCL.cxx()) + .clang_arg(format!("-I{}", vendor.join("include").to_string_lossy())) + .generate_cstr(true) + .generate_inline_functions(true) + .size_t_is_usize(true) + .emit_diagnostics() + .ctypes_prefix("core::ffi") + .header( + vendor + .join("include") + .join("MNN") + .join("Interpreter.hpp") + .to_string_lossy(), + ) + .allowlist_item(".*SessionInfoCode.*"); + // let cmd = bindings.command_line_flags().join(" "); + // println!("cargo:warn=bindgen: {}", cmd); + let bindings = bindings.generate()?; + bindings.write_to_file(out.as_ref().join("mnn_cpp.rs"))?; + Ok(()) +} + pub fn mnn_c_build(path: impl AsRef, vendor: impl AsRef) -> Result<()> { let mnn_c = path.as_ref(); let files = mnn_c.read_dir()?.flatten().map(|e| e.path()).filter(|e| { diff --git a/mnn-sys/src/lib.rs b/mnn-sys/src/lib.rs index 89d6a96..ce2bafb 100644 --- a/mnn-sys/src/lib.rs +++ b/mnn-sys/src/lib.rs @@ -1,5 +1,11 @@ use std::ffi::CStr; +pub mod cpp { + #![allow(non_upper_case_globals)] + #![allow(non_camel_case_types)] + #![allow(non_snake_case)] + include!(concat!(env!("OUT_DIR"), "/mnn_cpp.rs")); +} mod sys { #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] diff --git a/mnn-sys/vendor b/mnn-sys/vendor index a74551b..407a1c1 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit a74551b4f34b46ce7027c64e800d49fcab497261 +Subproject commit 407a1c141d459d093f655bf2fed2a8a5e22a77ce diff --git a/src/interpreter.rs b/src/interpreter.rs index dd2043c..2abe0b3 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -526,7 +526,12 @@ impl Interpreter { let mut memory = 0f32; let memory_ptr = &mut memory as *mut f32; let ret = unsafe { - mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 0, memory_ptr.cast()) + mnn_sys::Interpreter_getSessionInfo( + self.inner, + session.inner, + mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_MEMORY as _, + memory_ptr.cast(), + ) }; ensure!( ret == 1, @@ -544,7 +549,7 @@ impl Interpreter { mnn_sys::Interpreter_getSessionInfo( self.inner, session.inner, - 1, + mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_FLOPS as _, flop_ptr.cast::(), ) }; @@ -560,7 +565,12 @@ impl Interpreter { let mut resize_status = 0i32; let ptr = &mut resize_status as *mut i32; let ret = unsafe { - mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 3, ptr.cast()) + mnn_sys::Interpreter_getSessionInfo( + self.inner, + session.inner, + mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_RESIZE_STATUS as _, + ptr.cast(), + ) }; ensure!( ret == 1,