Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions keras_vggface/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@
from keras import layers


def get_name_formatted(name, prefix):
return [name if prefix == '' else '{}_{}'.format(name, prefix)][0]


def apply_name_formatting(model, prefix):
for each_layer in model.layers:
each_layer.name = get_name_formatted(each_layer.name, prefix)

return model


def VGG16(include_top=True, weights='vggface',
input_tensor=None, input_shape=None,
pooling=None,
classes=2622):
classes=2622, prefix=''):
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=48,
Expand All @@ -48,15 +59,12 @@ def VGG16(include_top=True, weights='vggface',
x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)

# Block 2
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_1')(
x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_2')(
x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_1')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_2')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)

# Block 3
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_1')(
x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_1')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_2')(
x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_3')(
Expand Down Expand Up @@ -135,6 +143,9 @@ def VGG16(include_top=True, weights='vggface',
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')

model = apply_name_formatting(model, prefix)

return model


Expand Down Expand Up @@ -207,7 +218,7 @@ def resnet_conv_block(input_tensor, kernel_size, filters, stage, block,
def RESNET50(include_top=True, weights='vggface',
input_tensor=None, input_shape=None,
pooling=None,
classes=8631):
classes=8631, prefix=''):
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=32,
Expand Down Expand Up @@ -306,6 +317,8 @@ def RESNET50(include_top=True, weights='vggface',
elif weights is not None:
model.load_weights(weights)

model = apply_name_formatting(model, prefix)

return model


Expand Down Expand Up @@ -412,7 +425,7 @@ def senet_identity_block(input_tensor, kernel_size,
def SENET50(include_top=True, weights='vggface',
input_tensor=None, input_shape=None,
pooling=None,
classes=8631):
classes=8631, prefix=''):
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=197,
Expand Down Expand Up @@ -513,4 +526,6 @@ def SENET50(include_top=True, weights='vggface',
elif weights is not None:
model.load_weights(weights)

model = apply_name_formatting(model, prefix)

return model
8 changes: 4 additions & 4 deletions keras_vggface/vggface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def VGGFace(include_top=True, model='vgg16', weights='vggface',
input_tensor=None, input_shape=None,
pooling=None,
classes=None):
classes=None, prefix=''):
"""Instantiates the VGGFace architectures.
Optionally loads weights pre-trained
on VGGFace datasets. Note that when using TensorFlow,
Expand Down Expand Up @@ -78,7 +78,7 @@ def VGGFace(include_top=True, model='vgg16', weights='vggface',
return VGG16(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)
classes=classes, prefix=prefix)


if model == 'resnet50':
Expand All @@ -94,7 +94,7 @@ def VGGFace(include_top=True, model='vgg16', weights='vggface',
return RESNET50(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)
classes=classes, prefix=prefix)

if model == 'senet50':

Expand All @@ -109,4 +109,4 @@ def VGGFace(include_top=True, model='vgg16', weights='vggface',
return SENET50(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)
classes=classes, prefix=prefix)