Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This repository contains code for the following Keras models:
- VGG16
- VGG19
- ResNet50
- Inception v1 (GoogLeNet)
- Inception v3
- CRNN for music tagging

Expand Down Expand Up @@ -77,6 +78,8 @@ block4_pool_features = model.predict(x)
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) - please cite this paper if you use the VGG models in your work.
- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) - please cite this paper if you use the ResNet model in your work.
- [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567) - please cite this paper if you use the Inception v3 model in your work.
- [Going deeper with convolutions](http://arxiv.org/abs/1409.4842v1) - please cite this paper if you use the Inception v1 model in your work.

- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras)

Additionally, don't forget to [cite Keras](https://keras.io/getting-started/faq/#how-should-i-cite-keras) if you use these models.
Expand All @@ -87,4 +90,5 @@ Additionally, don't forget to [cite Keras](https://keras.io/getting-started/faq/
- All code in this repository is under the MIT license as specified by the LICENSE file.
- The ResNet50 weights are ported from the ones [released by Kaiming He](https://github.com/KaimingHe/deep-residual-networks) under the [MIT license](https://github.com/KaimingHe/deep-residual-networks/blob/master/LICENSE).
- The VGG16 and VGG19 weights are ported from the ones [released by VGG at Oxford](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) under the [Creative Commons Attribution License](https://creativecommons.org/licenses/by/4.0/).
- The Inception v1 weights are derived from the weight files provided by Google for the TensorFlow-slim model zoo (which is Apache 2 licensed).
- The Inception v3 weights are trained by ourselves and are released under the MIT license.
319 changes: 319 additions & 0 deletions inception_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
"""Inception V1 model for Keras.

Note that the input preprocessing function is different from the the VGG16 and ResNet models (same as Xception).

Also that (currently) the output predictions are for 1001 classes (with the 0 class being 'background'),
so require a shift compared to the other models here.

# Reference

- [Going deeper with convolutions](http://arxiv.org/abs/1409.4842v1)

"""
from __future__ import print_function
from __future__ import absolute_import

import warnings
import numpy as np

from keras.models import Model
from keras import layers
from keras.layers import Input
from keras.layers import Conv2D
from keras.layers import Activation
from keras.layers import BatchNormalization
from keras.layers import MaxPooling2D
from keras.layers import AveragePooling2D
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.engine.topology import get_source_inputs
from keras.utils.layer_utils import convert_all_kernels_in_model
from keras.utils.data_utils import get_file
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.preprocessing import image

WEIGHTS_PATH = 'http://redcatlabs.com/downloads/inception_v1_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'http://redcatlabs.com/downloads/inception_v1_weights_tf_dim_ordering_tf_kernels_notop.h5'

# conv2d_bn is similar to (but updated from) inception_v3 version
def conv2d_bn(x,
filters,
num_row,
num_col,
padding='same',
strides=(1, 1),
normalizer=True,
activation='relu',
name=None):
"""Utility function to apply conv + BN.
Arguments:
x: input tensor.
filters: filters in `Conv2D`.
num_row: height of the convolution kernel.
num_col: width of the convolution kernel.
padding: padding mode in `Conv2D`.
strides: strides in `Conv2D`.
name: name of the ops; will become `name + '_conv'`
for the convolution, `name + '_bn'` for the
batch norm layer and `name + '_act'` for the
activation layer.
Returns:
Output tensor after applying `Conv2D` and `BatchNormalization`.
"""
if name is not None:
conv_name = name + '_conv'
bn_name = name + '_bn'
act_name = name + '_act'
else:
conv_name = None
bn_name = None
act_name = None
if K.image_data_format() == 'channels_first':
bn_axis = 1
else:
bn_axis = 3
x = Conv2D(
filters, (num_row, num_col),
strides=strides, padding=padding,
use_bias=False, name=conv_name)(x)
if normalizer:
x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
if activation:
x = Activation(activation, name=act_name)(x)
return x

# Convenience function for 'standard' Inception concatenated blocks
def concatenated_block(x, specs, channel_axis, name):
(br0, br1, br2, br3) = specs # ((64,), (96,128), (16,32), (32,))

branch_0 = conv2d_bn(x, br0[0], 1, 1, name=name+"_Branch_0_a_1x1")

branch_1 = conv2d_bn(x, br1[0], 1, 1, name=name+"_Branch_1_a_1x1")
branch_1 = conv2d_bn(branch_1, br1[1], 3, 3, name=name+"_Branch_1_b_3x3")

branch_2 = conv2d_bn(x, br2[0], 1, 1, name=name+"_Branch_2_a_1x1")
branch_2 = conv2d_bn(branch_2, br2[1], 3, 3, name=name+"_Branch_2_b_3x3")

branch_3 = MaxPooling2D( (3, 3), strides=(1, 1), padding='same', name=name+"_Branch_3_a_max")(x)
branch_3 = conv2d_bn(branch_3, br3[0], 1, 1, name=name+"_Branch_3_b_1x1")

x = layers.concatenate(
[branch_0, branch_1, branch_2, branch_3],
axis=channel_axis,
name=name+"_Concatenated")
return x


def InceptionV1(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1001):
"""Instantiates the Inception v1 architecture.

This architecture is defined in:
Going deeper with convolutions
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
http://arxiv.org/abs/1409.4842v1

Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
`image_data_format="channels_last"` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
TensorFlow and Theano. The data format
convention used by the model is the one
specified in your Keras config file.
Note that the default input image size for this model is 224x224.
Arguments:
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization)
or "imagenet" (pre-training on ImageNet).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 139.
E.g. `(150, 150, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional layer.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional layer, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
Returns:
A Keras model instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
"""
if weights not in {'imagenet', None}:
raise ValueError('The `weights` argument should be either '
'`None` (random initialization) or `imagenet` '
'(pre-training on ImageNet).')

if weights == 'imagenet' and include_top and classes != 1001:
raise ValueError('If using `weights` as imagenet with `include_top`'
' as true, `classes` should be 1001')

# Determine proper input shape
input_shape = _obtain_input_shape(
input_shape,
#default_size=299,
default_size=224,
min_size=139,
data_format=K.image_data_format(),
include_top=include_top)

if input_tensor is None:
img_input = Input(shape=input_shape)
else:
img_input = Input(tensor=input_tensor, shape=input_shape)

if K.image_data_format() == 'channels_first':
channel_axis = 1
else:
channel_axis = 3

# 'Sequential bit at start'
x = img_input
x = conv2d_bn(x, 64, 7, 7, strides=(2, 2), padding='same', name='Conv2d_1a_7x7')

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='MaxPool_2a_3x3')(x)

x = conv2d_bn(x, 64, 1, 1, strides=(1, 1), padding='same', name='Conv2d_2b_1x1')
x = conv2d_bn(x, 192, 3, 3, strides=(1, 1), padding='same', name='Conv2d_2c_3x3')

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='MaxPool_3a_3x3')(x)

