-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate new prompt mechanism into training #40
Conversation
train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you said putting it into the root, i thought under "functionary/"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this can create confusion, we should move all training module under functionary/train/
train.py
Outdated
replace_llama_attn_with_flash_attn() | ||
from functionary.prompt import EndToken | ||
from train.custom_datasets import CustomDataset, split_data | ||
from train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transformers must be imported after the patch. Lets revert back the previous order
train.py
Outdated
special_tokens_dict = {"additional_special_tokens": added_tokens} | ||
smart_tokenizer_and_embedding_resize( | ||
special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should refactor tokenizer initialization into seperate function
train/custom_datasets.py
Outdated
prompt_str = ( | ||
"system:\n" | ||
+ generate_schema_from_functions(functions=messages["functions"]) | ||
+ "\nsystem:\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should let the get_prompt_from_messages function to prepare the input. So, prepend this as a message dict to the original messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with proper role and content etc
train/custom_datasets.py
Outdated
"""Prepares a list of messages for the model by calling `prepare_message_for_model` function on each of them and | ||
concatenating the returned input_ids and targets. Also, the function merges the text of the messages. | ||
def prepare_training_inputs( | ||
messages: List[Dict], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this is a list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add some CI action to catch these kind of stuff
train/custom_datasets.py
Outdated
"input_ids": ret["input_ids"], | ||
"labels": ret["labels"], | ||
"attention_mask": ret["attention_mask"], | ||
"input_ids": ret[1]["input_ids"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this ret[1] ? Can we find a more explicit way. If we cannot, we should at least add some clarification
functionary/train/train.py
Outdated
|
||
import torch | ||
import torch.distributed | ||
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn | ||
import transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transformers must be imported after the monkey patch.
tests/test_prompt_creation.py
Outdated
self.assertEqual( | ||
final_prompt.strip(), self.final_prompt.strip(), "wrong final prompt from: get_prompt_from_messages" | ||
final_prompt.strip(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why .strip is necessary? Is there a bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
def test_prepare_training_inputs(self): | ||
"""this function is used to test function: prepare_training_inputs""" | ||
# note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 | ||
tokenizer = LlamaTokenizer.from_pretrained("musabgultekin/functionary-7b-v1", legacy=True) | ||
tokenizer = LlamaTokenizer.from_pretrained( | ||
"musabgultekin/functionary-7b-v1", legacy=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this using legacy=True and fast tokenizer while our training is using slow and legacy=False Can we protect the integrity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've set both to using fast tokenizer with legacy=True.
functionary/train/custom_datasets.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this file into something much more appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, let's spend more time on naming things
@@ -5,6 +5,10 @@ | |||
|
|||
import torch | |||
|
|||
from functionary.schema import generate_schema_from_functions | |||
|
|||
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was already a system message in the inference. Either let's delete this and import that one. Or delete the other one.
len(tokenizer)
instead oftokenizer.vocab_size
when reshaping the tensors/arrays