33
44import torch
55from peft import LoraConfig
6- from transformers import (AutoModelForCausalLM , AutoTokenizer ,
7- BitsAndBytesConfig , TrainingArguments )
8- from trl import SFTTrainer
6+ from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
7+ from trl import SFTTrainer , SFTConfig
98
109from dataset import SFTDataCollator , SFTDataset
1110from merge import merge_lora_to_base_model
@@ -44,7 +43,7 @@ def train_and_merge(
4443 bnb_4bit_compute_dtype = torch .bfloat16 ,
4544 )
4645
47- training_args = TrainingArguments (
46+ training_args = SFTConfig (
4847 per_device_train_batch_size = training_args .per_device_train_batch_size ,
4948 gradient_accumulation_steps = training_args .gradient_accumulation_steps ,
5049 warmup_steps = 100 ,
@@ -55,6 +54,7 @@ def train_and_merge(
5554 optim = "paged_adamw_8bit" ,
5655 remove_unused_columns = False ,
5756 num_train_epochs = training_args .num_train_epochs ,
57+ max_seq_length = context_length ,
5858 )
5959 tokenizer = AutoTokenizer .from_pretrained (
6060 model_id ,
@@ -81,9 +81,7 @@ def train_and_merge(
8181 train_dataset = dataset ,
8282 args = training_args ,
8383 peft_config = lora_config ,
84- packing = True ,
8584 data_collator = SFTDataCollator (tokenizer , max_seq_length = context_length ),
86- max_seq_length = context_length ,
8785 )
8886
8987 # Train model
0 commit comments