Skip to content

Commit 50c99ac

Browse files
authored
Fix/trl-dependency-and-training-args (#8)
* chore: upgrade trl * fix: replace trainingargument to trlconfig
1 parent 534d606 commit 50c99ac

2 files changed

Lines changed: 5 additions & 7 deletions

File tree

demo.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
import torch
55
from peft import LoraConfig
6-
from transformers import (AutoModelForCausalLM, AutoTokenizer,
7-
BitsAndBytesConfig, TrainingArguments)
8-
from trl import SFTTrainer
6+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7+
from trl import SFTTrainer, SFTConfig
98

109
from dataset import SFTDataCollator, SFTDataset
1110
from merge import merge_lora_to_base_model
@@ -44,7 +43,7 @@ def train_and_merge(
4443
bnb_4bit_compute_dtype=torch.bfloat16,
4544
)
4645

47-
training_args = TrainingArguments(
46+
training_args = SFTConfig(
4847
per_device_train_batch_size=training_args.per_device_train_batch_size,
4948
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
5049
warmup_steps=100,
@@ -55,6 +54,7 @@ def train_and_merge(
5554
optim="paged_adamw_8bit",
5655
remove_unused_columns=False,
5756
num_train_epochs=training_args.num_train_epochs,
57+
max_seq_length=context_length,
5858
)
5959
tokenizer = AutoTokenizer.from_pretrained(
6060
model_id,
@@ -81,9 +81,7 @@ def train_and_merge(
8181
train_dataset=dataset,
8282
args=training_args,
8383
peft_config=lora_config,
84-
packing=True,
8584
data_collator=SFTDataCollator(tokenizer, max_seq_length=context_length),
86-
max_seq_length=context_length,
8785
)
8886

8987
# Train model

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ torch>=1.13.1
22
transformers>=4.37.2
33
peft>=0.10.0
44
loguru
5-
trl>=0.8.1
5+
trl>=0.9.3
66
bitsandbytes
77
pyyaml

0 commit comments

Comments
 (0)