当前位置:网站首页>Keras' deep learning practice -- gender classification based on inception V3
Keras' deep learning practice -- gender classification based on inception V3
2022-07-22 07:23:00 【Hope Xiaohui】
Keras Deep learning practice —— be based on Inception v3 Achieve gender classification
0. Preface
We've learned that based on VGG16
and VGG19
Architecture implementation Gender classification , besides , There are many other deep neural network architectures that are more skillfully designed , for example Inception
On the premise of ensuring the quality of the model , Greatly reduce the number of model parameters . In this section , We will be on Inception
The core idea of the model is introduced , Then use pre training based Inception
The architecture implements gender classification .
1. Inception structure
For better understanding Inception
The core idea of the model , Let's first consider the following scenarios : In the data set , Some objects in the image occupy most of the image , But in other images, the object may only occupy a small part of the whole image . If we use convolution kernels of the same size in both cases , It will make it difficult for the model to learn to recognize the smaller object in the image and the larger object in the image at the same time .
To solve this problem , We can use many convolution kernels of different sizes in the same layer . under these circumstances , The network is essentially wider , Instead of getting deeper , As shown below :
In the diagram above , We use many convolution kernels of different sizes in a given layer ,Inception v1
The module has nine linearly stacked modules Inception
modular , As shown below :
1.1 Inception v1 Loss function
stay Inception v1
In the architecture diagram , You can see that the architecture is both deep and wide , This is likely to cause the gradient to disappear .
In order to solve the problem of gradient disappearance ,Inception v1
There are two auxiliary classifiers , They come from Inception
modular , Try to base on Inception
The total loss of the network is minimized , As shown below :
total_loss = real_loss + 0.3 * aux_loss_1 + 0.3 * aux_loss_2
It should be noted that , Auxiliary loss is only used during training , It will be ignored during model testing .
1.2 Inception v2 and Inception v3
Inception v2
and Inception v3
It's right Inception v1
Architecture improvements , Among them in Inception v2
in ,Inception
The author optimizes the algorithm on the basis of convolution , To process images faster ; stay Inception v3
in ,Inception
Based on the original convolution kernel, the author adds 7 x 7
Convolution kernel , And connect them in series . To make a long story short ,Inception
The contributions are as follows :
- Use
Inception
The module captures the multi-scale details of the image - Use
1 x 1
Convolution ACTS as the bottleneck layer - Use the average pool layer instead of the full connection layer , Reduce the amount of model parameters
- Use auxiliary branches to avoid the disappearance of gradients
2. Using pre-trained Inception v3 The model implements gender classification
stay 《 The migration study 》 in , We learned about using transfer learning , Only a few samples are needed to train the model with good performance ; And use pre trained based on transfer learning VGG16
The model Gender classification Actual combat . In this section , We use pre trained Inception v3
Build a model to recognize the gender of the characters in the image .
2.1 Model implementation
First , Load the required libraries , And load the pre trained Inception v3
Model :
from keras.applications import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from glob import glob
from skimage import io
import cv2
import numpy as np
model = InceptionV3(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
Create input and output data sets :
x = []
y = []
for i in glob('man_woman/a_resized/*.jpg')[:8000]:
try:
image = io.imread(i)
x.append(image)
y.append(0)
except:
continue
for i in glob('man_woman/b_resized/*.jpg')[:8000]:
try:
image = io.imread(i)
x.append(image)
y.append(1)
except:
continue
x_inception_v3 = []
for i in range(len(x)):
img = x[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_inception_v3.append(img_feature)
Convert input and output to numpy
Array , The data set is divided into training and testing sets :
x_inception_v3 = np.array(x_inception_v3)
x_inception_v3 = x_inception_v3.reshape(x_inception_v3.shape[0], x_inception_v3.shape[2], x_inception_v3.shape[3], x_inception_v3.shape[4])
y = np.array(y)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x_inception_v3, y, test_size=0.2)
Based on the output of the pre training model, a fine tuning model is constructed :
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
model_fine_tuning = Sequential()
model_fine_tuning.add(Conv2D(2048,
kernel_size=(3, 3),
activation='relu',
input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3])))
model_fine_tuning.add(MaxPooling2D(pool_size=(2, 2)))
model_fine_tuning.add(Flatten())
model_fine_tuning.add(Dense(1024, activation='relu'))
model_fine_tuning.add(Dropout(0.5))
model_fine_tuning.add(Dense(1, activation='sigmoid'))
model_fine_tuning.summary()
The brief information output of the previous fine-tuning model , As shown below :
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_94 (Conv2D) (None, 4, 4, 2048) 37750784
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 2, 2, 2048) 0
_________________________________________________________________
flatten (Flatten) (None, 8192) 0
_________________________________________________________________
dense (Dense) (None, 1024) 8389632
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 46,141,441
Trainable params: 46,141,441
Non-trainable params: 0
_________________________________________________________________
Last , Compile and fit the model :
model_fine_tuning.compile(loss='binary_crossentropy',optimizer='adam',metrics=['acc'])
history = model_fine_tuning.fit(x_train, y_train,
batch_size=32,
epochs=20,
verbose=1,
validation_data = (x_test, y_test))
During training , The changes of accuracy and loss values of the model on the training data set and the test data set are as follows :
You can see , Based on and training Inception V3
The accuracy of the gender classification model can reach 95%
about .
2.2 Examples of misclassified pictures
Examples of misclassified images are as follows :
x = np.array(x)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
x_test_inception_v3 = []
for i in range(len(x_test)):
img = x_test[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_test_inception_v3.append(img_feature)
x_test_inception_v3 = np.array(x_test_inception_v3)
x_test_inception_v3 = x_test_inception_v3.reshape(x_test_inception_v3.shape[0], x_test_inception_v3.shape[2], x_test_inception_v3.shape[3], x_test_inception_v3.shape[4])
y_pred = model_fine_tuning.predict(x_test_inception_v3)
wrong = np.argsort(np.abs(y_pred.flatten()-y_test))
print(wrong)
y_test_char = np.where(y_test==0,'M','F')
y_pred_char = np.where(y_pred>0.5,'F','M')
plt.subplot(221)
plt.imshow(x_test[wrong[-1]])
plt.title('Actual: '+str(y_test_char[wrong[-1]])+', '+'Predicted: '+str((y_pred_char[wrong[-1]][0])))
plt.subplot(222)
plt.imshow(x_test[wrong[-2]])
plt.title('Actual: '+str(y_test_char[wrong[-2]])+', '+'Predicted: '+str((y_pred_char[wrong[-2]][0])))
plt.subplot(223)
plt.imshow(x_test[wrong[-3]])
plt.title('Actual: '+str(y_test_char[wrong[-3]])+', '+'Predicted: '+str((y_pred_char[wrong[-3]][0])))
plt.subplot(224)
plt.imshow(x_test[wrong[-4]])
plt.title('Actual: '+str(y_test_char[wrong[-4]])+', '+'Predicted: '+str((y_pred_char[wrong[-4]][0])))
plt.show()
Related links
Keras Deep learning practice (7)—— Convolution neural network detailed explanation and implementation
Keras Deep learning practice (9)—— The limitations of convolutional neural networks
Keras Deep learning practice (10)—— The migration study
Keras Deep learning practice —— Using convolution neural network to achieve gender classification
Keras Deep learning practice —— be based on VGG19 The model implements gender classification
边栏推荐
- 码蹄集 - MT2095 · 曲径折跃
- Observer mode and publish / subscribe mode
- Basic settings of visualization
- mysql8中timestamp时间戳设置默认值 CURRENT_TIMESTAMP ,Error:1048 - Column ‘createTime‘ cannot be null
- [caused by: com.mysql.jdbc.exceptions.jdbc4.mysqlsyntaxerrorexception: SQL statement parameters are different from methods]
- Love running every day [noip2016 T4]
- Piecemeal knowledge - Business
- Realize line by line output of text file content
- 31下一个排列
- Common commands for starting services
猜你喜欢
如何生成xmind的复杂流程图
C#中抽象类abstract和接口interface的区别
About troubleshooting MySQL driver version errors, cause: com mysql. jdbc. exceptions. jdbc4、Unknown system variable ‘query_ cache_ size
Matlab GUI programming skills (VIII): uitoolbar create toolbar in the drawing window
如何在页面中添加地图
Piecemeal knowledge - SQL related
Oracle about date field index usage test
VLOOKUP函数
22括号生成
NoSQL数据库之Redis【数据操作、持久化、Jedis、缓存处理】详解
随机推荐
String说明
mysql8中timestamp时间戳设置默认值 CURRENT_TIMESTAMP ,Error:1048 - Column ‘createTime‘ cannot be null
Piecemeal knowledge - SQL related
[matlab problem solving] solve the problem after matlab compilation Exe file cannot run on another computer
"New energy + energy storage" starts from the digital twin, Tupu will make smart power to the extreme
Custom class loader implementation
select下拉框默认选中selected属性不起作用的解决方法 ligerui
Template implementation of linked list
Array implementation of scalable circular queue
B tree and b+ tree hash index
Matlab GUI programming skills (x): UI figure function to create a visual window
C#创建用户自定义异常 user defined exception
Dangling pointer and orphan memory
数组结构的栈实现
零碎知识——统计相关
链表结构的栈实现
The meaning of sprintf (pointer + offset, "format", value) form
零碎知识——AB实验
益盟操盘手软件可靠么?买股票安全吗?
软件推荐-装机