-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathrun.py
81 lines (69 loc) · 2.49 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy
import itertools
import json
import logging
import multiprocessing
from datetime import timedelta
from typing import List
import pandas as pd
from config import MODELS, N_CHECKPOINTS, TRACKS
from river import metrics
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)
from tqdm import tqdm
def run_dataset(model_str, no_dataset, no_track):
model_name = model_str
track = TRACKS[no_track]
dataset = track.datasets[no_dataset]
MODELS["Binary classification"].update(MODELS["Multiclass classification"])
model = MODELS[track.name][model_name].clone()
print(f"Processing {model_str} on {dataset.__class__.__name__}")
results = []
track = copy.deepcopy(track)
time = 0.0
for i in tqdm(
track.run(model, dataset, n_checkpoints=N_CHECKPOINTS),
total=N_CHECKPOINTS,
):
time += i["Time"].total_seconds()
res = {
"step": i["Step"],
"track": track.name,
"model": model_name,
"dataset": dataset.__class__.__name__,
}
for k, v in i.items():
if isinstance(v, metrics.base.Metric):
res[k] = v.get()
res["Memory in Mb"] = i["Memory"] / 1024**2
res["Time in s"] = time
results.append(res)
if time > 3600:
break
return results
def run_track(models: List[str], no_track: int, n_workers: int = 50):
pool = multiprocessing.Pool(processes=n_workers)
track = TRACKS[no_track]
runs = list(
itertools.product(models, range(len(track.datasets)), [no_track])
)
results = []
for val in pool.starmap(run_dataset, runs):
results.extend(val)
csv_name = track.name.replace(" ", "_").lower()
pd.DataFrame(results).to_csv(f"./{csv_name}.csv", index=False)
if __name__ == "__main__":
MODELS["Binary classification"].update(MODELS["Multiclass classification"])
details = {}
# Create details for each model
for i, track in enumerate(TRACKS):
details[track.name] = {"Dataset": {}, "Model": {}}
for dataset in track.datasets:
details[track.name]["Dataset"][dataset.__class__.__name__] = repr(
dataset
)
for model_name, model in MODELS[track.name].items():
details[track.name]["Model"][model_name] = repr(model)
with open("details.json", "w") as f:
json.dump(details, f, indent=2)
run_track(models=MODELS[track.name].keys(), no_track=i, n_workers=10)