-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_expt.py
143 lines (126 loc) · 4.85 KB
/
run_expt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import os.path as osp
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torchvision
from args import parse_args
from disc.train import train
from disc.models import model_attributes
from disc.dataset.cifar10_dataset import prepare_cifar10_data
from disc.dataset.fmow_dataset import prepare_fmow_data
from disc.dataset.dro_dataset import DRODataset
from disc.dataset.folds import Subset, get_fold
from disc.dataset.load_data import dataset_attributes, shift_types, prepare_data, log_data, log_meta_data
from disc.utils.tools import set_seed, Logger, CSVBatchLogger, log_args, get_model, check_args, set_log_dir
if __name__=='__main__':
# Load args
args = parse_args()
set_log_dir(args)
check_args(args)
set_seed(args.seed)
## Initialize logs
if not osp.exists(args.log_dir):
os.makedirs(args.log_dir)
args.mode = 'a' if (osp.exists(osp.join(args.log_dir, 'last_model.pth')) and args.resume) else 'w'
logger = Logger(osp.join(args.log_dir, f'log.txt'), args.mode)
# Prepare data
if args.dataset == 'CIFAR10':
train_data, val_data, test_data = prepare_cifar10_data(args)
elif args.dataset == 'FMoW':
train_data, val_data, test_data = prepare_fmow_data(args)
elif args.shift_type == 'confounder':
train_data, val_data, test_data = prepare_data(args, train=True)
elif args.shift_type == 'label_shift_step': # not used
train_data, val_data = prepare_data(args, train=True)
else:
raise NotImplementedError
# Record args
args.n_groups = train_data.n_groups
args.n_classes = train_data.n_classes
args.input_size = train_data.input_size()
log_args(args, logger)
# Prepare loaders
args.loader_kwargs = {'batch_size': args.batch_size, 'num_workers': 1, 'pin_memory': False}
if args.fold:
train_data, val_data = get_fold(
train_data, args.fold,
num_valid_per_point=args.num_sweeps,
cross_validation_ratio=(1 / args.num_folds_per_sweep),
seed=args.seed
)
if args.lisa_mix_up:
train_loader = {}
for i in range(train_data.n_groups):
idxes = np.where(train_data.get_group_array() == i)[0]
if len(idxes) == 0:
continue
subset = DRODataset(
Subset(train_data, idxes),
process_item_fn=None,
n_groups=train_data.n_groups,
n_classes=train_data.n_classes,
group_str_fn=train_data.group_str
)
train_loader[i] = subset.get_loader(
train=True, reweight_groups=False, **args.loader_kwargs
)
else:
train_loader = train_data.get_loader(reweight_groups=args.reweight_groups,
train=True, **args.loader_kwargs)
test_loader = test_data.get_loader(train=False, reweight_groups=None, **args.loader_kwargs)
val_loader = val_data.get_loader(train=False, reweight_groups=None, **args.loader_kwargs)
# Gather all loaders and datasets
data = {
'train_data': train_data,
'test_data': test_data,
'val_data': val_data,
'train_loader': train_loader,
'test_loader': test_loader,
'val_loader': val_loader
}
## Output logger to file
if "Meta" in args.dataset:
log_meta_data(data, logger)
else:
log_data(data, logger)
## Initialize model
model = get_model(
args, args.n_classes,
args.input_size[0], args.resume
)
model = model.cuda()
logger.flush()
## Define the objective
criterion = torch.nn.CrossEntropyLoss(reduction='none')
# Get resume information if needed
if args.resume:
df = pd.read_csv(osp.join(args.log_dir, f'test.csv'))
epoch_offset = df.loc[len(df)-1,'epoch']+1
logger.write(f'starting from epoch {epoch_offset}')
else:
epoch_offset=0
# Set up CSV loggers
csv_loggers = {}
for split in ['train', 'test', 'val']:
csv_loggers[split] = CSVBatchLogger(
args, osp.join(args.log_dir, f'{split}.csv'),
data[f'{split}_data'].n_groups, mode=args.mode
)
# Training
train(
args, model, criterion,
data, logger, csv_loggers,
args.n_classes, epoch_offset=epoch_offset
)
for split in ['train', 'test', 'val']:
csv_loggers[split].close()
# Final results
val_csv = pd.read_csv(osp.join(args.log_dir, f'val.csv'))
test_csv = pd.read_csv(osp.join(args.log_dir, f'test.csv'))
metric = 'roc_auc' if args.dataset == 'ISIC' else 'worst_group_acc'
idx = np.argmax(val_csv[metric].values)
logger.write(str(test_csv[[metric, 'mean_differences', "group_avg_acc", "avg_acc"]].iloc[idx]))