@@ -237,7 +237,8 @@ def _construct_model(self) -> None:
237
237
model_loc = self ._parameters ["model_path" ]
238
238
239
239
self ._model : tf .keras .Model = tf .keras .models .load_model (model_loc )
240
- softmax_output_layer_name = self ._model .outputs [0 ].name .split ("/" )[0 ]
240
+ self ._model = tf .keras .Model (self ._model .inputs , self ._model .outputs )
241
+ softmax_output_layer_name = self ._model .output_names [0 ]
241
242
softmax_layer_ind = cast (
242
243
int ,
243
244
labeler_utils .get_tf_layer_index_from_name (
@@ -252,21 +253,28 @@ def _construct_model(self) -> None:
252
253
num_labels , activation = "softmax" , name = "softmax_output"
253
254
)(self ._model .layers [softmax_layer_ind - 1 ].output )
254
255
255
- # Output the model into a .pb file for TensorFlow
256
- argmax_layer = tf .keras .backend .argmax (new_softmax_layer )
256
+ # Add argmax layer to get labels directly as an output
257
+ argmax_layer = tf .keras .ops .argmax (new_softmax_layer , axis = 2 )
257
258
258
259
argmax_outputs = [new_softmax_layer , argmax_layer ]
259
260
self ._model = tf .keras .Model (self ._model .inputs , argmax_outputs )
261
+ self ._model = tf .keras .Model (self ._model .inputs , self ._model .outputs )
260
262
261
263
# Compile the model w/ metrics
262
- softmax_output_layer_name = self ._model .outputs [ 0 ]. name . split ( "/" ) [0 ]
264
+ softmax_output_layer_name = self ._model .output_names [0 ]
263
265
losses = {softmax_output_layer_name : "categorical_crossentropy" }
264
266
265
267
# use f1 score metric
266
268
f1_score_training = labeler_utils .F1Score (
267
269
num_classes = num_labels , average = "micro"
268
270
)
269
- metrics = {softmax_output_layer_name : ["acc" , f1_score_training ]}
271
+ metrics = {
272
+ softmax_output_layer_name : [
273
+ "categorical_crossentropy" ,
274
+ "acc" ,
275
+ f1_score_training ,
276
+ ]
277
+ }
270
278
271
279
self ._model .compile (loss = losses , optimizer = "adam" , metrics = metrics )
272
280
@@ -294,30 +302,33 @@ def _reconstruct_model(self) -> None:
294
302
num_labels = self .num_labels
295
303
default_ind = self .label_mapping [self ._parameters ["default_label" ]]
296
304
297
- # Remove the 2 output layers ('softmax', 'tf_op_layer_ArgMax')
298
- for _ in range (2 ):
299
- self ._model .layers .pop ()
300
-
301
305
# Add the final Softmax layer to the previous spot
306
+ # self._model.layers[-2] to skip: original softmax
302
307
final_softmax_layer = tf .keras .layers .Dense (
303
308
num_labels , activation = "softmax" , name = "softmax_output"
304
- )(self ._model .layers [- 4 ].output )
309
+ )(self ._model .layers [- 2 ].output )
305
310
306
- # Output the model into a .pb file for TensorFlow
307
- argmax_layer = tf .keras .backend .argmax (final_softmax_layer )
311
+ # Add argmax layer to get labels directly as an output
312
+ argmax_layer = tf .keras .ops .argmax (final_softmax_layer , axis = 2 )
308
313
309
314
argmax_outputs = [final_softmax_layer , argmax_layer ]
310
315
self ._model = tf .keras .Model (self ._model .inputs , argmax_outputs )
311
316
312
317
# Compile the model
313
- softmax_output_layer_name = self ._model .outputs [ 0 ]. name . split ( "/" ) [0 ]
318
+ softmax_output_layer_name = self ._model .output_names [0 ]
314
319
losses = {softmax_output_layer_name : "categorical_crossentropy" }
315
320
316
321
# use f1 score metric
317
322
f1_score_training = labeler_utils .F1Score (
318
323
num_classes = num_labels , average = "micro"
319
324
)
320
- metrics = {softmax_output_layer_name : ["acc" , f1_score_training ]}
325
+ metrics = {
326
+ softmax_output_layer_name : [
327
+ "categorical_crossentropy" ,
328
+ "acc" ,
329
+ f1_score_training ,
330
+ ]
331
+ }
321
332
322
333
self ._model .compile (loss = losses , optimizer = "adam" , metrics = metrics )
323
334
@@ -370,7 +381,7 @@ def fit(
370
381
f1_report : dict = {}
371
382
372
383
self ._model .reset_metrics ()
373
- softmax_output_layer_name = self ._model .outputs [ 0 ]. name . split ( "/" ) [0 ]
384
+ softmax_output_layer_name = self ._model .output_names [0 ]
374
385
375
386
start_time = time .time ()
376
387
batch_id = 0
0 commit comments