-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
67 lines (57 loc) · 1.93 KB
/
train.py
File metadata and controls
67 lines (57 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import LEDForConditionalGeneration, Trainer, TrainingArguments
from data_prep import prepare_data
import logging
import gc
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def train():
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda:0")
logger.info(f"Using device: {device}")
dataset, tokenizer = prepare_data()
model = LEDForConditionalGeneration.from_pretrained(
"allenai/led-base-16384",
gradient_checkpointing=True,
use_cache=False
).to(device)
# Updated steps for smaller dataset
total_steps = int((len(dataset["train"]) * 3) / (2 * 8)) # epochs * batch_size * grad_accum
warmup_steps = int(0.1 * total_steps)
training_args = TrainingArguments(
output_dir="./legal_led_model",
per_device_train_batch_size=2, # Increased to 2 since dataset is smaller
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # Reduced to 8
learning_rate=2e-5,
num_train_epochs=3,
save_steps=25, # Adjusted for smaller dataset
eval_steps=25,
logging_steps=10,
eval_strategy="steps",
fp16=True,
gradient_checkpointing=True,
report_to="none",
max_grad_norm=1.0,
warmup_steps=warmup_steps,
dataloader_pin_memory=False,
optim="adamw_torch",
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"]
)
trainer.train()
model.save_pretrained("./legal_led_final")
tokenizer.save_pretrained("./legal_led_final")
if __name__ == "__main__":
train()