Skip to content

Commit

Permalink
Merge pull request #5 from aftershootco/docs
Browse files Browse the repository at this point in the history
Added basic docs
  • Loading branch information
uttarayan21 authored Oct 23, 2024
2 parents ef6b107 + b392211 commit 7cff340
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 3 deletions.
87 changes: 84 additions & 3 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ impl SessionMode {
}
}

/// net data holder. multiple sessions could share same net.
#[repr(transparent)]
pub struct Interpreter {
pub(crate) inner: *mut mnn_sys::Interpreter,
Expand All @@ -136,6 +137,11 @@ impl Drop for Interpreter {
}

impl Interpreter {
/// Create an net/interpreter from a file.
///
/// `path`: the file path of the model
///
/// return: the created net/interpreter
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string(), "File not found");
Expand All @@ -149,6 +155,11 @@ impl Interpreter {
})
}

/// Create an net/interpreter from a buffer.
///
/// `bytes`: the buffer of the model
///
/// return: the created net/interpreter
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
let bytes = bytes.as_ref();
let size = bytes.len();
Expand All @@ -161,14 +172,31 @@ impl Interpreter {
})
}

/// Set session mode
///
/// `mode`: the session mode
///
/// **Warning:**
/// It should be called before create session!
pub fn set_session_mode(&mut self, mode: SessionMode) {
unsafe { mnn_sys::Interpreter_setSessionMode(self.inner, mode.to_mnn_sys()) }
}

///call this function to get tensors ready.
///
///output tensor buffer (host or deviceId) should be retrieved after resize of any input tensor.
///
///`session`: the session to be prepared
pub fn resize_session(&self, session: &mut crate::Session) {
unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) }
}

/// Resize session and reallocate the buffer.
///
/// `session`: the session to be prepared.
///
/// # Note
/// NeedRelloc is default to 1, 1 means need realloc!
pub fn resize_session_reallocate(&self, session: &mut crate::Session) {
unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) }
}
Expand Down Expand Up @@ -206,6 +234,11 @@ impl Interpreter {
}
}

/// Create a session with session config. Session will be managed in net/interpreter.
///
/// `schedule` : the config of the session
///
/// return: the created session
pub fn create_session(
&mut self,
schedule: crate::ScheduleConfig,
Expand All @@ -221,6 +254,11 @@ impl Interpreter {
})
}

/// Create multi-path session with schedule configs and user-specified runtime. created session will be managed in net/interpreter.
///
/// `schedule` : the config of the session
///
/// return: the created session
pub fn create_multipath_session(
&mut self,
schedule: impl IntoIterator<Item = ScheduleConfig>,
Expand All @@ -238,6 +276,7 @@ impl Interpreter {
})
}

/// Print all input and output tensors info.
pub fn model_print_io(path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
crate::ensure!(path.exists(), ErrorKind::IOError);
Expand All @@ -247,11 +286,23 @@ impl Interpreter {
Ok(())
}

/// Get the input tensor of the session.
///
/// `session`: the session to get input tensor
///
/// return: List of input tensors
pub fn inputs(&self, session: &crate::Session) -> TensorList {
let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) };
TensorList::from_ptr(inputs)
}

/// Get the input tensor of the session by name.
///
/// `session`: the session to get input tensor from
///
/// `name`: the name of the input tensor
///
/// return: the input tensor
pub fn input<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand Down Expand Up @@ -291,7 +342,7 @@ impl Interpreter {
}

/// # Safety
/// We Still don't know the safety guarantees of this function so it's marked unsafe
/// **Warning** We Still don't know the safety guarantees of this function so it's marked unsafe
pub unsafe fn input_unresized<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand All @@ -314,7 +365,7 @@ impl Interpreter {
}

/// # Safety
/// Very unsafe since it doesn't check the type of the tensor
/// Very **unsafe** since it doesn't check the type of the tensor
/// as well as the shape of the tensor
pub unsafe fn input_unchecked<'s, H: HalideType>(
&self,
Expand All @@ -329,6 +380,10 @@ impl Interpreter {
}

/// Get the output tensor of a session by name
///
/// `session` : the session to get output tensor from
///
/// `name` : the name of the output tensor
pub fn output<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand Down Expand Up @@ -366,6 +421,7 @@ impl Interpreter {
Ok(RawTensor::from_ptr(output))
}

/// Run a session
pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> {
profile!("Running session"; {
let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
Expand All @@ -377,6 +433,15 @@ impl Interpreter {
})
}

/// Run a session with a callback
///
/// `session` : the session to run
///
/// `before` : a callback before each op. return true to run the op; return false to skip the op.
///
/// `after` : a callback after each op. return true to continue running; return false to interrupt the session.
///
/// `sync` : synchronously wait for finish of execution or not.
pub fn run_session_with_callback(
&mut self,
session: &crate::session::Session,
Expand All @@ -403,12 +468,24 @@ impl Interpreter {
Ok(())
}

/// Get all output tensors of a session
pub fn outputs(&self, session: &crate::session::Session) -> TensorList {
let outputs =
unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) };
TensorList::from_ptr(outputs)
}

/// If the cache exist, try to load cache from file.
/// After createSession, try to save cache to file.
///
/// `cache_file` : the file path to save or load cache.
///
/// `key_size` : the size of key
///
/// # Note
/// The API should be called before create session.
///
/// Key Depercerate, keeping for future use!
pub fn set_cache_file(&mut self, path: impl AsRef<Path>, key_size: usize) -> Result<()> {
let path = path.as_ref();
let path = dunce::simplified(path);
Expand All @@ -417,6 +494,8 @@ impl Interpreter {
unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) }
Ok(())
}

/// Update cache file
pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
MNNError::from_error_code(unsafe {
mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
Expand All @@ -433,6 +512,7 @@ impl Interpreter {
});
}

/// Get memory usage of a session in MB
pub fn memory(&self, session: &crate::session::Session) -> Result<f32> {
let mut memory = 0f32;
let memory_ptr = &mut memory as *mut f32;
Expand All @@ -447,6 +527,7 @@ impl Interpreter {
Ok(memory)
}

/// Get float operation needed in session in M
pub fn flops(&self, session: &crate::Session) -> Result<f32> {
let mut flop = 0.0f32;
let flop_ptr = &mut flop as *mut f32;
Expand Down Expand Up @@ -535,7 +616,7 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> {
}

/// # Safety
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be unsafe with this
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this
pub unsafe fn tensor_unresized<H: HalideType>(&self) -> Result<Tensor<RefMut<'t, Device<H>>>> {
debug_assert!(!self.tensor_info.is_null());
unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) };
Expand Down
Loading

0 comments on commit 7cff340

Please sign in to comment.