-
Notifications
You must be signed in to change notification settings - Fork 120
为tiktoken新增缓存,修复内存泄漏问题 #765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,34 @@ | |
| from aworld.models.openai_tokenizer import openai_tokenizer | ||
| from aworld.utils import import_package | ||
|
|
||
| # Global cache for tiktoken encodings to prevent memory leaks | ||
| _TIKTOKEN_ENCODING_CACHE = {} | ||
|
|
||
|
|
||
| def _get_cached_tiktoken_encoding(model: str): | ||
| """ | ||
| Get cached tiktoken encoding to prevent memory leaks. | ||
|
|
||
| Args: | ||
| model: Model name (e.g., 'gpt-4o', 'claude-3-opus') | ||
|
|
||
| Returns: | ||
| Cached tiktoken encoding object | ||
| """ | ||
| if model not in _TIKTOKEN_ENCODING_CACHE: | ||
| import tiktoken | ||
| try: | ||
| _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) | ||
| logger.debug(f"Created and cached tiktoken encoding for model: {model}") | ||
| except KeyError: | ||
| logger.debug(f"{model} model not found. Using cl100k_base encoding.") | ||
| # Cache cl100k_base if not already cached | ||
| if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE: | ||
| _TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base") | ||
| # Reuse cl100k_base for this model | ||
| _TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"] | ||
| return _TIKTOKEN_ENCODING_CACHE[model] | ||
|
Comment on lines
+18
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The _get_cached_tiktoken_encoding function introduces a potential Denial of Service (DoS) vulnerability due to an unbounded cache. The global dictionary _TIKTOKEN_ENCODING_CACHE can store tiktoken encoding objects for every unique model string, including invalid ones, leading to memory exhaustion if the model parameter is user-controlled. Consider using a fixed-size cache (e.g., functools.lru_cache) or validating model names against an allow-list. Furthermore, this cache is not thread-safe, which can lead to race conditions in a multi-threaded environment. It is recommended to use threading.Lock to ensure thread safety for cache operations. |
||
|
|
||
|
|
||
| class ModelUtils: | ||
| """Utility class for model-related operations""" | ||
|
|
@@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co | |
|
|
||
| def num_tokens_from_string(string: str, model: str = "openai"): | ||
| """Return the number of tokens used by a list of messages.""" | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| encoding = qwen_tokenizer | ||
| elif model.lower() == "openai": | ||
| encoding = openai_tokenizer | ||
| else: | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.debug( | ||
| f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
| return len(encoding.encode(string)) | ||
|
|
||
| def num_tokens_from_messages(messages, model="openai"): | ||
| """Return the number of tokens used by a list of messages.""" | ||
| import_package("tiktoken") | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| encoding = qwen_tokenizer | ||
| elif model.lower() == "openai": | ||
| encoding = openai_tokenizer | ||
| else: | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.warning( | ||
| f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
|
|
||
| tokens_per_message = 3 | ||
| tokens_per_name = 1 | ||
|
|
@@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"): | |
|
|
||
| def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"): | ||
| import_package("tiktoken") | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides) | ||
| elif model.lower() == "openai": | ||
| return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides) | ||
|
|
||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.warning(f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
|
|
||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
| return encoding.truncate(messages, max_tokens, keep_both_sides) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前的缓存实现不是线程安全的。在多线程环境中,可能会发生竞态条件,导致多个线程同时尝试加载和缓存同一个BPE文件。这不仅效率低下,还可能导致数据不一致。建议使用
threading.Lock来保护对缓存的并发访问。另外,
_load_tiktoken_bpe函数及其缓存_BPE_CACHE与aworld/models/qwen_tokenizer.py中的代码完全相同。为了提高可维护性,建议将这部分重复代码提取到一个共享的工具模块中(例如,aworld/models/tokenizer_utils.py),这样不仅可以消除重复代码,还能实现一个统一的BPE缓存。下面是一个线程安全实现的例子(使用双重检查锁定模式以提高效率):