Skip to content

Commit 6942e13

Browse files
committed
Update llama.cpp to latest version
1 parent 904fbda commit 6942e13

File tree

4 files changed

+24
-65
lines changed

4 files changed

+24
-65
lines changed

llama-cpp-2/src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ impl<'model> LlamaContext<'model> {
318318
scale: f32,
319319
) -> Result<(), LlamaLoraAdapterSetError> {
320320
let err_code = unsafe {
321-
llama_cpp_sys_2::llama_lora_adapter_set(
321+
llama_cpp_sys_2::llama_set_adapter_lora(
322322
self.context.as_ptr(),
323323
adapter.lora_adapter.as_ptr(),
324324
scale,
@@ -342,7 +342,7 @@ impl<'model> LlamaContext<'model> {
342342
adapter: &mut LlamaLoraAdapter,
343343
) -> Result<(), LlamaLoraAdapterRemoveError> {
344344
let err_code = unsafe {
345-
llama_cpp_sys_2::llama_lora_adapter_remove(
345+
llama_cpp_sys_2::llama_rm_adapter_lora(
346346
self.context.as_ptr(),
347347
adapter.lora_adapter.as_ptr(),
348348
)

llama-cpp-2/src/model.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub struct LlamaModel {
3131
#[repr(transparent)]
3232
#[allow(clippy::module_name_repetitions)]
3333
pub struct LlamaLoraAdapter {
34-
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_lora_adapter>,
34+
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
3535
}
3636

3737
/// A Safe wrapper around `llama_chat_message`
@@ -74,6 +74,10 @@ unsafe impl Send for LlamaModel {}
7474
unsafe impl Sync for LlamaModel {}
7575

7676
impl LlamaModel {
77+
pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
78+
unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
79+
}
80+
7781
/// get the number of tokens the model was trained on
7882
///
7983
/// # Panics
@@ -99,28 +103,28 @@ impl LlamaModel {
99103
/// Get the beginning of stream token.
100104
#[must_use]
101105
pub fn token_bos(&self) -> LlamaToken {
102-
let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) };
106+
let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
103107
LlamaToken(token)
104108
}
105109

106110
/// Get the end of stream token.
107111
#[must_use]
108112
pub fn token_eos(&self) -> LlamaToken {
109-
let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) };
113+
let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
110114
LlamaToken(token)
111115
}
112116

113117
/// Get the newline token.
114118
#[must_use]
115119
pub fn token_nl(&self) -> LlamaToken {
116-
let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) };
120+
let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
117121
LlamaToken(token)
118122
}
119123

120124
/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
121125
#[must_use]
122126
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
123-
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
127+
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
124128
}
125129

126130
/// Get the decoder start token.
@@ -225,7 +229,7 @@ impl LlamaModel {
225229

226230
let size = unsafe {
227231
llama_cpp_sys_2::llama_tokenize(
228-
self.model.as_ptr(),
232+
self.vocab_ptr(),
229233
c_string.as_ptr(),
230234
c_int::try_from(c_string.as_bytes().len())?,
231235
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
@@ -241,7 +245,7 @@ impl LlamaModel {
241245
buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
242246
unsafe {
243247
llama_cpp_sys_2::llama_tokenize(
244-
self.model.as_ptr(),
248+
self.vocab_ptr(),
245249
c_string.as_ptr(),
246250
c_int::try_from(c_string.as_bytes().len())?,
247251
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
@@ -268,7 +272,7 @@ impl LlamaModel {
268272
/// If the token type is not known to this library.
269273
#[must_use]
270274
pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
271-
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) };
275+
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
272276
LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
273277
}
274278

@@ -347,7 +351,7 @@ impl LlamaModel {
347351
let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
348352
let size = unsafe {
349353
llama_cpp_sys_2::llama_token_to_piece(
350-
self.model.as_ptr(),
354+
self.vocab_ptr(),
351355
token.0,
352356
buf,
353357
len,
@@ -374,7 +378,7 @@ impl LlamaModel {
374378
/// without issue.
375379
#[must_use]
376380
pub fn n_vocab(&self) -> i32 {
377-
unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) }
381+
unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
378382
}
379383

380384
/// The type of vocab the model was trained on.
@@ -384,7 +388,8 @@ impl LlamaModel {
384388
/// If llama-cpp emits a vocab type that is not known to this library.
385389
#[must_use]
386390
pub fn vocab_type(&self) -> VocabType {
387-
let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) };
391+
// llama_cpp_sys_2::llama_model_get_vocab
392+
let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
388393
VocabType::try_from(vocab_type).expect("invalid vocab type")
389394
}
390395

@@ -479,7 +484,7 @@ impl LlamaModel {
479484

480485
let cstr = CString::new(path)?;
481486
let adapter =
482-
unsafe { llama_cpp_sys_2::llama_lora_adapter_init(self.model.as_ptr(), cstr.as_ptr()) };
487+
unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
483488

484489
let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
485490

@@ -548,7 +553,6 @@ impl LlamaModel {
548553

549554
let res = unsafe {
550555
llama_cpp_sys_2::llama_chat_apply_template(
551-
self.model.as_ptr(),
552556
tmpl_ptr,
553557
chat.as_ptr(),
554558
chat.len(),
@@ -563,7 +567,6 @@ impl LlamaModel {
563567

564568
let res = unsafe {
565569
llama_cpp_sys_2::llama_chat_apply_template(
566-
self.model.as_ptr(),
567570
tmpl_ptr,
568571
chat.as_ptr(),
569572
chat.len(),

llama-cpp-2/src/sampling.rs

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ impl LlamaSampler {
238238

239239
let sampler = unsafe {
240240
llama_cpp_sys_2::llama_sampler_init_grammar(
241-
model.model.as_ptr(),
241+
model.vocab_ptr(),
242242
grammar_str.as_ptr(),
243243
grammar_root.as_ptr(),
244244
)
@@ -264,14 +264,15 @@ impl LlamaSampler {
264264
) -> Self {
265265
let seq_breakers: Vec<CString> = seq_breakers
266266
.into_iter()
267-
.map(|s| CString::new(s.as_ref()).unwrap())
267+
.map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
268268
.collect();
269269
let mut seq_breaker_pointers: Vec<*const CChar> =
270270
seq_breakers.iter().map(|s| s.as_ptr()).collect();
271271

272272
let sampler = unsafe {
273273
llama_cpp_sys_2::llama_sampler_init_dry(
274-
model.model.as_ptr(),
274+
model.vocab_ptr(),
275+
model.n_ctx_train().try_into().expect("n_ctx_train is greater than two billion"),
275276
multiplier,
276277
base,
277278
allowed_length,
@@ -286,74 +287,29 @@ impl LlamaSampler {
286287
/// Penalizes tokens for being present in the context.
287288
///
288289
/// Parameters:
289-
/// - ``n_vocab``: [`LlamaModel::n_vocab`]
290-
/// - ``special_eos)id``: [`LlamaModel::token_eos`]
291-
/// - ``linefeed_id``: [`LlamaModel::token_nl`]
292290
/// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
293291
/// - ``penalty_repeat``: 1.0 = disabled
294292
/// - ``penalty_freq``: 0.0 = disabled
295293
/// - ``penalty_present``: 0.0 = disabled
296-
/// - ``penalize_nl``: consider newlines as a repeatable token
297-
/// - ``ignore_eos``: ignore the end-of-sequence token
298294
#[allow(clippy::too_many_arguments)]
299295
#[must_use]
300296
pub fn penalties(
301-
n_vocab: i32,
302-
special_eos_id: i32,
303-
linefeed_id: i32,
304297
penalty_last_n: i32,
305298
penalty_repeat: f32,
306299
penalty_freq: f32,
307300
penalty_present: f32,
308-
penalize_nl: bool,
309-
ignore_eos: bool,
310301
) -> Self {
311302
let sampler = unsafe {
312303
llama_cpp_sys_2::llama_sampler_init_penalties(
313-
n_vocab,
314-
special_eos_id,
315-
linefeed_id,
316304
penalty_last_n,
317305
penalty_repeat,
318306
penalty_freq,
319307
penalty_present,
320-
penalize_nl,
321-
ignore_eos,
322308
)
323309
};
324310
Self { sampler }
325311
}
326312

327-
/// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
328-
/// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
329-
///
330-
/// Parameters:
331-
/// - ``model``: The model's tokenizer to use to initialize the sampler.
332-
/// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
333-
/// - ``penalty_repeat``: 1.0 = disabled
334-
/// - ``penalty_freq``: 0.0 = disabled
335-
/// - ``penalty_present``: 0.0 = disabled
336-
#[must_use]
337-
pub fn penalties_simple(
338-
model: &LlamaModel,
339-
penalty_last_n: i32,
340-
penalty_repeat: f32,
341-
penalty_freq: f32,
342-
penalty_present: f32,
343-
) -> Self {
344-
Self::penalties(
345-
model.n_vocab(),
346-
model.token_eos().0,
347-
model.token_nl().0,
348-
penalty_last_n,
349-
penalty_repeat,
350-
penalty_freq,
351-
penalty_present,
352-
false,
353-
true,
354-
)
355-
}
356-
357313
/// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
358314
///
359315
/// # Parameters:

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)