74
74
from deepcpg import data as dat
75
75
from deepcpg import metrics as met
76
76
from deepcpg import models as mod
77
+ from deepcpg .models .utils import is_input_layer , is_output_layer
77
78
from deepcpg .data import hdf , OUTPUT_SEP
78
79
from deepcpg .utils import format_table , make_dir , EPS
79
80
86
87
87
88
88
89
def remove_outputs (model ):
89
- while model .layers [- 1 ] in model . output_layers :
90
+ while is_output_layer ( model .layers [- 1 ], model ) :
90
91
model .layers .pop ()
91
92
model .outputs = [model .layers [- 1 ].output ]
92
93
model .layers [- 1 ].outbound_nodes = []
@@ -97,7 +98,7 @@ def rename_layers(model, scope=None):
97
98
if not scope :
98
99
scope = model .scope
99
100
for layer in model .layers :
100
- if layer in model . input_layers or layer .name .startswith (scope ):
101
+ if is_input_layer ( layer ) or layer .name .startswith (scope ):
101
102
continue
102
103
layer .name = '%s/%s' % (scope , layer .name )
103
104
@@ -595,7 +596,7 @@ def build_model(self):
595
596
remove_outputs (stem )
596
597
597
598
outputs = mod .add_output_layers (stem .outputs [0 ], output_names )
598
- model = Model (stem .inputs , outputs , stem .name )
599
+ model = Model (inputs = stem .inputs , outputs = outputs , name = stem .name )
599
600
return model
600
601
601
602
def set_trainability (self , model ):
@@ -622,17 +623,18 @@ def set_trainability(self, model):
622
623
table ['layer' ] = []
623
624
table ['trainable' ] = []
624
625
for layer in model .layers :
625
- if layer not in model .input_layers + model .output_layers :
626
- if not hasattr (layer , 'trainable' ):
627
- continue
628
- for regex in not_trainable :
629
- if re .match (regex , layer .name ):
630
- layer .trainable = False
631
- for regex in trainable :
632
- if re .match (regex , layer .name ):
633
- layer .trainable = True
634
- table ['layer' ].append (layer .name )
635
- table ['trainable' ].append (layer .trainable )
626
+ if is_input_layer (layer ) or is_output_layer (layer , model ):
627
+ continue
628
+ if not hasattr (layer , 'trainable' ):
629
+ continue
630
+ for regex in not_trainable :
631
+ if re .match (regex , layer .name ):
632
+ layer .trainable = False
633
+ for regex in trainable :
634
+ if re .match (regex , layer .name ):
635
+ layer .trainable = True
636
+ table ['layer' ].append (layer .name )
637
+ table ['trainable' ].append (layer .trainable )
636
638
print ('Layer trainability:' )
637
639
print (format_table (table ))
638
640
print ()
@@ -713,9 +715,7 @@ def main(self, name, opts):
713
715
mod .save_model (model , os .path .join (opts .out_dir , 'model.json' ))
714
716
715
717
log .info ('Computing output statistics ...' )
716
- output_names = []
717
- for output_layer in model .output_layers :
718
- output_names .append (output_layer .name )
718
+ output_names = model .output_names
719
719
720
720
output_stats = OrderedDict ()
721
721
0 commit comments