-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinat_wt_diff.py
More file actions
125 lines (81 loc) · 2.44 KB
/
inat_wt_diff.py
File metadata and controls
125 lines (81 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
import torch
import numpy as np
import pandas as pd
import resnet_XAI as models
import torch.utils.data
#dataset of methods
meth_eng = ['imb', 'ros', 'rem', 'dsm', 'eos']
paths=[
#base
".../INAT/_aug_INat_mod_37_0.pth",
".../INAT/_aug_INat_mod_38_10.pth",
".../INAT/_aug_INat_mod_36_100.pth",
#ROS
".../INAT/ROS_aug_INat_mod_32_0.pth",
".../INAT/ROS_aug_INat_mod_33_10.pth",
".../INAT/ROS_aug_INat_mod_36_100.pth",
#REM
".../INAT/rem_INat_mod_35_0.pth",
".../INAT/rem_INat_mod_35_10.pth",
".../INAT/rem_INat_mod_27_100.pth",
#DSM
".../INAT/inat_DSM_0_comb.pth",
".../INAT/inat_DSM_10_comb.pth",
".../INAT/inat_DSM_100_comb.pth",
#EOS
".../INAT/inat_EOS_0_comb.pth",
".../INATinat_EOS_10_comb.pth",
".../INAT/inat_EOS_100_comb.pth",
]
runs = 3
num_cls = 5
num_wts = 64
#DA methods - files to collect ouput
imb = np.zeros((runs*num_cls,num_wts))
ros = np.zeros((runs*num_cls,num_wts))
rem = np.zeros((runs*num_cls,num_wts))
dsm = np.zeros((runs*num_cls,num_wts))
eos = np.zeros((runs*num_cls,num_wts))
meths = [imb,ros,rem,dsm,eos]
model_count = 0
for m in range(len(meth_eng)):
print('method',m,meth_eng[m])
row_count = 0
for r in range(runs):
print('run',r)
print('model',model_count)
print(paths[model_count])
use_norm = False
model = models.resnet56(num_classes=num_cls, use_norm=use_norm)
torch.cuda.set_device(0)
model = model.cuda(0)
model.load_state_dict(torch.load(paths[model_count]))
wt = model.linear.weight
wt = wt.detach().cpu().numpy()
print(wt.shape)
meth = meths[m]
print(meth.shape)
print(row_count,row_count+num_cls)
meth[row_count:row_count+num_cls,:]= wt
row_count+=num_cls
model_count+=1
for m in range(len(meth_eng)):
print(meth_eng[m])
print(meths[m])
print()
len_meth = len(meth_eng)
diffs = np.zeros(len_meth-1)
count=0
for m in range(1,len_meth):
diff = np.abs(meths[m] - meths[0])
diff = diff / np.abs(meths[0])
diff = np.mean(diff)
diffs[count]=diff
count+=1
print(diffs)
dif = diffs.reshape(1,-1)
pdf = pd.DataFrame(data=dif,columns=[['ROS',
'REMIX', 'DSM', 'EOS']])
f= ".../INAT/inat_wt_diffs.csv"
pdf.to_csv(f,index=False)