diff --git a/aworld/models/openai_tokenizer.py b/aworld/models/openai_tokenizer.py index b1319de48..0f60947ba 100644 --- a/aworld/models/openai_tokenizer.py +++ b/aworld/models/openai_tokenizer.py @@ -34,15 +34,29 @@ ENDOFTEXT: 100256, } +# Global cache to prevent memory leaks from repeatedly loading BPE files +_BPE_CACHE = {} + def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: - """Load tiktoken BPE file similar to qwen_tokenizer.""" + """Load tiktoken BPE file with caching to prevent memory leaks.""" + # Check cache first + if tiktoken_bpe_file in _BPE_CACHE: + return _BPE_CACHE[tiktoken_bpe_file] + + # Load and decode file with open(tiktoken_bpe_file, 'rb') as f: contents = f.read() - return { - base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + + result = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) } + # Cache the result + _BPE_CACHE[tiktoken_bpe_file] = result + return result + class OpenAITokenizer: """OpenAI tokenizer using local tiktoken file.""" diff --git a/aworld/models/qwen_tokenizer.py b/aworld/models/qwen_tokenizer.py index 73c03afca..1d3775451 100644 --- a/aworld/models/qwen_tokenizer.py +++ b/aworld/models/qwen_tokenizer.py @@ -45,14 +45,29 @@ )) SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) +# Global cache to prevent memory leaks from repeatedly loading BPE files +_BPE_CACHE = {} + def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + """Load tiktoken BPE file with caching to prevent memory leaks.""" + # Check cache first + if tiktoken_bpe_file in _BPE_CACHE: + return _BPE_CACHE[tiktoken_bpe_file] + + # Load and decode file with open(tiktoken_bpe_file, 'rb') as f: contents = f.read() - return { - base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + + result = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) } + # Cache the result + _BPE_CACHE[tiktoken_bpe_file] = result + return result + class QWenTokenizer: """QWen tokenizer.""" diff --git a/aworld/models/utils.py b/aworld/models/utils.py index aa1678b93..ae98eed22 100644 --- a/aworld/models/utils.py +++ b/aworld/models/utils.py @@ -11,6 +11,34 @@ from aworld.models.openai_tokenizer import openai_tokenizer from aworld.utils import import_package +# Global cache for tiktoken encodings to prevent memory leaks +_TIKTOKEN_ENCODING_CACHE = {} + + +def _get_cached_tiktoken_encoding(model: str): + """ + Get cached tiktoken encoding to prevent memory leaks. + + Args: + model: Model name (e.g., 'gpt-4o', 'claude-3-opus') + + Returns: + Cached tiktoken encoding object + """ + if model not in _TIKTOKEN_ENCODING_CACHE: + import tiktoken + try: + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) + logger.debug(f"Created and cached tiktoken encoding for model: {model}") + except KeyError: + logger.debug(f"{model} model not found. Using cl100k_base encoding.") + # Cache cl100k_base if not already cached + if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE: + _TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base") + # Reuse cl100k_base for this model + _TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"] + return _TIKTOKEN_ENCODING_CACHE[model] + class ModelUtils: """Utility class for model-related operations""" @@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co def num_tokens_from_string(string: str, model: str = "openai"): """Return the number of tokens used by a list of messages.""" - import tiktoken - if model.lower() == "qwen": encoding = qwen_tokenizer elif model.lower() == "openai": encoding = openai_tokenizer else: - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug( - f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) return len(encoding.encode(string)) def num_tokens_from_messages(messages, model="openai"): """Return the number of tokens used by a list of messages.""" import_package("tiktoken") - import tiktoken if model.lower() == "qwen": encoding = qwen_tokenizer elif model.lower() == "openai": encoding = openai_tokenizer else: - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning( - f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) tokens_per_message = 3 tokens_per_name = 1 @@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"): def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"): import_package("tiktoken") - import tiktoken if model.lower() == "qwen": return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides) elif model.lower() == "openai": return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides) - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning(f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) return encoding.truncate(messages, max_tokens, keep_both_sides)