-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring training for readability #296
base: main
Are you sure you want to change the base?
Conversation
…dling, tokenization, batch allocation and training distribution
…side the train() function
…side the train() function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we remove the old code, as keeping them as comments is making it look not readable as much
Thank you for the PR btw |
@musab-mk Np. Check my new commit. I removed commented out code. |
@@ -248,6 +265,55 @@ def train(): | |||
print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***") | |||
training_utils.print_some_examples(eval_dataset, tokenizer) | |||
|
|||
# Dynamic batch size based on max tokens per batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While dynamic batch size might seem a good idea for dealing with memory issues, it would cause instabilities in training. Due to the difference in gradient updates per conversation. I would rather prefer stability over memory efficiency. Stable updates with higher cost of GPUs should be preferred over a cheaper&faster training
@@ -139,6 +134,23 @@ def trainer_save_model_safe(trainer: transformers.Trainer): | |||
trainer.save_model() | |||
|
|||
|
|||
""" | |||
Below is the updated train() function from LEVENT OZBEK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Authorship tracking is a responsibility of git. We should remove all authorship info from the code.
Most of the changes are identical to those in train_lora.py. I simply applied the changes to the utility code in training_utils.py | ||
I commented out the original train() function | ||
|
||
- training_utils.tokenize_and_cache() is used for both training and evaluation datasets to avoid repetition. | ||
- dynamic_batch_size() function auto adjusts batch sizes based on token counts. I did not implement this in train_lora.py since loras are trained on a smaller data so I felt that it wasn't too necessary there. | ||
- DataLoaders are constructed using BatchSampler to dynamically adjust the batch size per epoch. | ||
- distributed DataLoader is used if local_rank != -1. | ||
- updated to use the optimized preprocess_logits_for_metrics dynamically compute_metrics from training_utils.py. | ||
|
||
Advantages of These Changes: | ||
- handles datasets with varying sequence lengths dynamically | ||
- supports both single-GPU and distributed setups. | ||
""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updates would rather preferred to go into PR description, not the code
"""Preprocesses the logits during evaluation by computing the greedy token predictions for | ||
accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are | ||
of shape (batch_size x seq_len)""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep docstring . If you think its obsolete in terms of information, better to update it, rather than removing completely.
def compute_metrics(eval_preds, id2token, tokenizer): | ||
"""Computes next-token accuracy and perplexity metrics for evaluation""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets have a docstring here too
|
||
metrics = { | ||
"accuracy": acc_count / total_num, | ||
"perplexity": perplexity, | ||
"accuracy_first_token": first_token_correct_count / first_token_total_count, | ||
"total_number_first_token": first_token_total_count, | ||
"first_token_param_values": first_token_param_value_acc | ||
/ first_token_param_value_total, | ||
"first_token_param_values_total": first_token_param_value_total, | ||
"accuracy_first_token": first_token_correct / max(first_token_total, 1), | ||
} | ||
|
||
for token_id, stat in sorted( | ||
first_token_label_dic.items(), key=lambda x: -x[1]["total"] | ||
)[:5]: | ||
token = tokenizer.decode([token_id]) | ||
metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"] | ||
metrics[f"accuracy_first_token_{token}_total"] = stat["total"] | ||
|
||
for token_id in dic: | ||
# Token-specific accuracies | ||
for token_id, stat in token_stats.items(): | ||
token = id2token[token_id] | ||
total_num = dic[token_id]["total"] | ||
acc = -1 | ||
if total_num > 0: | ||
acc = dic[token_id]["acc"] / total_num | ||
metrics[f"accuracy_{token}"] = acc | ||
metrics[f"accuracy_total_num_{token}"] = total_num |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see some metrics has been removed. Could you kindly explain why these changes has been made?
Removed redundant imports
Added optimization functions for data handling, tokenization, batch allocation and distribution of training.