-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwoundrecognition-mobilenet.py
42 lines (35 loc) · 1.83 KB
/
woundrecognition-mobilenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
"""
20190131
Created on Sat Nov 3 22:32:39 2018
@author: adi
"""
from keras.layers import Conv2D,MaxPooling2D,Dense,Dropout,Flatten,Activation,Input,GlobalAveragePooling2D,BatchNormalization,Add
from keras.models import Sequential,Model
from keras.utils import np_utils
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint,TensorBoard
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.applications.nasnet import NASNetMobile
from weightnorm import AdamWithWeightnorm
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
#pre-process
train_path='D:/dataset-wound/train_data'
test_path='D:/dataset-wound/test_data'
traind= ImageDataGenerator(rotation_range=20,width_shift_range=3.0, rescale=1/255.,height_shift_range=3.0).flow_from_directory(train_path, target_size=(224,224) ,classes=['s','b','c','x'], batch_size=10)
testd= ImageDataGenerator(rescale=1/255.).flow_from_directory(test_path,target_size=(224,224) ,classes=['s','b','c','x'], batch_size=10)
img_input = Input(shape=(224,224,3))
mobile = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224,224,3))
x = GlobalAveragePooling2D()(mobile.output)
x = Dense(1024,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(4,activation='softmax')(x)
model = Model(mobile.input, x)
model.summary()
adam = Adam(lr=1e-5)
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
checkpoint = ModelCheckpoint('wound0131mobileweightnorm.h5', monitor='val_loss', save_best_only=True)
tbCallBack = TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
model.summary()
model.fit_generator(traind, steps_per_epoch=8, validation_data=testd, validation_steps=13, epochs=200, verbose=1, callbacks=[checkpoint,tbCallBack])