-
Notifications
You must be signed in to change notification settings - Fork 675
[WIP] Proper tool calling support in the torchtune #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2794
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey @krammnic, does this support tool calls for all formats (like openai, sharegpt etc)? |
It still WIP, but yes, it will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I left a few small comments. We should also add a test to ensure that this actually works and generates the expected outputs on a tool-calling dataset
token_ids = self.tokenizer.encode(text).ids | ||
if add_bos and not self.hf_adds_bos and self.bos_token not in text: | ||
|
||
# Both bos_id and eos_id might be None (null). Therefore, we need an additional check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related to tool-calling? Or a separate issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is caused by separate issue in HuggingfaceBaseTokenizer.
try: | ||
self.bos_token = self._get_token_from_config(self.config, "bos_token") | ||
self.eos_token = self._get_token_from_config(self.config, "eos_token") | ||
except ValueError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I wonder whether we should just modify _get_token_from_config
to directly return None (possibly logging a warning) rather than use this try/except
masked (bool): whether the message is masked in the sample. If True, do not use | ||
in loss calculation. Default: False | ||
ipython (bool): whether the message is a tool call. Default: False | ||
tool_calls (Optional[list]): list of tool calls related to this message. Default: None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also update the role "ipython" to "tool" to match what's done by Hugging Face?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch, that argument seemed to me weird.
[project.optional-dependencies] | ||
dev = [ | ||
"bitsandbytes>=0.43.0", | ||
# "bitsandbytes>=0.43.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this an intentional removal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope :/ I don't like to do it on Mac (in another case it will not install) and then remove the comment. Will open separate PR to address this
masked=d.get("masked", False), | ||
ipython=d.get("ipython", False), | ||
tool_calls=d.get("tool_calls", []), | ||
tool=d.get("tool", False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I agree with this change, it currently breaks existing tokenizers. Repro:
from torchtune.datasets import alpaca_cleaned_dataset
from torchtune.models.qwen2_5 import qwen2_5_tokenizer
vocab_path = "/tmp/Qwen2.5-14B-Instruct/vocab.json"
merges_path = "/tmp/Qwen2.5-14B-Instruct/merges.txt"
tokenizer_json_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer.json"
tokenizer_config_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer_config.json"
tokenizer_qwen = qwen2_5_tokenizer(
path=vocab_path,
merges_file=merges_path,
max_seq_len=512
)
dataset_qwen = alpaca_cleaned_dataset(tokenizer=tokenizer_qwen, packed=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yep. We need to remove ipython everywhere.
@krammnic This is great progress. It looks like a bit of work needs to be done on BC with existing tokenizers (unless the plan is to fully deprecate them). In addition I'm seeing some issues with the jinja rendering - I think it may require explicitly passing tools to the renderer (hf ref). Repro: from torchtune.datasets import alpaca_cleaned_dataset
from torchtune.modules.transforms.tokenizers import HuggingFaceModelTokenizer
tokenizer_json_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer.json"
tokenizer_config_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer_config.json"
tokenizer_hf = HuggingFaceModelTokenizer(
tokenizer_json_path=tokenizer_json_path,
tokenizer_config_json_path=tokenizer_config_path,
max_seq_len=512,
)
dataset_hf = alpaca_cleaned_dataset(tokenizer=tokenizer_hf, packed=True) Basically optionally propagating |
"role": m.role, | ||
"content": m.content[0]["content"], | ||
"tool_calls": m.tool_calls, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some issues specifically with the LLaMA 3(.3) tokenizer here, which didn't play nicely with empty tool calls []
or None. Ended up replacing with:
current_messages = [
{
"role": m.role,
"content": m.content[0]["content"],
**({"tool_calls": m.tool_calls} if m.tool_calls is not None else {})
}
for m in messages[: i + 1]
]
I'm not sure if this is the correct logic though (or if this plays nicely with all tokenizers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works fine with other tokenizers! Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me this broke qwen coder (sigh). I think removing strict undefined may have fixed it? I don't recall exactly. I kept trying components from the transformers jinja compile until I found something, but I don't know much about how jinja works - there may be a better solution.
@nathan-az Hey! We might want to merge this. I will introduce some final changes tomorrow and then we will be able to merge. Thanks for the patience and comprehensive review. |
Sounds good. Given the issues I had with other tokenizers I think it would be good to add tests with a couple of the popular models' tokenizers to confirm that:
Up to you whether we want to put their tokenizers directly in the torchtune source or if we add |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example