Skip to content

Commit 1df9568

Browse files
committed
Added cache for EOG and Control tokens to Vocabulary
1 parent 204ba96 commit 1df9568

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

LLama/Native/LLamaToken.cs

+2-8
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,7 @@ public bool IsControl(SafeLlamaModelHandle model)
9898
/// <returns></returns>
9999
public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab)
100100
{
101-
unsafe
102-
{
103-
return LLamaVocabNative.llama_vocab_is_control(vocab.VocabNative, this);
104-
}
101+
return vocab.ControlTokens.Contains(this);
105102
}
106103

107104
/// <summary>
@@ -121,10 +118,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model)
121118
/// <returns></returns>
122119
public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab)
123120
{
124-
unsafe
125-
{
126-
return LLamaVocabNative.llama_vocab_is_eog(vocab.VocabNative, this);
127-
}
121+
return vocab.EOGTokens.Contains(this);
128122
}
129123

130124
/// <inheritdoc />

LLama/Native/SafeLlamaModelHandle.cs

+15-1
Original file line numberDiff line numberDiff line change
@@ -633,15 +633,26 @@ public sealed class Vocabulary
633633
internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model);
634634

635635
/// <summary>
636-
/// Cache of all the tokens in the vocabulary, and their string representation
636+
/// Map of each token in this vocabulary to its string representation
637637
/// </summary>
638638
public readonly IReadOnlyDictionary<LLamaToken, string> TokenToString;
639639

640+
/// <summary>
641+
/// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc)
642+
/// </summary>
643+
public readonly HashSet<LLamaToken> EOGTokens;
644+
645+
/// <summary>
646+
/// Contains unique tokens that exist for inference control rather than text output
647+
/// </summary>
648+
public readonly HashSet<LLamaToken> ControlTokens;
649+
640650
internal Vocabulary(SafeLlamaModelHandle model)
641651
{
642652
_model = model;
643653
TokenToString = GetVocabCache();
644654

655+
// Cache the various properties that llama.cpp API exposes about the vocab
645656
unsafe
646657
{
647658
var vocabNative = llama_model_get_vocab(_model);
@@ -662,6 +673,9 @@ internal Vocabulary(SafeLlamaModelHandle model)
662673
DecoderStartToken = Normalize(llama_model_decoder_start_token(_model));
663674
ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative);
664675
ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative);
676+
677+
EOGTokens = new HashSet<LLamaToken>(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)));
678+
ControlTokens = new HashSet<LLamaToken>(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)));
665679
}
666680
}
667681

0 commit comments

Comments
 (0)