当前位置:网站首页>MNIST handwritten numeral recognition case tensorflow 2.0 practice
MNIST handwritten numeral recognition case tensorflow 2.0 practice
2022-07-22 19:00:00 【Sand is sand】
1. Import library
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
print("Tensorflow The version is :",tf.__version__)
2、 Data set acquisition
mnist = tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
3、 Data set partitioning
3.1 Partition validation set
total_num = len(train_images)
valid_split = 0.2 # Proportion of validation set 20%
train_num = int(total_num*(1-valid_split)) # Number of training sets
train_x = train_images[:train_num] # The first part is the training set
train_y = train_labels[:train_num]
valid_x = train_images[train_num:] # after 20% Give the validation set
valid_y = train_labels[train_num:]
test_x = test_images
test_y = test_labels
valid_x.shape
4、 Data plasticity
# hold (28 28) Structure Straighten into a line 784
train_x = train_x.reshape(-1,784)
valid_x = valid_x.reshape(-1,784)
test_x = test_x.reshape(-1,784)
5、 Normalization of characteristic data
train_x = tf.cast(train_x/255.0,tf.float32)
valid_x = tf.cast(valid_x/255.0,tf.float32)
test_x = tf.cast(test_x/255.0,tf.float32)
6、 Hot coding alone
# example
x = [3,4]
tf.one_hot(x,depth = 10)
# Encode the label data
train_y = tf.one_hot (train_y,depth = 10)
valid_y = tf.one_hot(valid_y,depth = 10)
test_y = tf.one_hot (test_y,depth=10)
train_y
7、 Create variables
# Defining variables
W = tf.Variable(tf.random.normal([784,10],mean = 0.0,stddev = 1.0,dtype = tf.float32))
B = tf.Variable(tf.zeros([10]),dtype = tf.float32)
# In this case , Initialize the total middle with random numbers of normal distribution W, With a constant 0 Initialize paranoia B
8、 Define the cross quotient loss function
def loss(x,y,w,b):
pred = model(x,w,b) # Calculate the predicted value of the model and Differences in tag values
loss_ = tf.keras.losses.categorical_crossentropy(y_true = y,y_pred = pred)
return tf.reduce_mean(loss_) # Calculating mean , The mean square deviation is obtained
9、 Set super parameters
training_epochs = 20 # Training theory
batch_size = 50 # Number of samples per training ( Batch size )
learning_rate = 0.001 # Learning rate
10、 Define the gradient function
# Calculate sample data 【x,y】 In the parameter 【w,b】 The gradient at the point
def grad(x,y,w,b):
with tf.GradientTape() as tape:
loss_ = loss(x,y,w,b)
return tape.gradient(loss_,[w,b]) # Returns the gradient vector
11、 Choose the optimizer
# Adam Optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
12、 Definition accuracy
def accuracy(x,y,w,b):
pred = model(x,w,b) # Calculate the difference between the predicted value of the model and the label value
# Check 【 Forecast category tf.argmax(pred,1) With the actual category tf.argmax(y,1) The matching condition of
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
# Accuracy rate , Convert Boolean values to floating point numbers , And calculate the average
return tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
13、 Training models
total_step = int(train_num/batch_size)
loss_list_train = [] # Used to save training sets loss A list of
loss_list_valid = [] # Used to save the validation set loss List of values
acc_list_train = [] # Used to hold Training set Acc List of values
acc_list_valid = [] # Used to save the validation set Acc Value A list of
for epoch in range(training_epochs):
for step in range(total_step):
xs = train_x[step*batch_size:(step+1)*batch_size]
ys = train_y[step*batch_size:(step+1)*batch_size]
grads = grad(xs,ys,W,B) # Calculate the gradient
optimizer.apply_gradients(zip(grads,[W,B]))# Optimizer Automatically adjust the variable according to the gradient w and b
loss_train = loss(train_x,train_y,W,B).numpy()# Calculate the current training loss
loss_valid = loss(valid_x,valid_y,W,B).numpy()# Calculate the loss of the current round of verification
acc_train = accuracy(train_x,train_y,W,B).numpy()
acc_valid = accuracy(valid_x,valid_y,W,B).numpy()
loss_list_train.append(loss_train)
loss_list_valid.append(loss_valid)
acc_list_train.append(acc_train)
acc_list_valid.append(acc_valid)
print("epoch = {:3d},train_loss = {:.4f},val_loss,val_acc={:.4f}".format(epoch+1,loss_train,acc_train,loss_valid,acc_valid))
It can be seen from the above printing results that the loss value Loss It tends to be smaller , meanwhile , Accuracy rate Accuracy Higher and higher
14、 Display training process data
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(loss_list_train,'blue',label = "Train Loss")
plt.plot(loss_list_valid,'red',label = "Valid Loss")
plt.legend(loc = 1) # Through parameters loc Specify the legend location
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(acc_list_train,'blue',label = 'Train Acc')
plt.plot(acc_list_valid,'red',label = 'Valid Acc')
plt.legend(loc = 1) # Through parameters loc Specify the legend location
15、 Evaluation model
After training , Evaluate the accuracy of the model on the test set
acc_test = accuracy(test_x,test_y,W,B).numpy()
print("Test accuracy :",acc_test)
16、 Model application and visualization
16.1、 Application model
Of After modeling and training , If you think the accuracy is acceptable , This model can be used to predict
# Define the prediction function
def predict(x,w,b):
pred = model(x,w,b)# Calculate the predicted value of the model
result = tf.argmax(pred,1).numpy()
return result
pred_test = predict(test_x,W,B)
pred_test[0]
16.2、 Define visualization functions
import matplotlib.pyplot as plot
import numpy as np
def plot_images_label_prediction(images,# Image list
labels, # Tag list
preds, #y List of predicted values
index=0, # From index Start display
num = 10 # default Show at once 10 picture
):
fig = plt.gcf()# Get the current chart ,Get Current Figure
fig.set_size_inches(10,4) #1 Inch be equal to 2.54cm
if num > 10:
num = 10
for i in range(0,num):
ax = plt.subplot(2,5,i+1)
ax.imshow(np.reshape(images[index],(28,28)),cmap="binary") # According to the first index Images
title = 'label=' + str(labels[index]) # Build the... To be displayed on the diagram title Information
if len(preds)>0:
title += ",predict=" + str(preds[index])
ax.set_title(title,fontsize = 10) # Show In the picture title Information
ax.set_xticks([]);
ax.set_yticks([])
index = index + 1
plt.show()
plot_images_label_prediction(test_images,test_labels,pred_test,10,10)
17、 final result
18、 Source code :GitHub - shazi4399/TensorFlow: Study TensorFlow Some of demo
19、 Welcome to the official account : Sand is sand
边栏推荐
- Message Oriented Middleware
- PTA search tree judgment
- The detailed analysis of the divide () method in BigDecimal takes you into the world of source code
- Flink learning notes (VII) processing function
- Summary of various technical data -mysql
- 程序员面试金典面试题 01.04. 回文排列
- 分库分表
- 1. Lei Dian: transfer MySQL database to Oracle, 2.qt5.12.3 connect Oracle 12C database
- Leetcode 2028. find out the missing observation data
- There is no session in the tensorflow module
猜你喜欢
Recursively find the partial sum of simple alternating power series (15 points)
国内 Ngrok 实现内网穿透
1. Lei Dian: transfer MySQL database to Oracle, 2.qt5.12.3 connect Oracle 12C database
Leetcode 2028. find out the missing observation data
写作单词积累
1. Closeable of qtablewidget, 2.pro/build_ pass、member,3.QString&&
Flink learning notes (IV) Flink runtime architecture
1.Qt之打包发布程序 (NSIS);
Six dimensional space
fucking-algorithm
随机推荐
Leetcode 2039. when the network is idle
Three ways to restrict IP access between micro services
When serializing JSON objects, how to return JSON strings with attribute names with null values?
LeetCode 720. 词典中最长的单词
QT | modal dialog and modeless dialog qdialog
项目启动过后,一直循环加载mapper xml文件
Go memory model
Kindling the Darkness: A Practical Low-light Image Enhancer
Pat class B 1010 univariate polynomial derivation (problem meaning understanding)
六度空间
LeetCode 116. 填充每个节点的下一个右侧节点指针
mysql主从复制
LeetCode 693. 交替位二进制数
在各类数据库中随机查询n条数据
分库分表
1. The solution of line feed qt5- "vs" in constants; 2. Problems and solutions of common compilation of QT and vs of the same code
代码—
【YOLOv5实战4】基于YOLOv5的交通标志识别系统-模型测试与评估
实现各个微服务间限制IP访问 的三种方式
Flink学习笔记(五)DataStream API