当前位置:网站首页>Keras' deep learning practice -- gender classification based on RESNET model
Keras' deep learning practice -- gender classification based on RESNET model
2022-07-22 07:23:00 【Hope Xiaohui】
Keras Deep learning practice —— be based on ResNet The model implements gender classification
0. Preface
from VGG16
To VGG19
, The most significant change is the increase in the number of network layers , generally , The deeper the neural network , The better the model performance . However, higher model performance can be achieved only by increasing the number of network layers , It's easy , We can add more layers to the model until it achieves the best performance .
But unfortunately , This is not the case , As the number of network layers increases , The problem of gradient disappearance also surfaced —— As the number of layers increases , The gradient in the network will become very small , So that it is difficult to adjust the weight , At the same time, the network performance will also decline .
Deep residual network (ResNet
) The proposal of is to solve the above problems . stay ResNet
in , If the model has nothing to learn , Then the convolution layer can do nothing , Just pass the output of the previous layer to the next layer . however , If the model needs to learn some other features , Then the convolution layer takes the output of the previous layer as the input , And learn other characteristics needed to complete the target task .
1. ResNet Architecture brief introduction
residual (Residual
) In mathematical statistics, it refers to the actual observed value and estimated value ( Fit value ) Difference between . classical ResNet
The architecture is as follows :
In the diagram above , It can be seen that , There are jump connections in the model , This connection connects the previous layer with the traditional convolution layer in the network to the next layer of the line . More formally , Input x x x Through convolution , Get the output after feature transformation F ( x ) F(x) F(x), With the input x x x Add element by element , Get the final output H ( x ) H(x) H(x):
H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)
VGG
The comparison between module and residual module is as follows :
2. Based on pre training ResNet50 The model implements gender classification
stay 《 The migration study 》 in , We learned about using transfer learning , Only a few samples are needed to train the model with good performance ; And use pre trained based on transfer learning VGG16
The model Gender classification Actual combat . In this section , We also use pre trained ResNet50
Carry out gender classification practice , among ResNet50
Medium 50
Indicates that the network has 50
Network layer .
2.1 Training gender classification model
First, import the required Library , And download pre trained ResNet50
Model :
from keras.applications import ResNet50
from keras.applications.resnet50 import preprocess_input
from glob import glob
from skimage import io
import cv2
import numpy as np
model = ResNet50(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
Create input and output data sets , It should be noted that ,ResNet50
The size of the input image of is at least 224 x 224
, In order to make sure ResNet50
The pre training model can work normally . We reuse 《 Convolution neural network for gender classification 》 Data set and data loading code used in :
x = []
y = []
for i in glob('man_woman/a_resized/*.jpg')[:800]:
try:
image = io.imread(i)
x.append(image)
y.append(0)
except:
continue
for i in glob('man_woman/b_resized/*.jpg')[:800]:
try:
image = io.imread(i)
x.append(image)
y.append(1)
except:
continue
x_resnet50 = []
for i in range(len(x)):
img = x[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_resnet50.append(img_feature)
Build input and output numpy
Array , At the same time, the data set is divided into training and testing sets :
x_resnet50 = np.array(x_resnet50)
x_resnet50 = x_resnet50.reshape(x_resnet50.shape[0], x_resnet50.shape[2], x_resnet50.shape[3], x_resnet50.shape[4])
y = np.array(y)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x_resnet50, y, test_size=0.2)
In pre training ResNet50
Build a fine-tuning model based on the output of the model :
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
model_fine_tuning = Sequential()
model_fine_tuning.add(Conv2D(2048,
kernel_size=(3, 3),
activation='relu',
input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3])))
model_fine_tuning.add(MaxPooling2D(pool_size=(2, 2)))
model_fine_tuning.add(Flatten())
model_fine_tuning.add(Dense(1024, activation='relu'))
model_fine_tuning.add(Dropout(0.5))
model_fine_tuning.add(Dense(1, activation='sigmoid'))
model_fine_tuning.summary()
The brief architecture information of the model is as follows :
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 6, 6, 2048) 37750784
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 3, 3, 2048) 0
_________________________________________________________________
flatten (Flatten) (None, 18432) 0
_________________________________________________________________
dense (Dense) (None, 1024) 18875392
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 56,627,201
Trainable params: 56,627,201
Non-trainable params: 0
_________________________________________________________________
Compile and fit the constructed fine-tuning model :
model_fine_tuning.compile(loss='binary_crossentropy',optimizer='adam',metrics=['acc'])
history = model_fine_tuning.fit(x_train, y_train,
batch_size=32,
epochs=20,
verbose=1,
validation_data = (x_test, y_test))
During training , The changes of accuracy and loss values of the model on the training data set and the test data set are as follows :
You can see , Use pre training ResNet50
The accuracy of the gender classification model can reach 95%
about .
2.2 Misclassification image example
Examples of misclassified images are as follows :
x = np.array(x)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
x_test_resnet50 = []
for i in range(len(x_test)):
img = x_test[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_test_resnet50.append(img_feature)
x_test_resnet50 = np.array(x_test_resnet50)
x_test_resnet50 = x_test_resnet50.reshape(x_test_resnet50.shape[0], x_test_resnet50.shape[2], x_test_resnet50.shape[3], x_test_resnet50.shape[4])
y_pred = model_fine_tuning.predict(x_test_resnet50)
wrong = np.argsort(np.abs(y_pred.flatten()-y_test))
print(wrong)
y_test_char = np.where(y_test==0,'M','F')
y_pred_char = np.where(y_pred>0.5,'F','M')
plt.subplot(221)
plt.imshow(x_test[wrong[-1]])
plt.title('Actual: '+str(y_test_char[wrong[-1]])+', '+'Predicted: '+str((y_pred_char[wrong[-1]][0])))
plt.subplot(222)
plt.imshow(x_test[wrong[-2]])
plt.title('Actual: '+str(y_test_char[wrong[-2]])+', '+'Predicted: '+str((y_pred_char[wrong[-2]][0])))
plt.subplot(223)
plt.imshow(x_test[wrong[-3]])
plt.title('Actual: '+str(y_test_char[wrong[-3]])+', '+'Predicted: '+str((y_pred_char[wrong[-3]][0])))
plt.subplot(224)
plt.imshow(x_test[wrong[-4]])
plt.title('Actual: '+str(y_test_char[wrong[-4]])+', '+'Predicted: '+str((y_pred_char[wrong[-4]][0])))
plt.show()
contrast VGG16、VGG19 and Inception v3, There is no significant difference in the accuracy of multiple pre trained gender classification models , Because perhaps the image features extracted by these pre training models are more general features , There is no optimization for extracting gender features , We can train one from scratch RestNet50
, Check the network performance .
Related links
Keras Deep learning practice (7)—— Convolution neural network detailed explanation and implementation
Keras Deep learning practice (9)—— The limitations of convolutional neural networks
Keras Deep learning practice (10)—— The migration study
Keras Deep learning practice —— Using convolution neural network to achieve gender classification
Keras Deep learning practice —— be based on VGG19 The model implements gender classification
Keras Deep learning practice —— be based on Inception v3 Achieve gender classification
边栏推荐
- Oracle 关于date 字段索引使用测试
- 【Caused by: com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException: sql语句参数跟方法的不一样】
- Implementation of linear table
- ECCV 2022 | 修正FPN带来的大目标性能损害:You Should Look at All Objects
- Matlab GUI编程技巧(八):uitoolbar在图窗中创建工具栏
- 关于mybatics中起始与结束时间的处理方法
- String description
- Interface and abstract class
- 空悬指针和孤儿内存
- Druid 集成 ShardingSphere 出现 xxMapper.xml 报错的原因与解决方案
猜你喜欢
NoSQL数据库之Redis【数据操作、持久化、Jedis、缓存处理】详解
Keras深度学习实战——基于ResNet模型实现性别分类
About troubleshooting MySQL driver version errors, cause: com mysql. jdbc. exceptions. jdbc4、Unknown system variable ‘query_ cache_ size
关于mysql驱动版本报错解决,Cause: com.mysql.jdbc.exceptions.jdbc4、Unknown system variable ‘query_cache_size
Keras深度学习实战(13)——目标检测基础详解
Detailed explanation of redis [data operation, persistence, jedis, cache processing] of NoSQL database
The difference between abstract class and interface interface in C #
自定义类加载器实现
80.26亿元!国家互联网信息办公室对滴滴依法作出网络安全审查相关行政处罚
Airtest stepped on the pit -- start flash back
随机推荐
零碎知识——机器学习
Sizeof judge array size
可视化之基础设置
关于mybatics中起始与结束时间的处理方法
C create user defined exception
Jmeter关联(二)
四种常见的 POST 提交数据方式(application/x-www-form-urlencoded,multipart/form-data,application/json,text/xml)
2022.7.19 simulation match
2022年中国第三方支付市场专题分析
Implementation of linear table
'resultMap' must match '(constructor?,id*,result*,association*,collect 问题解决
同花顺开户能直接开吗?开户安全吗?怎么办理开户??
软件推荐-办公软件
Matlab GUI programming skills (XI): axes/geoaxes/polaraxes drawing to create GUI coordinate area
Matlab GUI编程技巧(十一):axes/geoaxes/polaraxes绘图创建 GUI 坐标区
链表的模板实现
Array implementation of scalable circular queue
采坑阿里云 kex_exchange_identification: read: Connection reset by peer
华为麒麟985三季度量产:台积电7nm EUV工艺,集成5G基带!
IP地址分类及范围