Skip to content

Commit

Permalink
feat(mnn): Use generated bindings for SessionInfoCode instead of hard…
Browse files Browse the repository at this point in the history
…-coding
  • Loading branch information
uttarayan21 committed Oct 24, 2024
1 parent f36f901 commit 4ba0ad8
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 7 deletions.
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ fn main() -> Result<()> {
mnn_c_build(PathBuf::from(MANIFEST_DIR).join("mnn_c"), &vendor)
.with_context(|| "Failed to build mnn_c")?;
mnn_c_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_c bindings")?;
mnn_cpp_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_cpp bindings")?;
println!("cargo:include={vendor}/include", vendor = vendor.display());
if *TARGET_OS == "macos" {
#[cfg(feature = "metal")]
Expand Down Expand Up @@ -225,6 +226,36 @@ pub fn mnn_c_bindgen(vendor: impl AsRef<Path>, out: impl AsRef<Path>) -> Result<
Ok(())
}

pub fn mnn_cpp_bindgen(vendor: impl AsRef<Path>, out: impl AsRef<Path>) -> Result<()> {
let vendor = vendor.as_ref();
let bindings = bindgen::Builder::default()
.clang_args(["-x", "c++"])
.clang_args(["-std=c++11"])
.clang_arg(CxxOption::VULKAN.cxx())
.clang_arg(CxxOption::METAL.cxx())
.clang_arg(CxxOption::COREML.cxx())
.clang_arg(CxxOption::OPENCL.cxx())
.clang_arg(format!("-I{}", vendor.join("include").to_string_lossy()))
.generate_cstr(true)
.generate_inline_functions(true)
.size_t_is_usize(true)
.emit_diagnostics()
.ctypes_prefix("core::ffi")
.header(
vendor
.join("include")
.join("MNN")
.join("Interpreter.hpp")
.to_string_lossy(),
)
.allowlist_item(".*SessionInfoCode.*");
// let cmd = bindings.command_line_flags().join(" ");
// println!("cargo:warn=bindgen: {}", cmd);
let bindings = bindings.generate()?;
bindings.write_to_file(out.as_ref().join("mnn_cpp.rs"))?;
Ok(())
}

pub fn mnn_c_build(path: impl AsRef<Path>, vendor: impl AsRef<Path>) -> Result<()> {
let mnn_c = path.as_ref();
let files = mnn_c.read_dir()?.flatten().map(|e| e.path()).filter(|e| {
Expand Down
6 changes: 6 additions & 0 deletions mnn-sys/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use std::ffi::CStr;

pub mod cpp {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
include!(concat!(env!("OUT_DIR"), "/mnn_cpp.rs"));
}
mod sys {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
Expand Down
2 changes: 1 addition & 1 deletion mnn-sys/vendor
Submodule vendor updated 151 files
16 changes: 13 additions & 3 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,12 @@ impl Interpreter {
let mut memory = 0f32;
let memory_ptr = &mut memory as *mut f32;
let ret = unsafe {
mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 0, memory_ptr.cast())
mnn_sys::Interpreter_getSessionInfo(
self.inner,
session.inner,
mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_MEMORY as _,
memory_ptr.cast(),
)
};
ensure!(
ret == 1,
Expand All @@ -544,7 +549,7 @@ impl Interpreter {
mnn_sys::Interpreter_getSessionInfo(
self.inner,
session.inner,
1,
mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_FLOPS as _,
flop_ptr.cast::<libc::c_void>(),
)
};
Expand All @@ -560,7 +565,12 @@ impl Interpreter {
let mut resize_status = 0i32;
let ptr = &mut resize_status as *mut i32;
let ret = unsafe {
mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 3, ptr.cast())
mnn_sys::Interpreter_getSessionInfo(
self.inner,
session.inner,
mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_RESIZE_STATUS as _,
ptr.cast(),
)
};
ensure!(
ret == 1,
Expand Down

0 comments on commit 4ba0ad8

Please sign in to comment.