-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtalos_training_generator_CNN.py
More file actions
40 lines (32 loc) · 1.12 KB
/
talos_training_generator_CNN.py
File metadata and controls
40 lines (32 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from models import classifier_cnn
import glob
import os
import talos
import pickle
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
folder = '/projects/satdb/cnfgen_graph/dataset_preprocessed/train/'
files = glob.glob(folder + '*npy')
x0 = np.load(files[0])
model = classifier_cnn.ClassifierCNN()
experiment_name = 'CNN_Classifier_generator'
model.name = experiment_name
model.parameter_list['path'] = [folder]
model.max_variables = int(x0.shape[1]/2)
model.max_clauses = x0.shape[0]
model.encoding_size = x0.shape[0]*x0.shape[1]
model.verbose = True
dummy_x = np.empty((1,model.max_variables*2, model.max_clauses, 1))
dummy_y = np.empty((1, 1))
testX = np.empty((1,model.max_variables*2, model.max_clauses, 1))
testY = np.empty((1, 1))
t = talos.Scan(x=dummy_x,
y=dummy_y,
x_val=testX,
y_val=testY,
model=model.training,
experiment_name=experiment_name,
params=model.parameter_list,
round_limit=50)
filehandler = open("./" + ds.name + ".obj", 'wb')
pickle.dump(t.data, filehandler)