14
14
from together .utils import finetune_price_to_dollars , log_warn , parse_timestamp
15
15
16
16
17
+ _CONFIRMATION_MESSAGE = (
18
+ "You are about to create a fine-tuning job. "
19
+ "The cost of your job will be determined by the model size, the number of tokens "
20
+ "in the training file, the number of tokens in the validation file, the number of epochs, and "
21
+ "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n "
22
+ "You can pass `-y` or `--confirm` to your command to skip this message.\n \n "
23
+ "Do you want to proceed?"
24
+ )
25
+
26
+
17
27
class DownloadCheckpointTypeChoice (click .Choice ):
18
28
def __init__ (self ) -> None :
19
29
super ().__init__ ([ct .value for ct in DownloadCheckpointType ])
@@ -67,6 +77,14 @@ def fine_tuning(ctx: click.Context) -> None:
67
77
"--suffix" , type = str , default = None , help = "Suffix for the fine-tuned model name"
68
78
)
69
79
@click .option ("--wandb-api-key" , type = str , default = None , help = "Wandb API key" )
80
+ @click .option (
81
+ "--confirm" ,
82
+ "-y" ,
83
+ type = bool ,
84
+ is_flag = True ,
85
+ default = False ,
86
+ help = "Whether to skip the launch confirmation message" ,
87
+ )
70
88
def create (
71
89
ctx : click .Context ,
72
90
training_file : str ,
@@ -84,6 +102,7 @@ def create(
84
102
lora_trainable_modules : str ,
85
103
suffix : str ,
86
104
wandb_api_key : str ,
105
+ confirm : bool ,
87
106
) -> None :
88
107
"""Start fine-tuning"""
89
108
client : Together = ctx .obj
@@ -111,32 +130,37 @@ def create(
111
130
"You have specified a number of evaluation loops but no validation file."
112
131
)
113
132
114
- response = client .fine_tuning .create (
115
- training_file = training_file ,
116
- model = model ,
117
- n_epochs = n_epochs ,
118
- validation_file = validation_file ,
119
- n_evals = n_evals ,
120
- n_checkpoints = n_checkpoints ,
121
- batch_size = batch_size ,
122
- learning_rate = learning_rate ,
123
- lora = lora ,
124
- lora_r = lora_r ,
125
- lora_dropout = lora_dropout ,
126
- lora_alpha = lora_alpha ,
127
- lora_trainable_modules = lora_trainable_modules ,
128
- suffix = suffix ,
129
- wandb_api_key = wandb_api_key ,
130
- verbose = True ,
131
- )
133
+ if confirm or click .confirm (_CONFIRMATION_MESSAGE , default = True , show_default = True ):
134
+ response = client .fine_tuning .create (
135
+ training_file = training_file ,
136
+ model = model ,
137
+ n_epochs = n_epochs ,
138
+ validation_file = validation_file ,
139
+ n_evals = n_evals ,
140
+ n_checkpoints = n_checkpoints ,
141
+ batch_size = batch_size ,
142
+ learning_rate = learning_rate ,
143
+ lora = lora ,
144
+ lora_r = lora_r ,
145
+ lora_dropout = lora_dropout ,
146
+ lora_alpha = lora_alpha ,
147
+ lora_trainable_modules = lora_trainable_modules ,
148
+ suffix = suffix ,
149
+ wandb_api_key = wandb_api_key ,
150
+ verbose = True ,
151
+ )
132
152
133
- report_string = f"Successfully submitted a fine-tuning job { response .id } "
134
- if response .created_at is not None :
135
- created_time = datetime .strptime (response .created_at , "%Y-%m-%dT%H:%M:%S.%f%z" )
136
- # created_at reports UTC time, we use .astimezone() to convert to local time
137
- formatted_time = created_time .astimezone ().strftime ("%m/%d/%Y, %H:%M:%S" )
138
- report_string += f" at { formatted_time } "
139
- rprint (report_string )
153
+ report_string = f"Successfully submitted a fine-tuning job { response .id } "
154
+ if response .created_at is not None :
155
+ created_time = datetime .strptime (
156
+ response .created_at , "%Y-%m-%dT%H:%M:%S.%f%z"
157
+ )
158
+ # created_at reports UTC time, we use .astimezone() to convert to local time
159
+ formatted_time = created_time .astimezone ().strftime ("%m/%d/%Y, %H:%M:%S" )
160
+ report_string += f" at { formatted_time } "
161
+ rprint (report_string )
162
+ else :
163
+ click .echo ("No confirmation received, stopping job launch" )
140
164
141
165
142
166
@fine_tuning .command ()
0 commit comments