Skip to content

Commit 204ba96

Browse files
committed
Made Vocabulary properties be initialized only ONCE on creation
1 parent f8a7263 commit 204ba96

File tree

1 file changed

+56
-182
lines changed

1 file changed

+56
-182
lines changed

LLama/Native/SafeLlamaModelHandle.cs

+56-182
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Diagnostics;
44
using System.IO;
5+
using System.Linq;
56
using System.Text;
67
using LLama.Exceptions;
78

@@ -631,34 +632,51 @@ public sealed class Vocabulary
631632

632633
internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model);
633634

635+
/// <summary>
636+
/// Cache of all the tokens in the vocabulary, and their string representation
637+
/// </summary>
638+
public readonly IReadOnlyDictionary<LLamaToken, string> TokenToString;
639+
634640
internal Vocabulary(SafeLlamaModelHandle model)
635641
{
636642
_model = model;
637-
}
638-
639-
private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
640-
{
641-
if (!token.HasValue)
642-
return null;
643+
TokenToString = GetVocabCache();
643644

644-
// Try to convert using a fixed size buffer
645-
const int buffSize = 32;
646-
Span<byte> buff = stackalloc byte[buffSize];
647-
var tokenLength = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);
648-
649-
// Negative indicates that there was no result
650-
if (tokenLength <= 0)
651-
return null;
652-
653-
// if the original buffer wasn't large enough, try again with one that's the right size
654-
if (tokenLength > buffSize)
645+
unsafe
655646
{
656-
buff = stackalloc byte[(int)tokenLength];
657-
_ = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);
647+
var vocabNative = llama_model_get_vocab(_model);
648+
Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative);
649+
Type = LLamaVocabNative.llama_vocab_type(vocabNative);
650+
BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative));
651+
EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative));
652+
Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative));
653+
Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative));
654+
SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative));
655+
InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative));
656+
InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative));
657+
InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative));
658+
InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative));
659+
InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative));
660+
InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative));
661+
EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative));
662+
DecoderStartToken = Normalize(llama_model_decoder_start_token(_model));
663+
ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative);
664+
ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative);
658665
}
666+
}
659667

660-
var slice = buff.Slice(0, (int)tokenLength);
661-
return Encoding.UTF8.GetStringFromSpan(slice);
668+
private Dictionary<LLamaToken, string> GetVocabCache()
669+
{
670+
var decoder = Encoding.UTF8.GetDecoder();
671+
var (bytesArr, charsArr) = (new byte[1024], new char[1024]);
672+
return Enumerable.Range(0, Count).ToDictionary(
673+
keySelector: i => (LLamaToken) i,
674+
elementSelector: i =>
675+
{
676+
decoder.Convert(bytesArr, 0, (int) _model.TokenToSpan(i, bytesArr), charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _);
677+
return string.Join("", charsArr.Take(charsUsed));
678+
}
679+
);
662680
}
663681

664682
private static LLamaToken? Normalize(LLamaToken token)
@@ -669,232 +687,88 @@ internal Vocabulary(SafeLlamaModelHandle model)
669687
/// <summary>
670688
/// Total number of tokens in this vocabulary
671689
/// </summary>
672-
public int Count
673-
{
674-
get
675-
{
676-
unsafe
677-
{
678-
return LLamaVocabNative.llama_vocab_n_tokens(VocabNative);
679-
}
680-
}
681-
}
690+
public int Count { get; init; }
682691

683692
/// <summary>
684693
/// Get the the type of this vocabulary
685694
/// </summary>
686-
public LLamaVocabType Type
687-
{
688-
get
689-
{
690-
unsafe
691-
{
692-
return LLamaVocabNative.llama_vocab_type(VocabNative);
693-
}
694-
}
695-
}
695+
public LLamaVocabType Type { get; init; }
696696

697697
/// <summary>
698698
/// Get the Beginning of Sentence token for this model
699699
/// </summary>
700-
public LLamaToken? BOS
701-
{
702-
get
703-
{
704-
unsafe
705-
{
706-
return Normalize(LLamaVocabNative.llama_vocab_bos(VocabNative));
707-
}
708-
}
709-
}
700+
public LLamaToken? BOS { get; init; }
710701

711702
/// <summary>
712703
/// Get the End of Sentence token for this model
713704
/// </summary>
714-
public LLamaToken? EOS
715-
{
716-
get
717-
{
718-
unsafe
719-
{
720-
return Normalize(LLamaVocabNative.llama_vocab_eos(VocabNative));
721-
}
722-
}
723-
}
705+
public LLamaToken? EOS { get; init; }
724706

725707
/// <summary>
726708
/// Get the newline token for this model
727709
/// </summary>
728-
public LLamaToken? Newline
729-
{
730-
get
731-
{
732-
unsafe
733-
{
734-
return Normalize(LLamaVocabNative.llama_vocab_nl(VocabNative));
735-
}
736-
}
737-
}
710+
public LLamaToken? Newline { get; init; }
738711

