diff --git a/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py b/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py index 54c06958f..fd31212f0 100644 --- a/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py +++ b/09_Recurrent_Neural_Networks/02_Implementing_RNN_for_Spam_Prediction/02_implementing_rnn.py @@ -111,8 +111,9 @@ def clean_text(text_string): output = tf.nn.dropout(output, dropout_keep_prob) # Get output of RNN sequence -output = tf.transpose(output, [1, 0, 2]) -last = tf.gather(output, int(output.get_shape()[0]) - 1) +# output = tf.transpose(output, [1, 0, 2]) +# last = tf.gather(output, int(output.get_shape()[0]) - 1) +last = output[:,-1,:] weight = tf.Variable(tf.truncated_normal([rnn_size, 2], stddev=0.1)) bias = tf.Variable(tf.constant(0.1, shape=[2])) @@ -183,4 +184,4 @@ def clean_text(text_string): plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend(loc='upper left') -plt.show() \ No newline at end of file +plt.show()