-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathtrainer.py
More file actions
99 lines (85 loc) · 2.66 KB
/
Copy pathtrainer.py
File metadata and controls
99 lines (85 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import click
from typing import Optional, Union, List
from rorf.router.utils import write_config_to_json
import sys
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s %(levelname)s:%(message)s",
level=logging.INFO,
stream=sys.stdout,
datefmt="%I:%M:%S",
)
@click.group()
@click.pass_context
def run(ctx):
ctx.obj = {}
@run.command()
@click.option("--model_a", type=str, default="llama-3.1-405b-instruct")
@click.option("--model_b", type=str, default="llama-3.1-70b-instruct")
@click.option("--dataset_path", type=str)
@click.option("--eval_dataset", type=str)
@click.option("--embedding_provider", type=str, default="jina")
@click.option("--prompt_embedding_cache", type=str, default=None)
@click.option("--max_depth", type=int, default=20)
@click.option("--max_features", default=1.0)
@click.option("--n_estimators", type=int, default=100)
@click.option("--save_dir", type=str, default="checkpoints")
@click.option("--model_id", type=str, default=None)
@click.option("--model_org", type=str)
@click.pass_context
def rorf_classifier(
ctx,
model_a: str,
model_b: str,
dataset_path: str,
eval_dataset: str,
embedding_provider: str,
prompt_embedding_cache: Optional[str],
max_depth: int,
max_features: Union[float, str],
n_estimators: int,
save_dir: str,
model_id: str,
model_org: str,
):
from rorf.router.rorf import RoRFTrainer
llms = [model_a, model_b]
configs = {
"trainer": "RoRF",
"llms": llms,
"dataset_path": dataset_path,
"eval_dataset": eval_dataset,
"embedding_provider": embedding_provider,
"prompt_embedding_cache": prompt_embedding_cache,
"max_depth": max_depth,
"max_features": max_features,
"n_estimators": n_estimators,
"save_dir": save_dir,
"model_id": model_id,
"model_org": model_org,
}
trainer_obj = RoRFTrainer(
llms=llms,
dataset_path=dataset_path,
eval_dataset=eval_dataset,
embedding_provider=embedding_provider,
prompt_embedding_cache=prompt_embedding_cache,
max_depth=max_depth,
max_features=max_features,
n_estimators=n_estimators,
save_dir=save_dir,
model_id=model_id,
model_org=model_org,
)
write_config_to_json(configs, trainer_obj.save_path)
ctx.obj["trainer"] = trainer_obj
ctx.obj["configs"] = configs
@run.result_callback()
@click.pass_context
def process_result(ctx, result, **kwargs):
trainer = ctx.obj.get("trainer", None)
configs = ctx.obj.get("configs", None)
result = trainer.train()
if __name__ == "__main__":
run()