forked from TeamCookCaps/ImageClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
94 lines (68 loc) · 3.29 KB
/
model.py
File metadata and controls
94 lines (68 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, GlobalAveragePooling2D
from tensorflow.keras.applications import Xception
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import numpy as np
import cv2
import glob
import matplotlib.pyplot as plt
IMG_WIDTH = 224
IMG_HEIGHT = 224
base_model = Xception(weights='imagenet', include_top=False, input_shape=(IMG_WIDTH,IMG_HEIGHT,3))
#base_model.summary()
model = Sequential()
model.add(base_model)
#model.add(Flatten())
model.add(GlobalAveragePooling2D())
# 새로운 분류기
model.add(Dense(60,activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(11,activation='softmax')) # 답 11개이므로 출력층 노드 11개
#model.compile(loss='sparse_categorical_crossentropy',
# optimizer=tf.keras.optimizers.Adam(),
# metrics=['accuracy'])
model.summary()
train_dir = 'cats_and_dogs_filtered/train'
test_dir = 'cats_and_dogs_filtered/validation'
#이미지 읽어올 때 자동 정규화
train_data_gen = ImageDataGenerator(rescale=1./255,
rotation_range=10, width_shift_range=0.1,
height_shift_range=0.1,shear_range=0.1,zoom_range=0.1)
test_data_gen = ImageDataGenerator(rescale=1./255)
train_data = train_data_gen.flow_from_directory(train_dir, batch_size = 32,
color_mode='rgb', shuffle = True, class_mode = 'categorical',
target_size=(IMG_WIDTH,IMG_HEIGHT))
test_data = test_data_gen.flow_from_directory(test_dir,batch_size = 32,
color_mode='rgb', shuffle=True,class_mode='categorical',
target_size=(IMG_WIDTH,IMG_HEIGHT))
model.compile(loss='categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(2e-5),metrics=['accuracy'])
save_file_name = './cats_and_dogs_filtered_Xception_Colab.h5'
checkpoint = ModelCheckpoint(save_file_name,monitor='val_loss',
verbose=1, save_best_only=True,mode ='auto')
earlystopping = EarlyStopping(monitor='val_loss',patience=5)
hist = model.fit(train_data,epochs=1, validation_data = test_data, callbacks=[checkpoint,earlystopping])
### 학습 끝
test_imge_list = []
test_image_name_list = glob.glob('test_image_dir/*')
for i in range(len(test_image_name_list)):
src_img = cv2.imread(test_image_name_list[i],cv2.IMREAD_COLOR)
src_img = cv2.resize(src_img, dsize=(IMG_WIDTH,IMG_HEIGHT))
dst_img = cv2.cvtColor(src_img,cv2.COLOR_BGR2RGB)
dst_img = dst_img / 255.0
test_imge_list.append(dst_img)
pred = model.predict(np.array(test_imge_list))
class_name = ['기프티콘','동물','식물','영수증','음식','인물','인테리어','차량','캡쳐화면','패션','풍경']
plt.figure(figsize=(8,6))
for i in range(len(pred)):
plt.subplot(6,7,i+1)
prediction = str(class_name[np.argmax(pred[i])])
probility = '{0:0.2f}'.format(100*max(pred[i]))
title_str = prediction+" . "+probility+'%'
print(title_str)
plt.axis('off')
plt.title(title_str)
plt.imshow(test_imge_list[i])
plt.show()