Skip to content

Commit 2a3fec6

Browse files
authored
Add a confirmation message before submitting a fine-tuning job (#190)
* Add a confirmation message before submitting a fine-tuning job * Clarify wording in the message * Bump the version * Revert the version bump * Fix formatting
1 parent 4d613c2 commit 2a3fec6

File tree

1 file changed

+49
-25
lines changed

1 file changed

+49
-25
lines changed

src/together/cli/api/finetune.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
1515

1616

17+
_CONFIRMATION_MESSAGE = (
18+
"You are about to create a fine-tuning job. "
19+
"The cost of your job will be determined by the model size, the number of tokens "
20+
"in the training file, the number of tokens in the validation file, the number of epochs, and "
21+
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
22+
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
23+
"Do you want to proceed?"
24+
)
25+
26+
1727
class DownloadCheckpointTypeChoice(click.Choice):
1828
def __init__(self) -> None:
1929
super().__init__([ct.value for ct in DownloadCheckpointType])
@@ -67,6 +77,14 @@ def fine_tuning(ctx: click.Context) -> None:
6777
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
6878
)
6979
@click.option("--wandb-api-key", type=str, default=None, help="Wandb API key")
80+
@click.option(
81+
"--confirm",
82+
"-y",
83+
type=bool,
84+
is_flag=True,
85+
default=False,
86+
help="Whether to skip the launch confirmation message",
87+
)
7088
def create(
7189
ctx: click.Context,
7290
training_file: str,
@@ -84,6 +102,7 @@ def create(
84102
lora_trainable_modules: str,
85103
suffix: str,
86104
wandb_api_key: str,
105+
confirm: bool,
87106
) -> None:
88107
"""Start fine-tuning"""
89108
client: Together = ctx.obj
@@ -111,32 +130,37 @@ def create(
111130
"You have specified a number of evaluation loops but no validation file."
112131
)
113132

114-
response = client.fine_tuning.create(
115-
training_file=training_file,
116-
model=model,
117-
n_epochs=n_epochs,
118-
validation_file=validation_file,
119-
n_evals=n_evals,
120-
n_checkpoints=n_checkpoints,
121-
batch_size=batch_size,
122-
learning_rate=learning_rate,
123-
lora=lora,
124-
lora_r=lora_r,
125-
lora_dropout=lora_dropout,
126-
lora_alpha=lora_alpha,
127-
lora_trainable_modules=lora_trainable_modules,
128-
suffix=suffix,
129-
wandb_api_key=wandb_api_key,
130-
verbose=True,
131-
)
133+
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
134+
response = client.fine_tuning.create(
135+
training_file=training_file,
136+
model=model,
137+
n_epochs=n_epochs,
138+
validation_file=validation_file,
139+
n_evals=n_evals,
140+
n_checkpoints=n_checkpoints,
141+
batch_size=batch_size,
142+
learning_rate=learning_rate,
143+
lora=lora,
144+
lora_r=lora_r,
145+
lora_dropout=lora_dropout,
146+
lora_alpha=lora_alpha,
147+
lora_trainable_modules=lora_trainable_modules,
148+
suffix=suffix,
149+
wandb_api_key=wandb_api_key,
150+
verbose=True,
151+
)
132152

133-
report_string = f"Successfully submitted a fine-tuning job {response.id}"
134-
if response.created_at is not None:
135-
created_time = datetime.strptime(response.created_at, "%Y-%m-%dT%H:%M:%S.%f%z")
136-
# created_at reports UTC time, we use .astimezone() to convert to local time
137-
formatted_time = created_time.astimezone().strftime("%m/%d/%Y, %H:%M:%S")
138-
report_string += f" at {formatted_time}"
139-
rprint(report_string)
153+
report_string = f"Successfully submitted a fine-tuning job {response.id}"
154+
if response.created_at is not None:
155+
created_time = datetime.strptime(
156+
response.created_at, "%Y-%m-%dT%H:%M:%S.%f%z"
157+
)
158+
# created_at reports UTC time, we use .astimezone() to convert to local time
159+
formatted_time = created_time.astimezone().strftime("%m/%d/%Y, %H:%M:%S")
160+
report_string += f" at {formatted_time}"
161+
rprint(report_string)
162+
else:
163+
click.echo("No confirmation received, stopping job launch")
140164

141165

142166
@fine_tuning.command()

0 commit comments

Comments
 (0)