-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest.py
225 lines (170 loc) · 6.63 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Callable script to test any model on any dataset
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 11/06/2018
# Nicolas DONATI - 03/2020
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import time
import os
import numpy as np
# My libs
from utils.config import Config
from configs.SurrealConfig import SurrealConfig
from configs.FAUST_rConfig import FAUST_rConfig
from configs.SCAPE_rConfig import SCAPE_rConfig
from configs.SHREC_rConfig import SHREC_rConfig
from utils.tester import ModelTester
from models.KPCNN_FM_model import KernelPointCNN_FM
# Datasets
from datasets.Surreal import SurrealDataset
from datasets.FAUST_remeshed import FAUST_r_Dataset
from datasets.SCAPE_remeshed import SCAPE_r_Dataset
from datasets.SHREC_remeshed import SHREC_r_Dataset
# ----------------------------------------------------------------------------------------------------------------------
#
# Utility functions
# \***********************/
#
def test_caller(path, step_ind):
##########################
# Initiate the environment
##########################
# Choose which gpu to use
GPU_ID = '0'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
# Disable warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
###########################
# Load the model parameters
###########################
test_dataset = 'FAUST_r' # 'FAUST_r' # 'SCAPE_r' # 'SHREC_r'
# Load model parameters at train time (but we also need a test config)
#config_train = Config()
#config_train.load(path)
##################################
# Change model parameters for test
##################################
# Change parameters for the test here. For example, you can stop augmenting the input data.
#config.augment_noise = 0.0001
#config.augment_color = 1.0
#config.validation_size = 500
#config.batch_num = 10
##############
# Prepare Data
##############
print()
print('Dataset Preparation')
print('*******************')
# Initiate dataset configuration
config = FAUST_rConfig() # default setting
dataset = FAUST_r_Dataset(config)
if test_dataset == 'FAUST_r':
print('default setting')
# config = FAUST_rConfig()
# dataset = FAUST_r_Dataset(config)
elif test_dataset == 'SCAPE_r':
config = SCAPE_rConfig()
dataset = SCAPE_r_Dataset(config)
elif test_dataset == 'SHREC':
config = SHRECConfig()
dataset = SHREC_Dataset(config)
elif test_dataset == 'SHREC_r':
config = SHREC_rConfig()
dataset = SHREC_r_Dataset(config)
else:
raise ValueError('dataset not supported')
config.epoch_steps = 0 # to avoid it should be None
config.load(path) # get the exact same parameters for the network (which will be re-loaded afterwards)
print('number of eigenvectors in loaded model :', config.neig) #, config.epoch_steps)
# re-define some config parameters for testing by overwriting config
print('rotations at training were set to :', config.augment_rotation)
config.augment_rotation = 'none'
config.batch_num = 1
config.split = 'test'
dataset.split = 'test'
dataset.num_train = -1 # we want all test data
dataset.load_subsampled_clouds(config.first_subsampling_dl)
# Initialize input pipelines to get batch generator
dataset.init_test_input_pipeline(config)
##############
# Define Model
##############
print('Creating Model')
print('**************\n')
t1 = time.time()
model = KernelPointCNN_FM(dataset.flat_inputs, dataset.flat_inputs_2, config)
# Find all snapshot in the chosen training folder
snap_path = os.path.join(path, 'snapshots')
snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta']
# Find which snapshot to restore
print('restoring snapshot')
chosen_step = np.sort(snap_steps)[step_ind]
chosen_snap = os.path.join(path, 'snapshots', 'snap-{:d}'.format(chosen_step))
print(chosen_snap)
# Create a tester class
print('Model tester')
model.fld_name = dataset.dataset_name
tester = ModelTester(model, restore_snap=chosen_snap)
t2 = time.time()
print('\n----------------')
print('Done in {:.1f} s'.format(t2 - t1))
print('----------------\n')
############
# Start test
############
print('Start Test')
print('**********\n')
tester.test_shape_matching(model, dataset)
# ----------------------------------------------------------------------------------------------------------------------
#
# Main Call
# \***************/
#
if __name__ == '__main__':
##########################
# Choose the model to test
##########################
#chosen_log = 'results/Log_2020-03-16_16-01-48' # Log_2020-03-12_16-23-10' # 'Log_2020-02-13_16-09-59' # surreal100
#chosen_log = 'results/Log_2020-03-04_17-26-11' # 'Log_2020-02-27_14-55-10' # surreal5k
#chosen_log = 'results/Log_2020-03-04_21-01-24' # s2k
#chosen_log = 'results/Log_2020-03-10_16-51-30' # 'Log_2020-03-06_19-04-32' # s500
#chosen_log = 'results/Log_2020-03-10_16-51-51' # 'Log_2020-03-07_13-34-24' # s100
#chosen_log = 'results/Log_2020-03-05_08-34-57' # s5k NoReg
#chosen_log = 'results/Log_2020-03-08_07-58-35' #FAUSTr
#chosen_log = 'results/Log_2020-03-09_08-35-52' #SCAPEr
chosen_log = 'results/Log_2020-03-28_17-51-04'
chosen_log = 'results/Log_2020-03-29_20-41-26'
#
# You can also choose the index of the snapshot to load (last by default)
#
chosen_snapshot = -1
#
# Eventually, you can choose to test your model on the validation set
#
on_val = False
#
# If you want to modify certain parameters in the Config class, for example, to stop augmenting the input data,
# there is a section for it in the function "test_caller" defined above.
#
# Check if log exists
if not os.path.exists(chosen_log):
raise ValueError('The given log does not exists: ' + chosen_log)
# Let's go
test_caller(chosen_log, chosen_snapshot)