-
Notifications
You must be signed in to change notification settings - Fork 73
fix(tokenizer): support Kimi K2/K2.5/K2.6 tiktoken models #1482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
c2df5be
9eb42a1
9c48829
7000219
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,240 @@ | ||
| //! Kimi-K2 / K2.5 / K2.6 detection and special-token helpers. | ||
| //! | ||
| //! Kimi models use the standard tiktoken BPE engine but with a Han-aware regex | ||
| //! and a 256-slot reserved-special-token range starting at `len(mergeable_ranks)`. | ||
| //! These helpers let the generic `TiktokenTokenizer` loader specialize itself | ||
| //! when it sees a Kimi directory, without exposing a separate public type. | ||
| //! | ||
| //! Upstream reference (identical across all three Kimi variants): | ||
| //! - moonshotai/Kimi-K2-Thinking/tokenization_kimi.py | ||
| //! - moonshotai/Kimi-K2.5/tokenization_kimi.py | ||
| //! - moonshotai/Kimi-K2.6/tokenization_kimi.py | ||
|
|
||
| use std::{ | ||
| collections::{HashMap, HashSet}, | ||
| path::Path, | ||
| }; | ||
|
|
||
| use serde_json::Value; | ||
|
|
||
| use crate::traits::TokenIdType; | ||
|
|
||
| const NUM_RESERVED_SPECIAL_TOKENS: usize = 256; | ||
|
|
||
| /// Han-aware tokenization regex used by Kimi K2/K2.5/K2.6. Byte-identical to | ||
| /// the `pat_str` in upstream `tokenization_kimi.py`. | ||
| pub(crate) const KIMI_K2_PATTERN: &str = r"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; | ||
|
|
||
| /// Returns true if `dir` looks like a Kimi K2/K2.5/K2.6 model directory. | ||
| /// | ||
| /// Primary signal: `tokenizer_config.json` references `tokenization_kimi` | ||
| /// (via `auto_map`, `tokenizer_class`, etc.). Fallback: `config.json::model_type` | ||
| /// is one of the known Kimi values. | ||
| pub(crate) fn matches_dir(dir: &Path) -> bool { | ||
| if let Some(config) = read_json(&dir.join("tokenizer_config.json")) { | ||
| if value_mentions_kimi_tokenizer(&config) { | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| read_json(&dir.join("config.json")).is_some_and(|config| model_config_is_kimi(&config)) | ||
| } | ||
|
|
||
| /// Fill the 256-slot reserved special-token range starting at `base_vocab_size` | ||
| /// with synthetic `<|reserved_token_{id}|>` entries, preserving any explicit | ||
| /// `added_tokens_decoder` entries that already occupy slots in that range. | ||
| /// | ||
| /// Mirrors upstream `tokenization_kimi.py`: | ||
| /// ```python | ||
| /// {special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i | ||
| /// for i in range(num_base_tokens, num_base_tokens + 256)} | ||
| /// ``` | ||
| /// where `num_base_tokens = len(mergeable_ranks)` — i.e., `encoder.len()`. | ||
| pub(crate) fn apply_reserved_special_tokens( | ||
| added_tokens: &mut HashMap<String, TokenIdType>, | ||
| base_vocab_size: usize, | ||
| ) { | ||
| let Ok(start) = TokenIdType::try_from(base_vocab_size) else { | ||
| return; | ||
| }; | ||
|
|
||
| let occupied_ids: HashSet<TokenIdType> = added_tokens.values().copied().collect(); | ||
| for offset in 0..NUM_RESERVED_SPECIAL_TOKENS { | ||
| let id = start + offset as TokenIdType; | ||
| if occupied_ids.contains(&id) { | ||
| continue; | ||
| } | ||
|
|
||
| added_tokens | ||
| .entry(format!("<|reserved_token_{id}|>")) | ||
| .or_insert(id); | ||
| } | ||
| } | ||
|
|
||
| fn read_json(path: &Path) -> Option<Value> { | ||
| let content = std::fs::read_to_string(path).ok()?; | ||
| serde_json::from_str(&content).ok() | ||
| } | ||
|
|
||
| fn model_config_is_kimi(config: &Value) -> bool { | ||
| let model_type = config.get("model_type").and_then(Value::as_str); | ||
| matches!(model_type, Some("kimi_k2") | Some("kimi_k25")) | ||
|
CatherineSue marked this conversation as resolved.
|
||
| } | ||
|
|
||
| fn value_mentions_kimi_tokenizer(value: &Value) -> bool { | ||
| match value { | ||
| Value::String(s) => s.contains("tokenization_kimi"), | ||
| Value::Array(values) => values.iter().any(value_mentions_kimi_tokenizer), | ||
| Value::Object(map) => map.values().any(value_mentions_kimi_tokenizer), | ||
| _ => false, | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use base64::{engine::general_purpose::STANDARD, Engine as _}; | ||
|
|
||
| use super::*; | ||
| use crate::{ | ||
| tiktoken::TiktokenTokenizer, | ||
| traits::{Decoder, Encoder, Tokenizer}, | ||
| }; | ||
|
|
||
| // Minimal BPE: bytes 'a' (rank 0), 'b' (rank 1). Used for tests that only | ||
| // exercise decode of synthetic special tokens (no BPE encode). | ||
| const MINIMAL_TIKTOKEN_MODEL: &str = "YQ== 0\nYg== 1\n"; | ||
| // Minimal BPE: "hello's" (rank 0), "hello" (rank 1), "'s" (rank 2). | ||
| const CONTRACTION_TIKTOKEN_MODEL: &str = "aGVsbG8ncw== 0\naGVsbG8= 1\nJ3M= 2\n"; | ||
|
|
||
| /// Build a tiktoken model file with all 256 single-byte tokens (ranks 0..256) | ||
| /// plus the given multi-byte tokens at successive ranks starting at 256. | ||
| /// tiktoken's BPE requires every input byte to have a rank, so a real-world | ||
| /// encode test needs the full byte layer. | ||
| fn full_byte_bpe(extra: &[&[u8]]) -> String { | ||
| let mut out = String::new(); | ||
| for b in 0u32..256 { | ||
| out.push_str(&format!("{} {}\n", STANDARD.encode([b as u8]), b)); | ||
| } | ||
| for (offset, bytes) in extra.iter().enumerate() { | ||
| out.push_str(&format!("{} {}\n", STANDARD.encode(bytes), 256 + offset)); | ||
| } | ||
| out | ||
| } | ||
|
|
||
| fn write_kimi_dir(model: &str, tokenizer_config: &str) -> tempfile::TempDir { | ||
| let dir = tempfile::tempdir().unwrap(); | ||
| std::fs::write(dir.path().join("tiktoken.model"), model).unwrap(); | ||
| std::fs::write(dir.path().join("tokenizer_config.json"), tokenizer_config).unwrap(); | ||
| dir | ||
| } | ||
|
|
||
| const KIMI_AUTO_MAP_CONFIG: &str = r#"{ | ||
| "tokenizer_class": "TikTokenTokenizer", | ||
| "auto_map": { | ||
| "AutoTokenizer": ["tokenization_kimi.TikTokenTokenizer", null] | ||
| } | ||
| }"#; | ||
|
|
||
| #[test] | ||
| fn reserved_special_tokens_are_synthesized() { | ||
| let dir = write_kimi_dir( | ||
| MINIMAL_TIKTOKEN_MODEL, | ||
| r#"{ | ||
| "tokenizer_class": "TikTokenTokenizer", | ||
| "auto_map": { | ||
| "AutoTokenizer": ["tokenization_kimi.TikTokenTokenizer", null] | ||
| }, | ||
| "added_tokens_decoder": { | ||
| "2": { "content": "[BOS]", "special": true }, | ||
| "5": { "content": "<|im_assistant|>", "special": true } | ||
| } | ||
| }"#, | ||
| ); | ||
| let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap(); | ||
|
|
||
| assert_eq!(tokenizer.vocab_size(), 258); | ||
| assert_eq!( | ||
| tokenizer.decode(&[4], false).unwrap(), | ||
| "<|reserved_token_4|>" | ||
| ); | ||
| assert_eq!(tokenizer.decode(&[5], false).unwrap(), "<|im_assistant|>"); | ||
| assert_eq!(tokenizer.token_to_id("<|reserved_token_4|>"), Some(4)); | ||
| assert_eq!( | ||
| tokenizer.id_to_token(4).as_deref(), | ||
| Some("<|reserved_token_4|>") | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn matches_via_model_type_kimi_k2() { | ||
| let dir = tempfile::tempdir().unwrap(); | ||
| std::fs::write(dir.path().join("tiktoken.model"), MINIMAL_TIKTOKEN_MODEL).unwrap(); | ||
| std::fs::write( | ||
| dir.path().join("tokenizer_config.json"), | ||
| r#"{ "added_tokens_decoder": {} }"#, | ||
| ) | ||
| .unwrap(); | ||
| std::fs::write( | ||
| dir.path().join("config.json"), | ||
| r#"{ "model_type": "kimi_k2" }"#, | ||
| ) | ||
| .unwrap(); | ||
|
|
||
| assert!(matches_dir(dir.path())); | ||
| // Round-trip a synthetic reserved token to confirm Kimi load path was taken. | ||
| let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap(); | ||
| assert_eq!( | ||
| tokenizer.decode(&[42], false).unwrap(), | ||
| "<|reserved_token_42|>" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn matches_via_model_type_kimi_k25() { | ||
| let dir = tempfile::tempdir().unwrap(); | ||
| std::fs::write(dir.path().join("tiktoken.model"), MINIMAL_TIKTOKEN_MODEL).unwrap(); | ||
| std::fs::write( | ||
| dir.path().join("tokenizer_config.json"), | ||
| r#"{ "added_tokens_decoder": {} }"#, | ||
| ) | ||
| .unwrap(); | ||
| std::fs::write( | ||
| dir.path().join("config.json"), | ||
| r#"{ "model_type": "kimi_k25" }"#, | ||
| ) | ||
| .unwrap(); | ||
|
|
||
| assert!(matches_dir(dir.path())); | ||
| } | ||
|
|
||
| #[test] | ||
| fn uses_kimi_pattern_for_contractions() { | ||
| let dir = write_kimi_dir(CONTRACTION_TIKTOKEN_MODEL, KIMI_AUTO_MAP_CONFIG); | ||
| let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap(); | ||
| // The Kimi regex keeps "hello's" as a single match (contraction handling | ||
| // in the third alternation), so the BPE returns rank 0. | ||
| assert_eq!( | ||
| tokenizer.encode("hello's", false).unwrap().token_ids(), | ||
| &[0] | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn han_input_round_trips_through_kimi_pattern() { | ||
| // The Kimi regex's leading alternation is `[\p{Han}]+`. The main | ||
| // regressions this guards against are (a) the character-class | ||
| // intersection `[X&&[^\p{Han}]]` failing to compile under tiktoken-rs's | ||
| // fancy-regex backend, and (b) Han input being rejected at the | ||
| // pre-tokenizer. A minimal synthetic BPE can't reproduce a real Kimi | ||
| // vocab, so we assert byte-level round-trip rather than exact token | ||
| // IDs: encode must not panic, and decode must reconstruct the input. | ||
| let model = full_byte_bpe(&[]); | ||
| let dir = write_kimi_dir(&model, KIMI_AUTO_MAP_CONFIG); | ||
| let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap(); | ||
|
|
||
| let text = "你好世界 hello!"; | ||
| let encoding = tokenizer.encode(text, false).unwrap(); | ||
| let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap(); | ||
| assert_eq!(decoded, text); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,17 +14,14 @@ use crate::{ | |
| ChatTemplateState, ThinkingKeyName, ThinkingToggle, | ||
| }, | ||
| factory::discover_chat_template_in_dir, | ||
| kimi_k2_tokenizer, | ||
| traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait}, | ||
| }; | ||
|
|
||
| /// Regex pattern for cl100k_base tokenization. | ||
| /// | ||
| /// This pattern is correct for OpenAI models and most open-source tiktoken models (e.g. | ||
| /// DeepSeek, Kimi K2). Some models use a different regex — for example, Kimi K2's native | ||
| /// regex includes `\p{Han}` for Chinese character splitting — but encode/decode roundtrips | ||
| /// still work correctly because BPE vocab handles tokenization; the regex only affects exact | ||
| /// token boundary placement. A future enhancement could parse the regex from HuggingFace's | ||
| /// `generation_config.json` or similar metadata. | ||
| /// This pattern is correct for OpenAI models and most open-source tiktoken models. Models | ||
| /// with a tokenizer-specific regex specialize the pattern inside `load_from_path`. | ||
| const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; | ||
|
|
||
| type Rank = u32; | ||
|
|
@@ -62,6 +59,15 @@ fn load_tiktoken_config(config_path: &Path) -> Result<TiktokenConfig> { | |
| }) | ||
| } | ||
|
|
||
| fn load_tiktoken_config_from_dir(dir: &Path) -> Result<TiktokenConfig> { | ||
| let config_path = dir.join("tokenizer_config.json"); | ||
| if config_path.exists() { | ||
| load_tiktoken_config(&config_path) | ||
| } else { | ||
| Ok(TiktokenConfig::default()) | ||
| } | ||
| } | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| /// Parse `added_tokens_decoder` from config JSON. | ||
| /// | ||
| /// Format: `{ "163584": { "content": "[BOS]", "special": true }, ... }` | ||
|
|
@@ -224,11 +230,19 @@ impl TiktokenTokenizer { | |
| let dir = tiktoken_path | ||
| .parent() | ||
| .ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?; | ||
| let config_path = dir.join("tokenizer_config.json"); | ||
| let config = if config_path.exists() { | ||
| load_tiktoken_config(&config_path)? | ||
| let mut config = load_tiktoken_config_from_dir(dir)?; | ||
|
|
||
| // Kimi-K2/K2.5/K2.6 specialize the regex and pre-fill 256 reserved | ||
| // special-token slots starting at `len(mergeable_ranks)`; all other | ||
| // tiktoken models use the cl100k pattern unchanged. | ||
| let pattern = if kimi_k2_tokenizer::matches_dir(dir) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This call to References
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 9c48829. |
||
| kimi_k2_tokenizer::apply_reserved_special_tokens( | ||
| &mut config.added_tokens, | ||
| encoder.len(), | ||
| ); | ||
| kimi_k2_tokenizer::KIMI_K2_PATTERN | ||
| } else { | ||
| TiktokenConfig::default() | ||
| CL100K_BASE_PATTERN | ||
| }; | ||
|
|
||
| // 3. Build special tokens encoder for CoreBPE (needs FxHashMap) | ||
|
|
@@ -248,7 +262,7 @@ impl TiktokenTokenizer { | |
| .map(|id| id as usize + 1) | ||
| .unwrap_or(0); | ||
| let (vocab, reverse_vocab) = build_vocab_maps(&encoder, &config.added_tokens); | ||
| let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, CL100K_BASE_PATTERN)?; | ||
| let tokenizer = CoreBPE::new(encoder, special_tokens_encoder, pattern)?; | ||
|
|
||
| // 5. Load chat template — propagate errors for explicit paths, | ||
| // silently fall back for auto-discovery | ||
|
|
@@ -473,6 +487,9 @@ impl Decoder for TiktokenTokenizer { | |
| fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> { | ||
| match self.tokenizer.decode(token_ids.to_vec()) { | ||
| Ok(text) => Ok(text), | ||
| Err(err) if is_unknown_tiktoken_decode_error(&err) => Err(Error::msg(format!( | ||
| "tiktoken decode failed for unknown token id: {err}" | ||
| ))), | ||
| Err(err) => { | ||
| // Fallback to lossy decoding for incomplete UTF-8 sequences | ||
| let bytes: Vec<u8> = self | ||
|
|
@@ -491,6 +508,17 @@ impl Decoder for TiktokenTokenizer { | |
| } | ||
| } | ||
|
|
||
| /// Detect tiktoken's "unknown token id" error so we can surface a clean error | ||
| /// instead of letting the lossy-decode fallback panic on a missing key. | ||
| /// | ||
| /// We match on the `Display` string because tiktoken-rs's `DecodeKeyError` lives | ||
| /// in a private `vendor_tiktoken` module and isn't re-exported (as of 0.9.1), | ||
| /// so a typed `downcast_ref` is not available. The message format is stable — | ||
| /// see `vendor_tiktoken::DecodeKeyError::fmt` upstream. | ||
| fn is_unknown_tiktoken_decode_error(err: &Error) -> bool { | ||
| err.to_string().starts_with("Invalid token for decoding:") | ||
| } | ||
|
CatherineSue marked this conversation as resolved.
|
||
|
|
||
| impl TokenizerTrait for TiktokenTokenizer { | ||
| fn vocab_size(&self) -> usize { | ||
| self.vocab_size | ||
|
|
@@ -557,6 +585,21 @@ mod tests { | |
| use super::*; | ||
| use crate::traits::{Decoder, Encoder, Tokenizer}; | ||
|
|
||
| const MINIMAL_TIKTOKEN_MODEL: &str = "YQ== 0\nYg== 1\n"; | ||
|
|
||
| fn write_minimal_tiktoken_dir( | ||
| tokenizer_config: &str, | ||
| model_config: Option<&str>, | ||
| ) -> tempfile::TempDir { | ||
| let dir = tempfile::tempdir().unwrap(); | ||
| std::fs::write(dir.path().join("tiktoken.model"), MINIMAL_TIKTOKEN_MODEL).unwrap(); | ||
| std::fs::write(dir.path().join("tokenizer_config.json"), tokenizer_config).unwrap(); | ||
| if let Some(model_config) = model_config { | ||
| std::fs::write(dir.path().join("config.json"), model_config).unwrap(); | ||
| } | ||
| dir | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_tiktoken_creation() { | ||
| let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); | ||
|
|
@@ -754,6 +797,26 @@ mod tests { | |
| assert_eq!(tokens.get("<|im_end|>"), Some(&163586)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_tiktoken_unknown_token_decode_returns_error() { | ||
| let dir = write_minimal_tiktoken_dir( | ||
| r#"{ | ||
| "added_tokens_decoder": { | ||
| "2": { "content": "[BOS]", "special": true } | ||
| } | ||
| }"#, | ||
| None, | ||
| ); | ||
| let tokenizer = TiktokenTokenizer::from_dir(dir.path()).unwrap(); | ||
|
|
||
| let err = tokenizer.decode(&[4], false).unwrap_err(); | ||
| assert!( | ||
| err.to_string() | ||
| .contains("tiktoken decode failed for unknown token id"), | ||
| "unexpected error: {err}" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_parse_special_tokens() { | ||
| let config: serde_json::Value = serde_json::json!({ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.