diff --git a/tiktoken-rs/src/api.rs b/tiktoken-rs/src/api.rs index 27cc7e4..c39b6a3 100644 --- a/tiktoken-rs/src/api.rs +++ b/tiktoken-rs/src/api.rs @@ -1,9 +1,14 @@ +use std::sync::Arc; + use anyhow::{anyhow, Result}; +use lazy_static::lazy_static; +use parking_lot::Mutex; use crate::{ cl100k_base, model::get_context_size, o200k_base, p50k_base, p50k_edit, r50k_base, + singleton::*, tokenizer::{get_tokenizer, Tokenizer}, CoreBPE, }; @@ -41,7 +46,8 @@ use crate::{ /// based on the given model and prompt. pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result { let context_size = get_context_size(model); - let bpe = get_bpe_from_model(model)?; + let bpe_mutex = get_bpe_from_model(model)?; + let bpe = bpe_mutex.lock(); let prompt_tokens = bpe.encode_with_special_tokens(prompt).len(); Ok(context_size.saturating_sub(prompt_tokens)) } @@ -102,7 +108,8 @@ pub fn num_tokens_from_messages( if tokenizer != Tokenizer::Cl100kBase && tokenizer != Tokenizer::O200kBase { anyhow::bail!("Chat completion is only supported chat models") } - let bpe = get_bpe_from_tokenizer(tokenizer)?; + let bpe_mutex = get_bpe_from_tokenizer(tokenizer)?; + let bpe = bpe_mutex.lock(); let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") { ( @@ -220,7 +227,7 @@ pub fn get_chat_completion_max_tokens( /// # Returns /// /// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the tokenizer used by the given model. -pub fn get_bpe_from_model(model: &str) -> Result { +pub fn get_bpe_from_model(model: &str) -> Result>> { let tokenizer = get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?; let bpe = get_bpe_from_tokenizer(tokenizer)?; @@ -253,15 +260,16 @@ pub fn get_bpe_from_model(model: &str) -> Result { /// # Returns /// /// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the given tokenizer. -pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result { - match tokenizer { - Tokenizer::O200kBase => o200k_base(), - Tokenizer::Cl100kBase => cl100k_base(), - Tokenizer::R50kBase => r50k_base(), - Tokenizer::P50kBase => p50k_base(), - Tokenizer::P50kEdit => p50k_edit(), - Tokenizer::Gpt2 => r50k_base(), - } +pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result>> { + let tok = match tokenizer { + Tokenizer::O200kBase => o200k_base_singleton(), + Tokenizer::Cl100kBase => cl100k_base_singleton(), + Tokenizer::R50kBase => r50k_base_singleton(), + Tokenizer::P50kBase => p50k_base_singleton(), + Tokenizer::P50kEdit => p50k_edit_singleton(), + Tokenizer::Gpt2 => r50k_base_singleton(), + }; + Ok(tok) } #[cfg(test)]