2
2
using System . Collections . Generic ;
3
3
using System . Diagnostics ;
4
4
using System . IO ;
5
+ using System . Linq ;
5
6
using System . Text ;
6
7
using LLama . Exceptions ;
7
8
@@ -631,34 +632,51 @@ public sealed class Vocabulary
631
632
632
633
internal unsafe LLamaVocabNative * VocabNative => llama_model_get_vocab ( _model ) ;
633
634
635
+ /// <summary>
636
+ /// Cache of all the tokens in the vocabulary, and their string representation
637
+ /// </summary>
638
+ public readonly IReadOnlyDictionary < LLamaToken , string > TokenToString ;
639
+
634
640
internal Vocabulary ( SafeLlamaModelHandle model )
635
641
{
636
642
_model = model ;
637
- }
638
-
639
- private string ? LLamaTokenToString ( LLamaToken ? token , bool isSpecialToken )
640
- {
641
- if ( ! token . HasValue )
642
- return null ;
643
+ TokenToString = GetVocabCache ( ) ;
643
644
644
- // Try to convert using a fixed size buffer
645
- const int buffSize = 32 ;
646
- Span < byte > buff = stackalloc byte [ buffSize ] ;
647
- var tokenLength = _model . TokenToSpan ( ( LLamaToken ) token , buff , special : isSpecialToken ) ;
648
-
649
- // Negative indicates that there was no result
650
- if ( tokenLength <= 0 )
651
- return null ;
652
-
653
- // if the original buffer wasn't large enough, try again with one that's the right size
654
- if ( tokenLength > buffSize )
645
+ unsafe
655
646
{
656
- buff = stackalloc byte [ ( int ) tokenLength ] ;
657
- _ = _model . TokenToSpan ( ( LLamaToken ) token , buff , special : isSpecialToken ) ;
647
+ var vocabNative = llama_model_get_vocab ( _model ) ;
648
+ Count = LLamaVocabNative . llama_vocab_n_tokens ( vocabNative ) ;
649
+ Type = LLamaVocabNative . llama_vocab_type ( vocabNative ) ;
650
+ BOS = Normalize ( LLamaVocabNative . llama_vocab_bos ( vocabNative ) ) ;
651
+ EOS = Normalize ( LLamaVocabNative . llama_vocab_eos ( vocabNative ) ) ;
652
+ Newline = Normalize ( LLamaVocabNative . llama_vocab_nl ( vocabNative ) ) ;
653
+ Pad = Normalize ( LLamaVocabNative . llama_vocab_pad ( vocabNative ) ) ;
654
+ SEP = Normalize ( LLamaVocabNative . llama_vocab_sep ( vocabNative ) ) ;
655
+ InfillPrefix = Normalize ( LLamaVocabNative . llama_vocab_fim_pre ( vocabNative ) ) ;
656
+ InfillMiddle = Normalize ( LLamaVocabNative . llama_vocab_fim_mid ( vocabNative ) ) ;
657
+ InfillSuffix = Normalize ( LLamaVocabNative . llama_vocab_fim_suf ( vocabNative ) ) ;
658
+ InfillPad = Normalize ( LLamaVocabNative . llama_vocab_fim_pad ( vocabNative ) ) ;
659
+ InfillRep = Normalize ( LLamaVocabNative . llama_vocab_fim_rep ( vocabNative ) ) ;
660
+ InfillSep = Normalize ( LLamaVocabNative . llama_vocab_fim_sep ( vocabNative ) ) ;
661
+ EOT = Normalize ( LLamaVocabNative . llama_vocab_eot ( vocabNative ) ) ;
662
+ DecoderStartToken = Normalize ( llama_model_decoder_start_token ( _model ) ) ;
663
+ ShouldAddBOS = LLamaVocabNative . llama_vocab_get_add_bos ( vocabNative ) ;
664
+ ShouldAddEOS = LLamaVocabNative . llama_vocab_get_add_eos ( vocabNative ) ;
658
665
}
666
+ }
659
667
660
- var slice = buff . Slice ( 0 , ( int ) tokenLength ) ;
661
- return Encoding . UTF8 . GetStringFromSpan ( slice ) ;
668
+ private Dictionary < LLamaToken , string > GetVocabCache ( )
669
+ {
670
+ var decoder = Encoding . UTF8 . GetDecoder ( ) ;
671
+ var ( bytesArr , charsArr ) = ( new byte [ 1024 ] , new char [ 1024 ] ) ;
672
+ return Enumerable . Range ( 0 , Count ) . ToDictionary (
673
+ keySelector : i => ( LLamaToken ) i ,
674
+ elementSelector : i =>
675
+ {
676
+ decoder . Convert ( bytesArr , 0 , ( int ) _model . TokenToSpan ( i , bytesArr ) , charsArr , 0 , charsArr . Length , true , out var _ , out var charsUsed , out var _ ) ;
677
+ return string . Join ( "" , charsArr . Take ( charsUsed ) ) ;
678
+ }
679
+ ) ;
662
680
}
663
681
664
682
private static LLamaToken ? Normalize ( LLamaToken token )
@@ -669,232 +687,88 @@ internal Vocabulary(SafeLlamaModelHandle model)
669
687
/// <summary>
670
688
/// Total number of tokens in this vocabulary
671
689
/// </summary>
672
- public int Count
673
- {
674
- get
675
- {
676
- unsafe
677
- {
678
- return LLamaVocabNative . llama_vocab_n_tokens ( VocabNative ) ;
679
- }
680
- }
681
- }
690
+ public int Count { get ; init ; }
682
691
683
692
/// <summary>
684
693
/// Get the the type of this vocabulary
685
694
/// </summary>
686
- public LLamaVocabType Type
687
- {
688
- get
689
- {
690
- unsafe
691
- {
692
- return LLamaVocabNative . llama_vocab_type ( VocabNative ) ;
693
- }
694
- }
695
- }
695
+ public LLamaVocabType Type { get ; init ; }
696
696
697
697
/// <summary>
698
698
/// Get the Beginning of Sentence token for this model
699
699
/// </summary>
700
- public LLamaToken ? BOS
701
- {
702
- get
703
- {
704
- unsafe
705
- {
706
- return Normalize ( LLamaVocabNative . llama_vocab_bos ( VocabNative ) ) ;
707
- }
708
- }
709
- }
700
+ public LLamaToken ? BOS { get ; init ; }
710
701
711
702
/// <summary>
712
703
/// Get the End of Sentence token for this model
713
704
/// </summary>
714
- public LLamaToken ? EOS
715
- {
716
- get
717
- {
718
- unsafe
719
- {
720
- return Normalize ( LLamaVocabNative . llama_vocab_eos ( VocabNative ) ) ;
721
- }
722
- }
723
- }
705
+ public LLamaToken ? EOS { get ; init ; }
724
706
725
707
/// <summary>
726
708
/// Get the newline token for this model
727
709
/// </summary>
728
- public LLamaToken ? Newline
729
- {
730
- get
731
- {
732
- unsafe
733
- {
734
- return Normalize ( LLamaVocabNative . llama_vocab_nl ( VocabNative ) ) ;
735
- }
736
- }
737
- }
710
+ public LLamaToken ? Newline { get ; init ; }
738
711
739
712
/// <summary>
740
713
/// Get the padding token for this model
741
714
/// </summary>
742
- public LLamaToken ? Pad
743
- {
744
- get
745
- {
746
- unsafe
747
- {
748
- return Normalize ( LLamaVocabNative . llama_vocab_pad ( VocabNative ) ) ;
749
- }
750
- }
751
- }
715
+ public LLamaToken ? Pad { get ; init ; }
752
716
753
717
/// <summary>
754
718
/// Get the sentence separator token for this model
755
719
/// </summary>
756
- public LLamaToken ? SEP
757
- {
758
- get
759
- {
760
- unsafe
761
- {
762
- return Normalize ( LLamaVocabNative . llama_vocab_sep ( VocabNative ) ) ;
763
- }
764
- }
765
- }
720
+ public LLamaToken ? SEP { get ; init ; }
766
721
767
722
/// <summary>
768
723
/// Codellama beginning of infill prefix
769
724
/// </summary>
770
- public LLamaToken ? InfillPrefix
771
- {
772
- get
773
- {
774
- unsafe
775
- {
776
- return Normalize ( LLamaVocabNative . llama_vocab_fim_pre ( VocabNative ) ) ;
777
- }
778
- }
779
- }
725
+ public LLamaToken ? InfillPrefix { get ; init ; }
780
726
781
727
/// <summary>
782
728
/// Codellama beginning of infill middle
783
729
/// </summary>
784
- public LLamaToken ? InfillMiddle
785
- {
786
- get
787
- {
788
- unsafe
789
- {
790
- return Normalize ( LLamaVocabNative . llama_vocab_fim_mid ( VocabNative ) ) ;
791
- }
792
- }
793
- }
730
+ public LLamaToken ? InfillMiddle { get ; init ; }
794
731
795
732
/// <summary>
796
733
/// Codellama beginning of infill suffix
797
734
/// </summary>
798
- public LLamaToken ? InfillSuffix
799
- {
800
- get
801
- {
802
- unsafe
803
- {
804
- return Normalize ( LLamaVocabNative . llama_vocab_fim_suf ( VocabNative ) ) ;
805
- }
806
- }
807
- }
735
+ public LLamaToken ? InfillSuffix { get ; init ; }
808
736
809
737
/// <summary>
810
738
/// Codellama pad
811
739
/// </summary>
812
- public LLamaToken ? InfillPad
813
- {
814
- get
815
- {
816
- unsafe
817
- {
818
- return Normalize ( LLamaVocabNative . llama_vocab_fim_pad ( VocabNative ) ) ;
819
- }
820
- }
821
- }
740
+ public LLamaToken ? InfillPad { get ; init ; }
822
741
823
742
/// <summary>
824
743
/// Codellama rep
825
744
/// </summary>
826
- public LLamaToken ? InfillRep
827
- {
828
- get
829
- {
830
- unsafe
831
- {
832
- return Normalize ( LLamaVocabNative . llama_vocab_fim_rep ( VocabNative ) ) ;
833
- }
834
- }
835
- }
745
+ public LLamaToken ? InfillRep { get ; init ; }
836
746
837
747
/// <summary>
838
748
/// Codellama rep
839
749
/// </summary>
840
- public LLamaToken ? InfillSep
841
- {
842
- get
843
- {
844
- unsafe
845
- {
846
- return Normalize ( LLamaVocabNative . llama_vocab_fim_sep ( VocabNative ) ) ;
847
- }
848
- }
849
- }
750
+ public LLamaToken ? InfillSep { get ; init ; }
850
751
851
752
/// <summary>
852
753
/// end-of-turn token
853
754
/// </summary>
854
- public LLamaToken ? EOT
855
- {
856
- get
857
- {
858
- unsafe
859
- {
860
- return Normalize ( LLamaVocabNative . llama_vocab_eot ( VocabNative ) ) ;
861
- }
862
- }
863
- }
755
+ public LLamaToken ? EOT { get ; init ; }
864
756
865
757
/// <summary>
866
758
/// For encoder-decoder models, this function returns id of the token that must be provided
867
759
/// to the decoder to start generating output sequence.
868
760
/// </summary>
869
- public LLamaToken ? DecoderStartToken => Normalize ( llama_model_decoder_start_token ( _model ) ) ;
761
+ public LLamaToken ? DecoderStartToken { get ; init ; }
870
762
871
763
/// <summary>
872
764
/// Check if the current model requires a BOS token added
873
765
/// </summary>
874
- public bool ShouldAddBOS
875
- {
876
- get
877
- {
878
- unsafe
879
- {
880
- return LLamaVocabNative . llama_vocab_get_add_bos ( llama_model_get_vocab ( _model ) ) ;
881
- }
882
- }
883
- }
766
+ public bool ShouldAddBOS { get ; init ; }
884
767
885
768
/// <summary>
886
769
/// Check if the current model requires a EOS token added
887
770
/// </summary>
888
- public bool ShouldAddEOS
889
- {
890
- get
891
- {
892
- unsafe
893
- {
894
- return LLamaVocabNative . llama_vocab_get_add_eos ( llama_model_get_vocab ( _model ) ) ;
895
- }
896
- }
897
- }
771
+ public bool ShouldAddEOS { get ; init ; }
898
772
}
899
773
}
900
774
}
0 commit comments