# Now the '3' level inception units
x = concatenated_block(x, (( 64,), ( 96,128), (16, 32), ( 32,)), channel_axis, 'Mixed_3b')
x = concatenated_block(x, ((128,), (128,192), (32, 96), ( 64,)), channel_axis, 'Mixed_3c')

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='MaxPool_4a_3x3')(x)

# Now the '4' level inception units
x = concatenated_block(x, ((192,), ( 96,208), (16, 48), ( 64,)), channel_axis, 'Mixed_4b')
x = concatenated_block(x, ((160,), (112,224), (24, 64), ( 64,)), channel_axis, 'Mixed_4c')
x = concatenated_block(x, ((128,), (128,256), (24, 64), ( 64,)), channel_axis, 'Mixed_4d')
x = concatenated_block(x, ((112,), (144,288), (32, 64), ( 64,)), channel_axis, 'Mixed_4e')
x = concatenated_block(x, ((256,), (160,320), (32,128), (128,)), channel_axis, 'Mixed_4f')

x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='MaxPool_5a_2x2')(x)

# Now the '5' level inception units
x = concatenated_block(x, ((256,), (160,320), (32,128), (128,)), channel_axis, 'Mixed_5b')
x = concatenated_block(x, ((384,), (192,384), (48,128), (128,)), channel_axis, 'Mixed_5c')


if include_top:
# Classification block

# 'AvgPool_0a_7x7'
x = AveragePooling2D((7, 7), strides=(1, 1), padding='valid')(x)

# 'Dropout_0b'
x = Dropout(0.2)(x) # slim has keep_prob (@0.8), keras uses drop_fraction

#logits = conv2d_bn(x, classes, 1, 1, strides=(1, 1), padding='valid', name='Logits',
# normalizer=False, activation=None, )

# Write out the logits explictly, since it is pretty different
x = Conv2D(classes, (1, 1), strides=(1,1), padding='valid', use_bias=True, name='Logits')(x)

x = Flatten(name='Logits_flat')(x)
#x = x[:, 1:] # ??Shift up so that first class ('blank background') vanishes
# Would be more efficient to strip off position[0] from the weights+bias terms directly in 'Logits'

x = Activation('softmax', name='Predictions')(x)
else:
if pooling == 'avg':
x = GlobalAveragePooling2D(name='global_pooling')(x)
elif pooling == 'max':
x = GlobalMaxPooling2D( name='global_pooling')(x)

# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

# Finally : Create model
model = Model(inputs, x, name='inception_v1')

# LOAD model weights
if weights == 'imagenet':
if K.image_data_format() == 'channels_first':
if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
'image data format convention '
'(`image_data_format="channels_first"`). '
'For best performance, set '
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
if include_top:
weights_path = get_file(
'inception_v1_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='723bf2f662a5c07db50d28c8d35b626d')
else:
weights_path = get_file(
'inception_v1_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='6fa8ecdc5f6c402a59909437f0f5c975')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)

return model



def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x


if __name__ == '__main__':
model = InceptionV1(include_top=True, weights='imagenet')

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)

x = preprocess_input(x)

preds = model.predict(x)

# Extra shift to remove 'background' as entry0
preds_1000 = preds[:, 1:]

print('Predicted:', decode_predictions(preds_1000))