Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring training for readability #296

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

levoz92
Copy link

@levoz92 levoz92 commented Nov 24, 2024

Removed redundant imports
Added optimization functions for data handling, tokenization, batch allocation and distribution of training.

Copy link
Collaborator

@musab-mk musab-mk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the old code, as keeping them as comments is making it look not readable as much

@musab-mk
Copy link
Collaborator

musab-mk commented Dec 5, 2024

Thank you for the PR btw

@musab-mk musab-mk changed the title Levent update Refactoring for readability Dec 5, 2024
@musab-mk musab-mk changed the title Refactoring for readability Refactoring training for readability Dec 5, 2024
@levoz92
Copy link
Author

levoz92 commented Dec 6, 2024

@musab-mk Np. Check my new commit. I removed commented out code.

@@ -248,6 +265,55 @@ def train():
print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***")
training_utils.print_some_examples(eval_dataset, tokenizer)

# Dynamic batch size based on max tokens per batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While dynamic batch size might seem a good idea for dealing with memory issues, it would cause instabilities in training. Due to the difference in gradient updates per conversation. I would rather prefer stability over memory efficiency. Stable updates with higher cost of GPUs should be preferred over a cheaper&faster training

@@ -139,6 +134,23 @@ def trainer_save_model_safe(trainer: transformers.Trainer):
trainer.save_model()


"""
Below is the updated train() function from LEVENT OZBEK.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Authorship tracking is a responsibility of git. We should remove all authorship info from the code.

Comment on lines +139 to +153
Most of the changes are identical to those in train_lora.py. I simply applied the changes to the utility code in training_utils.py
I commented out the original train() function

- training_utils.tokenize_and_cache() is used for both training and evaluation datasets to avoid repetition.
- dynamic_batch_size() function auto adjusts batch sizes based on token counts. I did not implement this in train_lora.py since loras are trained on a smaller data so I felt that it wasn't too necessary there.
- DataLoaders are constructed using BatchSampler to dynamically adjust the batch size per epoch.
- distributed DataLoader is used if local_rank != -1.
- updated to use the optimized preprocess_logits_for_metrics dynamically compute_metrics from training_utils.py.

Advantages of These Changes:
- handles datasets with varying sequence lengths dynamically
- supports both single-GPU and distributed setups.
"""


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates would rather preferred to go into PR description, not the code

Comment on lines -75 to -78
"""Preprocesses the logits during evaluation by computing the greedy token predictions for
accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are
of shape (batch_size x seq_len)"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep docstring . If you think its obsolete in terms of information, better to update it, rather than removing completely.

def compute_metrics(eval_preds, id2token, tokenizer):
"""Computes next-token accuracy and perplexity metrics for evaluation"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets have a docstring here too

Comment on lines 184 to -182

metrics = {
"accuracy": acc_count / total_num,
"perplexity": perplexity,
"accuracy_first_token": first_token_correct_count / first_token_total_count,
"total_number_first_token": first_token_total_count,
"first_token_param_values": first_token_param_value_acc
/ first_token_param_value_total,
"first_token_param_values_total": first_token_param_value_total,
"accuracy_first_token": first_token_correct / max(first_token_total, 1),
}

for token_id, stat in sorted(
first_token_label_dic.items(), key=lambda x: -x[1]["total"]
)[:5]:
token = tokenizer.decode([token_id])
metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"]
metrics[f"accuracy_first_token_{token}_total"] = stat["total"]

for token_id in dic:
# Token-specific accuracies
for token_id, stat in token_stats.items():
token = id2token[token_id]
total_num = dic[token_id]["total"]
acc = -1
if total_num > 0:
acc = dic[token_id]["acc"] / total_num
metrics[f"accuracy_{token}"] = acc
metrics[f"accuracy_total_num_{token}"] = total_num
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see some metrics has been removed. Could you kindly explain why these changes has been made?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants