From b1d54373e34cd313568a64d587fb855906875af9 Mon Sep 17 00:00:00 2001 From: Think tank of Sichuan University Date: Wed, 28 Nov 2018 21:27:56 +0800 Subject: [PATCH] the original method seems too complicated For beginners, the original method seems too complicated. In fact, it only needs simple tensor slicing operation to achieve the same effect. --- .../02_implementing_rnn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py b/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py index 54c06958f..fd31212f0 100644 --- a/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py +++ b/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py @@ -111,8 +111,9 @@ def clean_text(text_string): output = tf.nn.dropout(output, dropout_keep_prob) # Get output of RNN sequence -output = tf.transpose(output, [1, 0, 2]) -last = tf.gather(output, int(output.get_shape()[0]) - 1) +# output = tf.transpose(output, [1, 0, 2]) +# last = tf.gather(output, int(output.get_shape()[0]) - 1) +last = output[:,-1,:] weight = tf.Variable(tf.truncated_normal([rnn_size, 2], stddev=0.1)) bias = tf.Variable(tf.constant(0.1, shape=[2])) @@ -183,4 +184,4 @@ def clean_text(text_string): plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend(loc='upper left') -plt.show() \ No newline at end of file +plt.show()