Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor c strings type mapping #78

Merged
merged 3 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/sherpa-rs-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sherpa-rs-sys"
version = "0.6.4"
version = "0.6.5"
edition = "2021"
authors = ["thewh1teagle"]
homepage = "https://github.com/thewh1teagle/sherpa-rs"
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sherpa-rs"
version = "0.6.4"
version = "0.6.5"
edition = "2021"
authors = ["thewh1teagle"]
license = "MIT"
Expand All @@ -21,7 +21,7 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
eyre = "0.6.12"
hound = { version = "3.5.1" }
sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.4", default-features = false }
sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.5", default-features = false }
tracing = "0.1.40"

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/audio_tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl AudioTag {

for i in 0..self.config.top_k {
let event = *results.add(i.try_into().unwrap());
let event_name = cstr_to_string((*event).name);
let event_name = cstr_to_string((*event).name as _);
events.push(event_name);
}

Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/embedding_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl EmbeddingManager {
if name.is_null() {
return None;
}
let name = cstr_to_string(name);
let name = cstr_to_string(name as _);
Some(name)
}
}
Expand All @@ -57,7 +57,7 @@ impl EmbeddingManager {
let mut matches: Vec<SpeakerMatch> = Vec::new();
for i in 0..result.count {
let match_c = matches_c[i as usize];
let name = cstr_to_string(match_c.name);
let name = cstr_to_string(match_c.name as _);
let score = match_c.score;
matches.push(SpeakerMatch { name, score });
}
Expand All @@ -76,7 +76,7 @@ impl EmbeddingManager {
);

if status.is_negative() {
bail!("Failed to register {}", name)
bail!("Failed to register {}", name);
}
Ok(())
}
Expand All @@ -90,6 +90,6 @@ impl Drop for EmbeddingManager {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroySpeakerEmbeddingManager(self.manager);
};
}
}
}
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/keyword_spot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl KeywordSpot {
let spotter = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordSpotter(&sherpa_config) };

if spotter.is_null() {
bail!("Failed to create keyword spotter")
bail!("Failed to create keyword spotter");
}
let stream = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordStream(spotter) };
if stream.is_null() {
Expand Down Expand Up @@ -139,7 +139,7 @@ impl KeywordSpot {
let result_ptr = sherpa_rs_sys::SherpaOnnxGetKeywordResult(self.spotter, self.stream);
let mut keyword = None;
if !result_ptr.is_null() {
let decoded_keyword = cstr_to_string((*result_ptr).keyword);
let decoded_keyword = cstr_to_string((*result_ptr).keyword as _);
if !decoded_keyword.is_empty() {
keyword = Some(decoded_keyword);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/language_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ impl SpokenLanguageId {
let language_result_ptr =
sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCompute(self.slid, stream);
if language_result_ptr.is_null() || (*language_result_ptr).lang.is_null() {
bail!("language ptr is null")
bail!("language ptr is null");
}
let language_ptr = (*language_result_ptr).lang;
let language = cstr_to_string(language_ptr);
let language = cstr_to_string(language_ptr as _);
// Free
sherpa_rs_sys::SherpaOnnxDestroySpokenLanguageIdentificationResult(language_result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/moonshine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl MoonshineRecognizer {
let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}

Ok(Self { recognizer })
Expand All @@ -140,7 +140,7 @@ impl MoonshineRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);
let result = MoonshineRecognizerResult { text };
// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/punctuate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Punctuation {
unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config) };

if audio_punctuation.is_null() {
bail!("Failed to create audio punctuation")
bail!("Failed to create audio punctuation");
}
Ok(Self { audio_punctuation })
}
Expand All @@ -51,7 +51,7 @@ impl Punctuation {
self.audio_punctuation,
text.as_ptr(),
);
let text_with_punct = cstr_to_string(text_with_punct_ptr);
let text_with_punct = cstr_to_string(text_with_punct_ptr as _);
sherpa_rs_sys::SherpaOfflinePunctuationFreeText(text_with_punct_ptr);
text_with_punct
}
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl KokoroTts {
},
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/tts/matcha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl MatchaTts {
kokoro: mem::zeroed::<_>(),
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
34 changes: 4 additions & 30 deletions crates/sherpa-rs/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use std::ffi::CString;
use std::ffi::{c_char, CString};

// Smart pointer for CString
pub struct RawCStr {
#[cfg(target_os = "android")]
ptr: *mut u8,

#[cfg(not(target_os = "android"))]
ptr: *mut i8,
ptr: *mut std::ffi::c_char,
}

impl RawCStr {
Expand All @@ -23,13 +19,7 @@ impl RawCStr {
/// This function only returns the raw pointer and does not transfer ownership.
/// The pointer remains valid as long as the `CStr` instance exists.
/// Be cautious not to deallocate or modify the pointer after using `CStr::new`.
#[cfg(target_os = "android")]
pub fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}

#[cfg(not(target_os = "android"))]
pub fn as_ptr(&self) -> *const i8 {
pub fn as_ptr(&self) -> *const c_char {
self.ptr
}
}
Expand All @@ -38,29 +28,13 @@ impl Drop for RawCStr {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
#[cfg(target_os = "android")]
let _ = CString::from_raw(self.ptr as *mut u8);

#[cfg(not(target_os = "android"))]
let _ = CString::from_raw(self.ptr);
}
}
}
}

#[cfg(target_os = "android")]
pub fn cstr_to_string(ptr: *const u8) -> String {
unsafe {
if ptr.is_null() {
String::new()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
}

#[cfg(not(target_os = "android"))]
pub fn cstr_to_string(ptr: *const i8) -> String {
pub fn cstr_to_string(ptr: *mut c_char) -> String {
unsafe {
if ptr.is_null() {
String::new()
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl WhisperRecognizer {
let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}

Ok(Self { recognizer })
Expand All @@ -147,7 +147,7 @@ impl WhisperRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);
// let timestamps: &[f32] =
// std::slice::from_raw_parts(raw_result.timestamps, raw_result.count as usize);
let result = WhisperRecognizerResult { text };
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/zipformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl ZipFormer {
unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&recognizer_config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}
Ok(Self { recognizer })
}
Expand All @@ -113,7 +113,7 @@ impl ZipFormer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);

// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
Expand Down