739712
/// <summary>
740713
/// Get the padding token for this model
741714
/// </summary>
742-
public LLamaToken? Pad
743-
{
744-
get
745-
{
746-
unsafe
747-
{
748-
return Normalize(LLamaVocabNative.llama_vocab_pad(VocabNative));
749-
}
750-
}
751-
}
715+
public LLamaToken? Pad { get; init; }
752716

753717
/// <summary>
754718
/// Get the sentence separator token for this model
755719
/// </summary>
756-
public LLamaToken? SEP
757-
{
758-
get
759-
{
760-
unsafe
761-
{
762-
return Normalize(LLamaVocabNative.llama_vocab_sep(VocabNative));
763-
}
764-
}
765-
}
720+
public LLamaToken? SEP { get; init; }
766721

767722
/// <summary>
768723
/// Codellama beginning of infill prefix
769724
/// </summary>
770-
public LLamaToken? InfillPrefix
771-
{
772-
get
773-
{
774-
unsafe
775-
{
776-
return Normalize(LLamaVocabNative.llama_vocab_fim_pre(VocabNative));
777-
}
778-
}
779-
}
725+
public LLamaToken? InfillPrefix { get; init; }
780726

781727
/// <summary>
782728
/// Codellama beginning of infill middle
783729
/// </summary>
784-
public LLamaToken? InfillMiddle
785-
{
786-
get
787-
{
788-
unsafe
789-
{
790-
return Normalize(LLamaVocabNative.llama_vocab_fim_mid(VocabNative));
791-
}
792-
}
793-
}
730+
public LLamaToken? InfillMiddle { get; init; }
794731

795732
/// <summary>
796733
/// Codellama beginning of infill suffix
797734
/// </summary>
798-
public LLamaToken? InfillSuffix
799-
{
800-
get
801-
{
802-
unsafe
803-
{
804-
return Normalize(LLamaVocabNative.llama_vocab_fim_suf(VocabNative));
805-
}
806-
}
807-
}
735+
public LLamaToken? InfillSuffix { get; init; }
808736

809737
/// <summary>
810738
/// Codellama pad
811739
/// </summary>
812-
public LLamaToken? InfillPad
813-
{
814-
get
815-
{
816-
unsafe
817-
{
818-
return Normalize(LLamaVocabNative.llama_vocab_fim_pad(VocabNative));
819-
}
820-
}
821-
}
740+
public LLamaToken? InfillPad { get; init; }
822741

823742
/// <summary>
824743
/// Codellama rep
825744
/// </summary>
826-
public LLamaToken? InfillRep
827-
{
828-
get
829-
{
830-
unsafe
831-
{
832-
return Normalize(LLamaVocabNative.llama_vocab_fim_rep(VocabNative));
833-
}
834-
}
835-
}
745+
public LLamaToken? InfillRep { get; init; }
836746

837747
/// <summary>
838748
/// Codellama rep
839749
/// </summary>
840-
public LLamaToken? InfillSep
841-
{
842-
get
843-
{
844-
unsafe
845-
{
846-
return Normalize(LLamaVocabNative.llama_vocab_fim_sep(VocabNative));
847-
}
848-
}
849-
}
750+
public LLamaToken? InfillSep { get; init; }
850751

851752
/// <summary>
852753
/// end-of-turn token
853754
/// </summary>
854-
public LLamaToken? EOT
855-
{
856-
get
857-
{
858-
unsafe
859-
{
860-
return Normalize(LLamaVocabNative.llama_vocab_eot(VocabNative));
861-
}
862-
}
863-
}
755+
public LLamaToken? EOT { get; init; }
864756

865757
/// <summary>
866758
/// For encoder-decoder models, this function returns id of the token that must be provided
867759
/// to the decoder to start generating output sequence.
868760
/// </summary>
869-
public LLamaToken? DecoderStartToken => Normalize(llama_model_decoder_start_token(_model));
761+
public LLamaToken? DecoderStartToken { get; init; }
870762

871763
/// <summary>
872764
/// Check if the current model requires a BOS token added
873765
/// </summary>
874-
public bool ShouldAddBOS
875-
{
876-
get
877-
{
878-
unsafe
879-
{
880-
return LLamaVocabNative.llama_vocab_get_add_bos(llama_model_get_vocab(_model));
881-
}
882-
}
883-
}
766+
public bool ShouldAddBOS { get; init; }
884767

885768
/// <summary>
886769
/// Check if the current model requires a EOS token added
887770
/// </summary>
888-
public bool ShouldAddEOS
889-
{
890-
get
891-
{
892-
unsafe
893-
{
894-
return LLamaVocabNative.llama_vocab_get_add_eos(llama_model_get_vocab(_model));
895-
}
896-
}
897-
}
771+
public bool ShouldAddEOS { get; init; }
898772
}
899773
}
900774
}

0 commit comments

Comments
 (0)