-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
63 lines (40 loc) · 1.32 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Miscellaneous utils."""
import csv
import json
import os
CHECKPOINT_FILENAME = 'checkpoint.pt'
CONFIG_FILENAME = 'config.json'
LOG_FILENAME = 'log.csv'
class CSVLogger:
def __init__(self, fieldnames, filepath):
self.filepath = filepath
self.csv_file = open(filepath, 'w')
self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.writer.writeheader()
self.csv_file.flush()
def writerow(self, row):
self.writer.writerow(row)
self.csv_file.flush()
def close(self):
self.csv_file.close()
def write_dict_to_json(data, filepath):
with open(filepath, 'w') as f:
json.dump(data, f)
def write_args_to_json(args, filepath):
args_dict = {}
for arg in vars(args):
args_dict[arg] = getattr(args, arg)
write_dict_to_json(args_dict, filepath)
def get_log_path(root_dir):
return os.path.join(root_dir, LOG_FILENAME)
def get_config_path(root_dir):
return os.path.join(root_dir, CONFIG_FILENAME)
def get_checkpoint_path(root_dir):
return os.path.join(root_dir, 'checkpoints', CHECKPOINT_FILENAME)
def create_directory(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def load_config(filepath):
with open(filepath, 'r') as f:
config = json.load(f)
return config