|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import tensorflow as tf |
| 4 | +import numpy as np |
| 5 | +import pickle as pk |
| 6 | +from tensorflow.examples.tutorials.mnist import input_data |
| 7 | +home='/home/liuyang' |
| 8 | + |
| 9 | +INPUT_NODE = 784 |
| 10 | +OUTPUT_NODE = 10 |
| 11 | +IMAGE_SIZE = 28 |
| 12 | +NUM_CHANNELS = 1 |
| 13 | +NUM_LABELS = 10 |
| 14 | +CONV1_DEEP = 32 |
| 15 | +CONV1_SIZE = 5 |
| 16 | +CONV2_DEEP = 64 |
| 17 | +CONV2_SIZE = 5 |
| 18 | +FC_SIZE = 512 |
| 19 | + |
| 20 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| 21 | +config=tf.ConfigProto() |
| 22 | +config.gpu_options.allow_growth=True#增长式 |
| 23 | + |
| 24 | +# 配置神经网络的参数 |
| 25 | +BATCH_SIZE = 100 |
| 26 | +REGULARAZTION_RATE = 1e-4 |
| 27 | +TRAINING_STEPS = 30000 |
| 28 | +LABEL='relu' |
| 29 | + |
| 30 | +def train(mnist): |
| 31 | + # 定义输入输出placeholder |
| 32 | + x = tf.placeholder(tf.float32, [None, |
| 33 | + IMAGE_SIZE, # 第一维表示一个batch中样例的个数 |
| 34 | + IMAGE_SIZE, # 第二维和第三维表示图片的尺寸 |
| 35 | + NUM_CHANNELS], # 第四维表示图片的深度,对于RGB格式的图片,深度为5 |
| 36 | + name='x-input') |
| 37 | + y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') |
| 38 | + dropout=tf.placeholder(tf.float32) |
| 39 | + regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) |
| 40 | + |
| 41 | + def print_shape(t): |
| 42 | + print(t.op.name,' ',t.get_shape().as_list()) |
| 43 | + def weight(shape): |
| 44 | + return tf.Variable(tf.truncated_normal(shape,stddev=0.1)) |
| 45 | + def bias(shape): |
| 46 | + return tf.Variable(tf.constant(0.1,shape=shape)) |
| 47 | + def activate(x,b,label): |
| 48 | + if label=='relu': |
| 49 | + return tf.nn.relu(x+b) |
| 50 | + |
| 51 | + w1=weight([CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP]) |
| 52 | + b1=bias([CONV1_DEEP]) |
| 53 | + net1 = tf.nn.conv2d(x, w1, strides=[1, 1, 1, 1], padding='SAME',name='conv1') |
| 54 | + print_shape(net1) |
| 55 | + net = activate(net1, b1, LABEL) |
| 56 | + net = tf.nn.max_pool(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',name='pool1') |
| 57 | + |
| 58 | + w2=weight([CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP]) |
| 59 | + b2=bias([CONV2_DEEP]) |
| 60 | + net2 = tf.nn.conv2d(net, w2, strides=[1, 1, 1, 1], padding='SAME',name='con2') |
| 61 | + print_shape(net2) |
| 62 | + net = activate(net2, b2, LABEL) |
| 63 | + net = tf.nn.max_pool(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',name='pool2') |
| 64 | + |
| 65 | + #w3_=weight([2, 2, 64, 64]) |
| 66 | + #b3_=bias([64]) |
| 67 | + #net3 = tf.nn.conv2d(net, w3_, strides=[1, 1, 1, 1], padding='SAME',name='con3') |
| 68 | + #print_shape(net3) |
| 69 | + #net = activate(net3, b3_, LABEL) |
| 70 | + #net = tf.nn.max_pool(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',name='pool3') |
| 71 | + |
| 72 | + reshaped = tf.reshape(net, [-1, 3136]) |
| 73 | + w3 = weight([3136, FC_SIZE]) |
| 74 | + if regularizer != None: |
| 75 | + tf.add_to_collection('losses', regularizer(w3)) |
| 76 | + b3 = bias([FC_SIZE]) |
| 77 | + net = activate(tf.matmul(reshaped, w3), b3, LABEL) |
| 78 | + net = tf.nn.dropout(net, dropout) |
| 79 | + |
| 80 | + w4 = weight([FC_SIZE, NUM_LABELS]) |
| 81 | + if regularizer != None: |
| 82 | + tf.add_to_collection('losses', regularizer(w4)) |
| 83 | + b4 = bias([NUM_LABELS]) |
| 84 | + |
| 85 | + logit = tf.matmul(net, w4,name='logit') + b4 |
| 86 | + y = tf.nn.softmax(logit) |
| 87 | + cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y+1e-10),reduction_indices=[1])) |
| 88 | + train_step=tf.train.AdamOptimizer(REGULARAZTION_RATE).minimize(cross_entropy) |
| 89 | + correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) |
| 90 | + accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) |
| 91 | + |
| 92 | + with tf.Session(config=config) as sess: |
| 93 | + tf.global_variables_initializer().run() |
| 94 | + # 在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成 |
| 95 | + for i in range(TRAINING_STEPS): |
| 96 | + xs, ys = mnist.train.next_batch(BATCH_SIZE) |
| 97 | + reshaped_xs = np.reshape(xs, (BATCH_SIZE, |
| 98 | + IMAGE_SIZE, |
| 99 | + IMAGE_SIZE, |
| 100 | + NUM_CHANNELS)) |
| 101 | + _, loss_value, acc= sess.run([train_step, cross_entropy,accuracy], feed_dict={x: reshaped_xs, y_: ys, dropout: 0.5}) |
| 102 | + if i%2000==0: |
| 103 | + print('training',i,'acc:',acc) |
| 104 | + |
| 105 | + test_accuracy=0.0 |
| 106 | + for j in range(int(10000/BATCH_SIZE)): |
| 107 | + xs,ys=mnist.test.next_batch(BATCH_SIZE) |
| 108 | + reshaped_xs = np.reshape(xs, (BATCH_SIZE, |
| 109 | + IMAGE_SIZE, |
| 110 | + IMAGE_SIZE, |
| 111 | + NUM_CHANNELS)) |
| 112 | + acc=accuracy.eval(feed_dict={x:reshaped_xs,y_:ys, dropout: 1.0}) |
| 113 | + test_accuracy+=acc |
| 114 | + test_accuracy= test_accuracy/int(10000/BATCH_SIZE) |
| 115 | + print('test_accuracy:',test_accuracy) |
| 116 | + |
| 117 | + def get(n,label): |
| 118 | + reshaped_x = np.reshape(mnist.test.images[n],(1,28,28,1)) |
| 119 | + reshaped_y = np.reshape(mnist.test.labels[n],(1,10)) |
| 120 | + feed_dict = {x: reshaped_x,y_: reshaped_y, dropout:1.0} |
| 121 | + |
| 122 | + y_pre , y_label= sess.run([y,y_],feed_dict=feed_dict) |
| 123 | + y_prediction = np.reshape(y_pre,(10)) |
| 124 | + y_prediction_label = np.reshape(y_label,(10)) |
| 125 | + |
| 126 | + y_prediction = y_prediction.tolist() |
| 127 | + y_prediction_label = y_prediction_label.tolist() |
| 128 | + |
| 129 | + prediction = y_prediction.index(max(y_prediction)) |
| 130 | + prediction_label = y_prediction_label.index(max(y_prediction_label)) |
| 131 | + if prediction == prediction_label: |
| 132 | + if not os.path.isdir(home+'/save/%s/%d'%(label,n)): |
| 133 | + os.makedirs(home+'/save/%s/%d'%(label,n)) |
| 134 | + feature_map1 = net1.eval(feed_dict=feed_dict) |
| 135 | + f=open(home+'/save/%s/%d/feature_map1.pk'%(label,n),'wb') |
| 136 | + pk.dump(feature_map1,f) |
| 137 | + f.close() |
| 138 | + feature_map2 = net2.eval(feed_dict=feed_dict) |
| 139 | + f=open(home+'/save/%s/%d/feature_map2.pk'%(label,n),'wb') |
| 140 | + pk.dump(feature_map2,f) |
| 141 | + f.close() |
| 142 | + f=open(home+'/save/%s/%d/prediction_%d'%(label,n,prediction),'w') |
| 143 | + f.write('1') |
| 144 | + f.close() |
| 145 | + weight_ = w3.eval() |
| 146 | + f=open(home+'/save/%s/%d/weight.pk'%(label,n),'wb') |
| 147 | + pk.dump(weight_,f) |
| 148 | + f.close() |
| 149 | + |
| 150 | + for i in range(1000): |
| 151 | + get(i,'base') |
| 152 | + |
| 153 | + sys.exit() |
| 154 | + |
| 155 | +def main(argv=None): |
| 156 | + mnist = input_data.read_data_sets("/home/liuyang/workspace/n_fold_superposition/src/MNIST_data", one_hot=True) |
| 157 | + train(mnist) |
| 158 | + |
| 159 | + |
| 160 | +if __name__ == '__main__': |
| 161 | + main() |
0 commit comments