Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions aworld/models/openai_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,29 @@
ENDOFTEXT: 100256,
}

# Global cache to prevent memory leaks from repeatedly loading BPE files
_BPE_CACHE = {}


def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file similar to qwen_tokenizer."""
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Check cache first
if tiktoken_bpe_file in _BPE_CACHE:
return _BPE_CACHE[tiktoken_bpe_file]

# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
return {
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)

result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}

# Cache the result
_BPE_CACHE[tiktoken_bpe_file] = result
return result
Comment on lines +38 to +58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当前的缓存实现不是线程安全的。在多线程环境中,可能会发生竞态条件,导致多个线程同时尝试加载和缓存同一个BPE文件。这不仅效率低下,还可能导致数据不一致。建议使用 threading.Lock 来保护对缓存的并发访问。

另外,_load_tiktoken_bpe 函数及其缓存 _BPE_CACHEaworld/models/qwen_tokenizer.py 中的代码完全相同。为了提高可维护性,建议将这部分重复代码提取到一个共享的工具模块中(例如,aworld/models/tokenizer_utils.py),这样不仅可以消除重复代码,还能实现一个统一的BPE缓存。

下面是一个线程安全实现的例子(使用双重检查锁定模式以提高效率):

import threading

_BPE_CACHE = {}
_BPE_CACHE_LOCK = threading.Lock()

def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
    # 快速路径,无锁检查
    if tiktoken_bpe_file in _BPE_CACHE:
        return _BPE_CACHE[tiktoken_bpe_file]
    
    # 加锁
    with _BPE_CACHE_LOCK:
        # 再次检查,防止其他线程已经填充了缓存
        if tiktoken_bpe_file in _BPE_CACHE:
            return _BPE_CACHE[tiktoken_bpe_file]
        
        # 加载文件并填充缓存
        with open(tiktoken_bpe_file, 'rb') as f:
            contents = f.read()
        result = {
            base64.b64decode(token): int(rank)
            for token, rank in (line.split() for line in contents.splitlines() if line)
        }
        _BPE_CACHE[tiktoken_bpe_file] = result
        return result



class OpenAITokenizer:
"""OpenAI tokenizer using local tiktoken file."""
Expand Down
19 changes: 17 additions & 2 deletions aworld/models/qwen_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,29 @@
))
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)

# Global cache to prevent memory leaks from repeatedly loading BPE files
_BPE_CACHE = {}


def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Check cache first
if tiktoken_bpe_file in _BPE_CACHE:
return _BPE_CACHE[tiktoken_bpe_file]

# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
return {
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)

result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}

# Cache the result
_BPE_CACHE[tiktoken_bpe_file] = result
return result


class QWenTokenizer:
"""QWen tokenizer."""
Expand Down
56 changes: 34 additions & 22 deletions aworld/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,34 @@
from aworld.models.openai_tokenizer import openai_tokenizer
from aworld.utils import import_package

# Global cache for tiktoken encodings to prevent memory leaks
_TIKTOKEN_ENCODING_CACHE = {}


def _get_cached_tiktoken_encoding(model: str):
"""
Get cached tiktoken encoding to prevent memory leaks.

Args:
model: Model name (e.g., 'gpt-4o', 'claude-3-opus')

Returns:
Cached tiktoken encoding object
"""
if model not in _TIKTOKEN_ENCODING_CACHE:
import tiktoken
try:
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
logger.debug(f"Created and cached tiktoken encoding for model: {model}")
except KeyError:
logger.debug(f"{model} model not found. Using cl100k_base encoding.")
# Cache cl100k_base if not already cached
if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE:
_TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base")
# Reuse cl100k_base for this model
_TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"]
return _TIKTOKEN_ENCODING_CACHE[model]
Comment on lines +18 to +40
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _get_cached_tiktoken_encoding function introduces a potential Denial of Service (DoS) vulnerability due to an unbounded cache. The global dictionary _TIKTOKEN_ENCODING_CACHE can store tiktoken encoding objects for every unique model string, including invalid ones, leading to memory exhaustion if the model parameter is user-controlled. Consider using a fixed-size cache (e.g., functools.lru_cache) or validating model names against an allow-list. Furthermore, this cache is not thread-safe, which can lead to race conditions in a multi-threaded environment. It is recommended to use threading.Lock to ensure thread safety for cache operations.



class ModelUtils:
"""Utility class for model-related operations"""
Expand Down Expand Up @@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co

def num_tokens_from_string(string: str, model: str = "openai"):
"""Return the number of tokens used by a list of messages."""
import tiktoken

if model.lower() == "qwen":
encoding = qwen_tokenizer
elif model.lower() == "openai":
encoding = openai_tokenizer
else:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug(
f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)
return len(encoding.encode(string))

def num_tokens_from_messages(messages, model="openai"):
"""Return the number of tokens used by a list of messages."""
import_package("tiktoken")
import tiktoken

if model.lower() == "qwen":
encoding = qwen_tokenizer
elif model.lower() == "openai":
encoding = openai_tokenizer
else:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(
f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)

tokens_per_message = 3
tokens_per_name = 1
Expand All @@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"):

def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"):
import_package("tiktoken")
import tiktoken

if model.lower() == "qwen":
return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides)
elif model.lower() == "openai":
return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides)

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)
return encoding.truncate(messages, max_tokens, keep_both_sides)


Expand Down