-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadapter_args_helper.py
83 lines (80 loc) · 2.85 KB
/
adapter_args_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from dataclasses import dataclass, field
from transformers import (
TrainingArguments,
)
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to utilize.
"""
model_size: Optional[str] = field(
default="medium",
metadata={"help": "The size of pretrained GPT2 model."}
)
load_checkpoint_adapter: Optional[str] = field(
default="",
metadata={"help": "Path to adapter checkpoint."}
)
max_seq_len: Optional[int] = field(
default=512,
metadata={"help": "Maximum sequence length the model can process."},
)
@dataclass
class DataArguments:
"""
Arguments pertaining to the data loading and preprocessing pipeline.
"""
dataset_path: Optional[str] = field(
default='/home/bryan/datasets/bookcorpusopen/bookcorpusopen_chunked.arrow',
metadata={"help": "Dataset path."},
)
preprocessing_num_workers: Optional[int] = field(
default=16,
metadata={"help": "The number of processes to use for the preprocessing."},
)
bookcorpusopen_story_column_name: Optional[str] = field(
default="chunk",
metadata={"help": "The name of the dataset column containing the story data."},
)
genre: Optional[str] = field(
default="Fiction",
metadata={"help": "Genre that we want the adapter to be trained with."},
)
adapter_id: Optional[int] = field(
default=-1,
metadata={"help": "Id for the genre we want the adapter to be trained with."},
)
match_up_to_n_genres: Optional[int] = field(
default=None,
metadata={"help": "How many of the firsts bookcorpusopen genres entries\
is considered as a genre to match with the genre input.\
None defaults to use all bookcorpusopen genres to match."},
)
sample_row: Optional[int] = field(
default=None,
metadata={"help": "Set the int number to sample the dataset,\
None means using all the datasets samples available,\
Setting it too small (e.g. 200) triggers batching error"}
)
@dataclass
class TrainingArguments(TrainingArguments):
"""
Arguments pertraining to the training pipeline.
"""
output_dir: Optional[str] = field(
default="./save",
metadata={"help": "Output directory"},
)
eval_accumulation_steps: Optional[int] = field(
default=None,
metadata={"help": "Evaluation accumulation steps"}
)
early_stopping_patience: Optional[int] = field(
default=5,
metadata={"help": "Early stopping patience for EarlyStoppingCallback"}
)
load_best_model_at_end: Optional[bool] = field(
default=True,
metadata={"help": "Needed for EarlyStoppingCallback"}
)