Skip to content

Speed up function _estimate_string_tokens #2156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from .. import _utils, usage
from .._utils import PeekableAsyncStream
from ..messages import (
AudioUrl,
BinaryContent,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -308,18 +306,19 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0

if isinstance(content, str):
return len(re.split(r'[\s",.:]+', content.strip()))
else:
tokens = 0
for part in content:
if isinstance(part, str):
tokens += len(re.split(r'[\s",.:]+', part.strip()))
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
if isinstance(part, (AudioUrl, ImageUrl)):
tokens += 0
elif isinstance(part, BinaryContent):
tokens += len(part.data)
else:
tokens += 0
return tokens
return len(_TOKEN_SPLIT_RE.split(content.strip()))

tokens = 0
for part in content:
if isinstance(part, str):
tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
elif isinstance(part, BinaryContent):
tokens += len(part.data)
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.

return tokens


_TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will have some overhead at import time, it's small but it'll add up if we do this with all regular expressions. Should we stick with re.split(r'[\s",.:]+', part.strip()) as it'll cache the regex the first time it's run.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To validate the performance characteristics, I tried an experiment where I replaced the current suggestion with inline re.split and ran it on the generated test set and timed the runtime. So the only change is the global re.compile vs inline re.split.
global re.compile time -> 1.68ms
inline re.split -> 2.57ms
Yes, regex does cache the complied regex for future use, but it has overhead that especially when used in a loop can be high. In my experience with optimizations discovered with codeflash, I've seen re.compile be faster.
In this case, since regex is used multiple times and in a loop i would recommend regex compilation. Although its your decision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file needs to be imported explicitly, so when it's imported we can assume the _estimate_string_tokens function is going to be used, so compiling the regex at import time is fine. I'd feel different if this was in a file that's always imported by Pydantic AI itself, and we wouldn't know if the regex was actually going to be used.