-
Notifications
You must be signed in to change notification settings - Fork 28
/
common_wandb.py
86 lines (76 loc) · 2.91 KB
/
common_wandb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import time
from common import limit_depth, max_completion_depth, count_depth, limit_tokens
import llm
from cmdline import args
import wandb
if args.use_wandb:
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
group=args.wandb_group,
config=args.dict(),
name=args.wandb_name,
)
def compute_gen_stat(pre_gen_time, pre_gen_toks, text, depth):
if args.use_wandb:
# Compute stats about generate_complete
gen_stat = {}
gen_stat["generate/gen_time"] = time.time() - pre_gen_time
gen_stat["generate/gen_length"] = llm.token_counter - pre_gen_toks
gen_stat["generate/score_sign"] = 2 * int(text is not None) - 1
gen_stat["generate/completion_depth"] = depth
return gen_stat
else:
return {}
def log_tree(montecarlo, gen_stat, node):
if args.use_wandb:
# Compute some tree stats over time
stat = montecarlo.get_stat_dict()
stat = {f"tree/{k}": v for k, v in stat.items()}
stat["tree/node_depth"] = count_depth(node)
stat["tree/n_tokens"] = llm.token_counter
# Final solution depth
if montecarlo.solution is not None:
solution_depth = 1
parent = node.parent
while parent is not None:
solution_depth += 1
parent = parent.parent
stat["final/solution_depth"] = solution_depth
wandb.log({**gen_stat, **stat})
def compute_summary(montecarlo, node_dups_counter, init_time, ver_avg = 0, ver_count = 0, llm_avg = 0, llm_count = 0):
# Compute summary stats
if args.use_wandb:
stat = {}
stat["final/time"] = time.time() - init_time
stat["final/solved"] = not limit_tokens()
stat["final/text"] = montecarlo.solution
stat["final/n_tokens"] = llm.token_counter
stat["final/node_dups"] = node_dups_counter
stat["final/ver_avg"] = ver_avg
stat["final/ver_count"] = ver_count
stat["final/llm_avg"] = llm_avg
stat["final/llm_count"] = llm_count
# Log pass at t
ts = [500, 1000, 2000, 5000]
for t in ts:
pass_at_t = llm.token_counter <= t
stat[f"final/pass_at_{t}"] = int(pass_at_t)
final_stat = montecarlo.get_stat_dict()
final_stat = {f"final/{k}": v for k, v in final_stat.items()}
stat = {**stat, **final_stat}
wandb.log(stat)
def compute_summary_nomc(solution, init_time):
# Compute summary stats
if args.use_wandb:
stat = {}
stat["final/time"] = time.time() - init_time
stat["final/solved"] = limit_tokens()
stat["final/text"] = solution
stat["final/n_tokens"] = llm.token_counter
# Log pass at t
ts = [500, 1000, 2000, 5000]
for t in ts:
pass_at_t = llm.token_counter <= t
stat[f"final/pass_at_{t}"] = int(pass_at_t)
wandb.log(stat)