Skip to content

Commit 4063f55

Browse files
authored
Merge pull request #819 from fellhorn/dennis/mtmd-improvements
Multimodality improvements
2 parents a6565b0 + 94a83e9 commit 4063f55

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

examples/mtmd/src/mtmd.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ pub struct MtmdCliParams {
5050
/// Number of threads
5151
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
5252
pub n_threads: i32,
53+
/// Number of tokens to process in a batch during eval chunks
54+
#[arg(long = "batch-size", value_name = "b", default_value = "1")]
55+
pub batch_size: i32,
5356
/// Maximum number of tokens in context
5457
#[arg(long = "n-tokens", value_name = "N", default_value = "4096")]
5558
pub n_tokens: NonZeroU32,
@@ -140,6 +143,7 @@ impl MtmdCliContext {
140143
context: &mut LlamaContext,
141144
msg: LlamaChatMessage,
142145
add_bos: bool,
146+
batch_size: i32,
143147
) -> Result<(), Box<dyn std::error::Error>> {
144148
self.chat.push(msg);
145149

@@ -168,7 +172,7 @@ impl MtmdCliContext {
168172
// Clear bitmaps after tokenization
169173
self.bitmaps.clear();
170174

171-
self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?;
175+
self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, batch_size, true)?;
172176
Ok(())
173177
}
174178

@@ -186,7 +190,7 @@ impl MtmdCliContext {
186190

187191
for _i in 0..max_predict {
188192
// Sample next token
189-
let token = sampler.sample(context, 0);
193+
let token = sampler.sample(context, -1);
190194
generated_tokens.push(token);
191195
sampler.accept(token);
192196

@@ -244,7 +248,7 @@ fn run_single_turn(
244248
println!("Evaluating message: {msg:?}");
245249

246250
// Evaluate the message (prefill)
247-
ctx.eval_message(model, context, msg, true)?;
251+
ctx.eval_message(model, context, msg, true, params.batch_size)?;
248252

249253
// Generate response (decode)
250254
ctx.generate_response(model, context, sampler, params.n_predict)?;
@@ -286,7 +290,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
286290
// Create context
287291
let context_params = LlamaContextParams::default()
288292
.with_n_threads(params.n_threads)
289-
.with_n_batch(1)
293+
.with_n_batch(params.batch_size.try_into()?)
290294
.with_n_ctx(Some(params.n_tokens));
291295
let mut context = model.new_context(&backend, context_params)?;
292296

llama-cpp-2/src/mtmd.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,18 @@ use crate::token::LlamaToken;
2525
/// let audio_chunk = MtmdInputChunkType::Audio;
2626
///
2727
/// assert_eq!(text_chunk, MtmdInputChunkType::Text);
28+
/// assert_eq!(text_chunk, llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT.into());
2829
/// assert_ne!(text_chunk, image_chunk);
2930
/// ```
3031
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32+
#[repr(u32)]
3133
pub enum MtmdInputChunkType {
3234
/// Text input chunk
33-
Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as isize,
35+
Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as _,
3436
/// Image input chunk
35-
Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as isize,
37+
Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as _,
3638
/// Audio input chunk
37-
Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as isize,
39+
Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as _,
3840
}
3941

4042
impl From<llama_cpp_sys_2::mtmd_input_chunk_type> for MtmdInputChunkType {
@@ -43,7 +45,7 @@ impl From<llama_cpp_sys_2::mtmd_input_chunk_type> for MtmdInputChunkType {
4345
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT => MtmdInputChunkType::Text,
4446
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE => MtmdInputChunkType::Image,
4547
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO => MtmdInputChunkType::Audio,
46-
_ => panic!("Unknown MTMD input chunk type"),
48+
_ => panic!("Unknown MTMD input chunk type: {chunk_type}"),
4749
}
4850
}
4951
}
@@ -106,9 +108,7 @@ impl From<llama_cpp_sys_2::mtmd_context_params> for MtmdContextParams {
106108
use_gpu: params.use_gpu,
107109
print_timings: params.print_timings,
108110
n_threads: params.n_threads,
109-
media_marker: unsafe { CStr::from_ptr(params.media_marker) }
110-
.to_owned()
111-
.into(),
111+
media_marker: unsafe { CStr::from_ptr(params.media_marker) }.to_owned(),
112112
}
113113
}
114114
}
@@ -211,10 +211,11 @@ impl MtmdContext {
211211
}
212212

213213
/// Get audio bitrate in Hz (e.g., 16000 for Whisper).
214-
/// Returns -1 if audio is not supported.
214+
/// Returns None if audio is not supported.
215215
#[must_use]
216-
pub fn get_audio_bitrate(&self) -> i32 {
217-
unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) }
216+
pub fn get_audio_bitrate(&self) -> Option<u32> {
217+
let rate = unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) };
218+
(rate > 0).then_some(rate.unsigned_abs())
218219
}
219220

220221
/// Tokenize input text and bitmaps into chunks.
@@ -275,7 +276,7 @@ impl MtmdContext {
275276
llama_cpp_sys_2::mtmd_tokenize(
276277
self.context.as_ptr(),
277278
chunks.chunks.as_ptr(),
278-
&input_text,
279+
&raw const input_text,
279280
bitmap_ptrs.as_ptr().cast_mut(),
280281
bitmaps.len(),
281282
)
@@ -626,15 +627,11 @@ impl MtmdInputChunks {
626627
let chunk_ptr =
627628
unsafe { llama_cpp_sys_2::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
628629

629-
if chunk_ptr.is_null() {
630-
None
631-
} else {
632-
// Note: We don't own this chunk, it's owned by the chunks collection
633-
Some(MtmdInputChunk {
634-
chunk: NonNull::new(chunk_ptr.cast_mut()).unwrap(),
635-
owned: false,
636-
})
637-
}
630+
// Note: We don't own this chunk, it's owned by the chunks collection
631+
NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
632+
chunk: ptr,
633+
owned: false,
634+
})
638635
}
639636

640637
/// Get total number of tokens across all chunks.
@@ -701,7 +698,7 @@ impl MtmdInputChunks {
701698
seq_id,
702699
n_batch,
703700
logits_last,
704-
&mut new_n_past,
701+
&raw mut new_n_past,
705702
)
706703
};
707704

@@ -753,7 +750,10 @@ impl MtmdInputChunk {
753750

754751
let mut n_tokens = 0usize;
755752
let tokens_ptr = unsafe {
756-
llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(self.chunk.as_ptr(), &mut n_tokens)
753+
llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(
754+
self.chunk.as_ptr(),
755+
&raw mut n_tokens,
756+
)
757757
};
758758

759759
if tokens_ptr.is_null() || n_tokens == 0 {

0 commit comments

Comments
 (0)