32
32
import json
33
33
import os
34
34
35
- import keras
35
+ import tensorflow as tf
36
36
import tensorflowjs as tfjs
37
37
38
38
@@ -50,7 +50,7 @@ def get_word_index(reverse=False):
50
50
Returns:
51
51
The word index as a `dict`.
52
52
"""
53
- word_index = keras .datasets .imdb .get_word_index ()
53
+ word_index = tf . keras .datasets .imdb .get_word_index ()
54
54
if reverse :
55
55
word_index = dict ((word_index [key ], key ) for key in word_index )
56
56
return word_index
@@ -85,10 +85,10 @@ def get_imdb_data(vocabulary_size, max_len):
85
85
y_test: Same as `y_train`, but for test.
86
86
"""
87
87
print ("Getting IMDB data with vocabulary_size %d" % vocabulary_size )
88
- (x_train , y_train ), (x_test , y_test ) = keras .datasets .imdb .load_data (
88
+ (x_train , y_train ), (x_test , y_test ) = tf . keras .datasets .imdb .load_data (
89
89
num_words = vocabulary_size )
90
- x_train = keras .preprocessing .sequence .pad_sequences (x_train , maxlen = max_len )
91
- x_test = keras .preprocessing .sequence .pad_sequences (x_test , maxlen = max_len )
90
+ x_train = tf . keras .preprocessing .sequence .pad_sequences (x_train , maxlen = max_len )
91
+ x_test = tf . keras .preprocessing .sequence .pad_sequences (x_test , maxlen = max_len )
92
92
return x_train , y_train , x_test , y_test
93
93
94
94
@@ -122,38 +122,30 @@ def train_model(model_type,
122
122
ValueError: on invalid model type.
123
123
"""
124
124
125
- model = keras .Sequential ()
126
- model .add (keras .layers .Embedding (vocabulary_size , embedding_size ))
125
+ model = tf . keras .Sequential ()
126
+ model .add (tf . keras .layers .Embedding (vocabulary_size , embedding_size ))
127
127
if model_type == 'bidirectional_lstm' :
128
128
# TODO(cais): Uncomment the following once bug b/74429960 is fixed.
129
- # model.add(keras.layers.Embedding(
129
+ # model.add(tf. keras.layers.Embedding(
130
130
# vocabulary_size, 128, input_length=maxlen))
131
- # model.add(keras.layers.Bidirectional(
132
- # keras.layers.LSTM(64,
133
- # kernel_initializer='glorot_normal',
134
- # recurrent_initializer ='glorot_normal')))
135
- # model.add(keras.layers.Dropout(0.5))
131
+ # model.add(tf.keras.layers.Bidirectional(
132
+ # tf.keras.layers.LSTM(64))
133
+ # model.add(tf.keras.layers.Dropout(0.5))
136
134
raise NotImplementedError ()
137
135
elif model_type == 'cnn' :
138
- model .add (keras .layers .Dropout (0.2 ))
139
- model .add (keras .layers .Conv1D (250 ,
136
+ model .add (tf . keras .layers .Dropout (0.2 ))
137
+ model .add (tf . keras .layers .Conv1D (250 ,
140
138
3 ,
141
139
padding = 'valid' ,
142
140
activation = 'relu' ,
143
141
strides = 1 ))
144
- model .add (keras .layers .GlobalMaxPooling1D ())
145
- model .add (keras .layers .Dense (250 , activation = 'relu' ))
142
+ model .add (tf . keras .layers .GlobalMaxPooling1D ())
143
+ model .add (tf . keras .layers .Dense (250 , activation = 'relu' ))
146
144
elif model_type == 'lstm' :
147
- model .add (
148
- keras .layers .LSTM (
149
- 128 ,
150
- kernel_initializer = 'glorot_normal' ,
151
- recurrent_initializer = 'glorot_normal' ))
152
- # TODO(cais): Remove glorot_normal and use the default orthogonal once
153
- # SVD is available.
145
+ model .add (tf .keras .layers .LSTM (128 ))
154
146
else :
155
147
raise ValueError ("Invalid model type: '%s'" % model_type )
156
- model .add (keras .layers .Dense (1 , activation = 'sigmoid' ))
148
+ model .add (tf . keras .layers .Dense (1 , activation = 'sigmoid' ))
157
149
158
150
model .compile ('adam' , 'binary_crossentropy' , metrics = ['accuracy' ])
159
151
model .fit (x_train , y_train ,
@@ -210,7 +202,7 @@ def main():
210
202
tfjs .converters .save_keras_model (model , FLAGS .artifacts_dir )
211
203
print ('\n Saved model artifacts in directory: %s' % FLAGS .artifacts_dir )
212
204
213
-
205
+
214
206
if __name__ == '__main__' :
215
207
parser = argparse .ArgumentParser ('IMDB sentiment classification model' )
216
208
parser .add_argument (
0 commit comments