@@ -31,7 +31,7 @@ pub struct LlamaModel {
31
31
#[ repr( transparent) ]
32
32
#[ allow( clippy:: module_name_repetitions) ]
33
33
pub struct LlamaLoraAdapter {
34
- pub ( crate ) lora_adapter : NonNull < llama_cpp_sys_2:: llama_lora_adapter > ,
34
+ pub ( crate ) lora_adapter : NonNull < llama_cpp_sys_2:: llama_adapter_lora > ,
35
35
}
36
36
37
37
/// A Safe wrapper around `llama_chat_message`
@@ -74,6 +74,10 @@ unsafe impl Send for LlamaModel {}
74
74
unsafe impl Sync for LlamaModel { }
75
75
76
76
impl LlamaModel {
77
+ pub ( crate ) fn vocab_ptr ( & self ) -> * const llama_cpp_sys_2:: llama_vocab {
78
+ unsafe { llama_cpp_sys_2:: llama_model_get_vocab ( self . model . as_ptr ( ) ) }
79
+ }
80
+
77
81
/// get the number of tokens the model was trained on
78
82
///
79
83
/// # Panics
@@ -99,28 +103,28 @@ impl LlamaModel {
99
103
/// Get the beginning of stream token.
100
104
#[ must_use]
101
105
pub fn token_bos ( & self ) -> LlamaToken {
102
- let token = unsafe { llama_cpp_sys_2:: llama_token_bos ( self . model . as_ptr ( ) ) } ;
106
+ let token = unsafe { llama_cpp_sys_2:: llama_token_bos ( self . vocab_ptr ( ) ) } ;
103
107
LlamaToken ( token)
104
108
}
105
109
106
110
/// Get the end of stream token.
107
111
#[ must_use]
108
112
pub fn token_eos ( & self ) -> LlamaToken {
109
- let token = unsafe { llama_cpp_sys_2:: llama_token_eos ( self . model . as_ptr ( ) ) } ;
113
+ let token = unsafe { llama_cpp_sys_2:: llama_token_eos ( self . vocab_ptr ( ) ) } ;
110
114
LlamaToken ( token)
111
115
}
112
116
113
117
/// Get the newline token.
114
118
#[ must_use]
115
119
pub fn token_nl ( & self ) -> LlamaToken {
116
- let token = unsafe { llama_cpp_sys_2:: llama_token_nl ( self . model . as_ptr ( ) ) } ;
120
+ let token = unsafe { llama_cpp_sys_2:: llama_token_nl ( self . vocab_ptr ( ) ) } ;
117
121
LlamaToken ( token)
118
122
}
119
123
120
124
/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
121
125
#[ must_use]
122
126
pub fn is_eog_token ( & self , token : LlamaToken ) -> bool {
123
- unsafe { llama_cpp_sys_2:: llama_token_is_eog ( self . model . as_ptr ( ) , token. 0 ) }
127
+ unsafe { llama_cpp_sys_2:: llama_token_is_eog ( self . vocab_ptr ( ) , token. 0 ) }
124
128
}
125
129
126
130
/// Get the decoder start token.
@@ -225,7 +229,7 @@ impl LlamaModel {
225
229
226
230
let size = unsafe {
227
231
llama_cpp_sys_2:: llama_tokenize (
228
- self . model . as_ptr ( ) ,
232
+ self . vocab_ptr ( ) ,
229
233
c_string. as_ptr ( ) ,
230
234
c_int:: try_from ( c_string. as_bytes ( ) . len ( ) ) ?,
231
235
buffer. as_mut_ptr ( ) as * mut llama_cpp_sys_2:: llama_token ,
@@ -241,7 +245,7 @@ impl LlamaModel {
241
245
buffer. reserve_exact ( usize:: try_from ( -size) . expect ( "usize's are larger " ) ) ;
242
246
unsafe {
243
247
llama_cpp_sys_2:: llama_tokenize (
244
- self . model . as_ptr ( ) ,
248
+ self . vocab_ptr ( ) ,
245
249
c_string. as_ptr ( ) ,
246
250
c_int:: try_from ( c_string. as_bytes ( ) . len ( ) ) ?,
247
251
buffer. as_mut_ptr ( ) as * mut llama_cpp_sys_2:: llama_token ,
@@ -268,7 +272,7 @@ impl LlamaModel {
268
272
/// If the token type is not known to this library.
269
273
#[ must_use]
270
274
pub fn token_attr ( & self , LlamaToken ( id) : LlamaToken ) -> LlamaTokenAttrs {
271
- let token_type = unsafe { llama_cpp_sys_2:: llama_token_get_attr ( self . model . as_ptr ( ) , id) } ;
275
+ let token_type = unsafe { llama_cpp_sys_2:: llama_token_get_attr ( self . vocab_ptr ( ) , id) } ;
272
276
LlamaTokenAttrs :: try_from ( token_type) . expect ( "token type is valid" )
273
277
}
274
278
@@ -347,7 +351,7 @@ impl LlamaModel {
347
351
let lstrip = lstrip. map_or ( 0 , |it| i32:: from ( it. get ( ) ) ) ;
348
352
let size = unsafe {
349
353
llama_cpp_sys_2:: llama_token_to_piece (
350
- self . model . as_ptr ( ) ,
354
+ self . vocab_ptr ( ) ,
351
355
token. 0 ,
352
356
buf,
353
357
len,
@@ -374,7 +378,7 @@ impl LlamaModel {
374
378
/// without issue.
375
379
#[ must_use]
376
380
pub fn n_vocab ( & self ) -> i32 {
377
- unsafe { llama_cpp_sys_2:: llama_n_vocab ( self . model . as_ptr ( ) ) }
381
+ unsafe { llama_cpp_sys_2:: llama_n_vocab ( self . vocab_ptr ( ) ) }
378
382
}
379
383
380
384
/// The type of vocab the model was trained on.
@@ -384,7 +388,8 @@ impl LlamaModel {
384
388
/// If llama-cpp emits a vocab type that is not known to this library.
385
389
#[ must_use]
386
390
pub fn vocab_type ( & self ) -> VocabType {
387
- let vocab_type = unsafe { llama_cpp_sys_2:: llama_vocab_type ( self . model . as_ptr ( ) ) } ;
391
+ // llama_cpp_sys_2::llama_model_get_vocab
392
+ let vocab_type = unsafe { llama_cpp_sys_2:: llama_vocab_type ( self . vocab_ptr ( ) ) } ;
388
393
VocabType :: try_from ( vocab_type) . expect ( "invalid vocab type" )
389
394
}
390
395
@@ -479,7 +484,7 @@ impl LlamaModel {
479
484
480
485
let cstr = CString :: new ( path) ?;
481
486
let adapter =
482
- unsafe { llama_cpp_sys_2:: llama_lora_adapter_init ( self . model . as_ptr ( ) , cstr. as_ptr ( ) ) } ;
487
+ unsafe { llama_cpp_sys_2:: llama_adapter_lora_init ( self . model . as_ptr ( ) , cstr. as_ptr ( ) ) } ;
483
488
484
489
let adapter = NonNull :: new ( adapter) . ok_or ( LlamaLoraAdapterInitError :: NullResult ) ?;
485
490
@@ -548,7 +553,6 @@ impl LlamaModel {
548
553
549
554
let res = unsafe {
550
555
llama_cpp_sys_2:: llama_chat_apply_template (
551
- self . model . as_ptr ( ) ,
552
556
tmpl_ptr,
553
557
chat. as_ptr ( ) ,
554
558
chat. len ( ) ,
@@ -563,7 +567,6 @@ impl LlamaModel {
563
567
564
568
let res = unsafe {
565
569
llama_cpp_sys_2:: llama_chat_apply_template (
566
- self . model . as_ptr ( ) ,
567
570
tmpl_ptr,
568
571
chat. as_ptr ( ) ,
569
572
chat. len ( ) ,
0 commit comments