当前位置:网站首页>Introduction to machine learning: support vector machine-6
Introduction to machine learning: support vector machine-6
2022-07-22 20:21:00 【Pentium wanderer】
Introduction to machine learning : Support vector machine
1、 Description of the experiment
- This experiment provides some handwritten digital pictures , And use SVM Recognize handwritten data sets , And the recognition results are shown in the figure . Then use cross validation , Find the optimal parameters of the algorithm .
- The duration of the experiment :45 minute
- Main steps :
- Load training data
- Load test data
- Data set preprocessing
- Drawing training data 、 Grayscale image of test data
- model training
- The prediction of the model
- Model to evaluate
- Use cross validation Determine the optimal parameters
2、 Experimental environment
- Number of virtual machines :1
- System version :CentOS 7.5
- scikit-learn edition : 0.19.2
- numpy edition :1.15.1
- matplotlib edition :2.2.3
- python edition :3.5
- IPython edition :6.5.0
3、 Relevant skills
Python Programming
Scikit-learn Programming
Matplotlib Programming
Numpy Programming
SVM modeling
4、 Related knowledge
- Support vector machine
- The feature space
- Nuclear space
- Cross validation
- Model creation
- Training models
- Model to predict
5、 Realization effect
- The recognition effect of support vector machine on handwritten digital pictures is shown in the following figure :
6、 The experimental steps
6.1 Support vector machine concept :
6.1.1 Suppose we want to divide a task into two categories through a line .
6.1.1.1 Then there are countless lines that can complete this task .
6.1.1.2 stay SVM in , We look for an optimal dividing line to make it to both sides margin Are the biggest .
6.1.1.3 In this case, it is on the edge , Several data points that can determine the position of the line , Call support vector, This is also the origin of the name of this classification algorithm .
6.1.2 Kernel function :
6.1.2.1 You can use kernel functions , Map the original input space to the new The feature space , thus , So that the original linear inseparable samples can be in Nuclear space Divisible .
6.1.2.2 Common kernel functions : Polynomial kernel function 、 Gaussian kernel RBF function 、Sigmoid Kernel function . in application , Often rely on a priori domain knowledge / Cross validation and other schemes can select effective kernel functions . If there is no more prior information , Gaussian kernel function
6.2 Get into Anaconda Create a virtual environment “ML”
6.2.1 from zkpk Copy the data files required for the experiment to zkpk Home directory
[[email protected] ~]$ cd
[[email protected] ~]$ cp /home/zkpk/experiment/optdigits.tra /home/zkpk
[[email protected] ~]$ cp /home/zkpk/experiment/optdigits.tes /home/zkpk
6.2.2 Data set introduction : The data comes from 43 Human handwritten digits , among 30 Human is used for training , in addition 13 Human is used for testing .
6.2.2.1 The training set consists of 3823 A picture , The test set consists of 1797 A picture .
6.2.2.2 Each picture is 8×8 Grayscale image of , Pixel values from 0 To 16, among ,16 It's all black ,0 It means full light
6.2.3 stay zkpk Execute the following command in your home directory
[[email protected] ~]$ cd
[[email protected] ~]$ source activate ML
(ML)[[email protected] master ML]$
6.2.4 At this point, it has entered the virtual environment . Type the following command , Get into ipython Interaction is the programming environment
(ML) [[email protected] master ML]$ ipython
Python 3.5.4 |Anaconda, Inc.| (default, Nov 3 2017, 20:01:27)
Type 'copyright', 'credits' or 'license' for more information
IPython 6.2.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]:
6.3 stay Ipython Start experimenting in an interactive programming environment
6.3.1 Import the package required for the experiment
In [1]: import numpy as np
...: import pandas as pd
...: from sklearn import svm
...: import matplotlib.colors
...: import matplotlib.pyplot as plt
...: from PIL import Image # PIL: Python image library
...: from sklearn.metrics import accuracy_score
...: import os
...: from sklearn.model_selection import train_test_split
...: from sklearn.model_selection import GridSearchCV
...: from time import time
6.3.2 Load training set data
6.3.2.1 Specify the data separator as ’,’
In [13]: data = np.loadtxt('optdigits.tra', dtype=np.float, delimiter=',') # loadtxt Method returns ndarray;dtype Parameter is used to specify the element type in the returned array ;delimiter Specify the separator
...: data.shape
6.3.2.2 Read the corresponding x, y label
In [14]: x, y = np.split(data, (-1, ), axis=1) # split Methods will ndarray Segmentation ;axis=1 Press vertically data Segmentation ; Reading data , The last column serves as a label , Assign to y
...: x.shape
...: y.shape
6.3.2.3 take y The label is converted into an unsigned integer
In [15]: y = y.ravel().astype(np.int) # ravel Method , take ndarray Convert to a flattened array ;astype Convert all elements of the array to the specified int type
6.3.2.4 see y What are the tag values in
In [16]: np.unique(y)
6.3.2.5 see x What are the only elements in
In [17]: np.unique(x)
6.3.2.6 Remap the image to 8*8 size ,
In [21]: images = x.reshape(-1, 8, 8) # The image size is mapped to 8*8
...: images.shape
6.3.3 Do the same for the test set
In [22]: print('Load Test Data Start...')
...: data = np.loadtxt('optdigits.tes', dtype=np.float, delimiter=',')
...: x_test, y_test = np.split(data, (-1, ), axis=1)
...: print(y_test.shape)
...: images_test = x_test.reshape(-1, 8, 8)
...: y_test = y_test.ravel().astype(np.int)
...: print('Load Data OK...')
6.3.4 Divide the data set
In [23]: x, x_test, y, y_test = train_test_split(x, y, test_size=0.4, random_state=1)
6.3.5 Check the gray image of training data and test data
In [24]: # Training data
...: for index, image in enumerate(images[:16]): # Look at the head 16 Map
...: plt.subplot(4,8,index+1) # Add sub graph , And specify the position of the subgraph in the canvas
...: plt.imshow(image, cmap=plt.cm.gray_r,interpolation='nearest') # imshow Method , Show a picture ;interpolation='nearest' Use nearest neighbor interpolation
...: plt.title(u'train_images: %i' %y[index])
...:
...: # Test data
...: for index, image in enumerate(images_test[:16]):
...: plt.subplot(4,8,index+17)
...: plt.imshow(image,cmap=plt.cm.gray_r,interpolation='nearest')
...: plt.title(u'test_images: %i' % y_test[index])
...: plt.tight_layout()
...: plt.show()
6.4 Use cross validation to determine the optimal parameters of the model
6.4.1 Define the model parameter dictionary
6.4.1.1SVM The model has two very important parameters C And gamma. among C It's the penalty factor , Tolerance to error .C The higher the , The more intolerable the error is , Easy to overfit .C The smaller it is , Easy under fitting .C Too large or too small , Will lead to poor generalization ability
6.4.1.2gamma It's choice RBF Function as kernel after , This function has a parameter . Determines the distribution of data after mapping to the new feature space ,gamma The bigger it is , The less support vector ,gamma The smaller the value. , The more support vectors . The number of support vectors affects the speed of training and prediction .
In [31]: params = {
'C':np.logspace(0, 3, 7), 'gamma':np.logspace(-5, 0, 11)} # logspace Method to create an equal ratio sequence
6.4.2 Use GridSearchCV Build the model
6.4.2.1 Use ‘rbf’kernel, Use 3 Crossover verification .
In [32]: model = GridSearchCV(svm.SVC(kernel='rbf'), param_grid=params, cv=3)
6.4.3 Use model training
6.4.3.1 Use model training on the training set
In [33]: print ('Start Learning...')
...: t0 = time()
...: model.fit(x, y)
...:
6.4.3.2 Print training related information
In [35]: t1 = time()
...: t = t1 - t0
...: print (' Training +CV Time consuming :%d minute %.3f second ' % (int(t/60), t - 60*int(t/60)))
6.4.3.3 Print the optimal parameters of the model
In [36]: print (' Optimal parameters :\t', model.best_params_)
6.5 Evaluate the model
In [37]: print ('Learning is OK...')
...: print (' Training set accuracy :', accuracy_score(y, model.predict(x)))
...: y_hat = model.predict(x_test)
...: print (' Test set accuracy :', accuracy_score(y_test, model.predict(x_test)))
...: print (y_hat)
...: print (y_test)
7、 Refer to the answer
- Code list svm_case.py
8、 summary
Complete this experiment , Need to know 、 Master relevant knowledge about support vector machine , And hands-on programming . At the end of the experiment , We go through GridSearchCV Search the optimal parameters of the model , The accuracy of the model is improved . When you do it yourself , Different parameter dictionaries can be configured to obtain the optimal parameters of the model . Again , This experiment involves a lot of code programming , We need more hands-on programming , There are ways you don't understand 、 Class can look up related api Document resolution .
边栏推荐
- 域名dns被劫持怎么办、dns被劫持怎么办、dns被劫持的解决方法
- Flask Cross - Domain
- What is network hijacking? How to repair web pages that have been tampered with and hijacked (final scheme) how to repair web page hijacking?
- [summary of school recruitment] [review of old articles] Baidu internship gains meituan Netease Xiaomi Huawei vision offer
- Websites jump inexplicably. What is website hijacking from Baidu? How to solve Baidu snapshot hijacking
- 进程的互斥、同步
- Introduction to elastic search: search full text search (7)
- RetinaNet:Focal Loss for Dense Object Detection
- 将一些转义字符替换为指定标准的字符
- When using CV2 to realize face recognition, Chinese is displayed on the recognition frame
猜你喜欢
How to deal with DNS hijacking, DNS hijacking, and DNS hijacking solutions
使用Modelsim独立仿真Altera及Xilinx IP核
dns劫持如何完美修複?dns被劫持如何解决如何完美修複
Flutter 2 Advanced (I): practical skills of flutter
网站别黑了怎么解决?如何处理网站被黑问题详解
为什么memset不能将数组元素初始化为1?
求N!后面有多少个0
dns被劫持了怎么处理 5种方法教你处理
Easy operation commands of ES for getting started with elastic search (II)
DNS劫持如何预防、DNS是什么?DNS劫持详解
随机推荐
All subsets of nc27 set (I)
信号降噪方法
Leetcode 21. merge two ordered linked lists
euler五十讲(一)
What is network hijacking? How to repair web pages that have been tampered with and hijacked (final scheme) how to repair web page hijacking?
Fluent 2 Advanced (IX): fijkplayer plays videos and card effects
深入理解mmap函数
Usage and precautions of accumulator used in spark
使用CSDN-markdown编辑器入门
在使用cv2实现人脸识别时在识别框上显示中文
Redis accesses JSON data
信号量实现同步互斥经典案例
笔记:C语言
达梦数据库安装使用避坑指南
YOLO v1、v2、v3
AMBert
What is the phenomenon of DNS being hijacked? What is DNS? How to solve DNS hijacking
Fluent 2 Advanced (VI): use of callback functions
[reading notes] micro habits: slimming down
如何将Word转化为Markdown文本