当前位置:网站首页>Explain the usage of pytorch fold and unfold in detail
Explain the usage of pytorch fold and unfold in detail
2022-07-21 03:44:00 【daimashiren】
Come to the conclusion first ,conv = unfold + matmul + fold
. That is, convolution operation is equivalent to , First unfold( an ), Then perform matrix multiplication matmul, And then again fold( Fold ). The specific process is as follows :
unfold The function takes an input Tensor(N,C,H,W) Unfold into (N,C * K1 * K2, Blocks), among kernel Shape is (K1,K2), The total Block The number of Blocks. That is, put the input Tensor according to kernel Expand the size of into Blocks Vector .Block The calculation formula of is as follows :
B l o c k s = H b l o c k s × W b l o c k s Blocks = \text H_{blocks} \times W_{blocks} Blocks=Hblocks×Wblocks
among :
H b l o c k s = H + 2 ∗ p a d d i n g [ 0 ] − k e r n e l [ 0 ] s t r i d e [ 0 ] + 1 H_{blocks} = \frac {H+2*padding[0]-kernel[0]}{stride[0]}+1 Hblocks=stride[0]H+2∗padding[0]−kernel[0]+1
W b l o c k s = W + 2 ∗ p a d d i n g [ 1 ] − k e r n e l [ 1 ] s t r i d e [ 1 ] + 1 W_{blocks} = \frac {W+2*padding[1]-kernel[1]}{stride[1]}+1 Wblocks=stride[1]W+2∗padding[1]−kernel[1]+1
The code for :
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))#shape of inp_unf is (1,3*4*5,7*8)
among ,inp_unf Of shape The calculation process is as follows
H b l o c k s = 10 − 4 1 + 1 = 7 H_{blocks} = \frac {10-4}{1}+1 = 7 Hblocks=110−4+1=7
W b l o c k s = 12 − 5 1 + 1 = 8 W_{blocks} = \frac {12-5}{1}+1 = 8 Wblocks=112−5+1=8
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
#shape of out_unf is (1,2,56)
The above code is equivalent to inp_unf(1, 60, 56) .t() * w(2 , 3 * 4 * 5).t() → out_unf (1, 56, 2 ) → out_unf (1, 2, 56)
unfold + matmul
Already completed , And finally fold The process . fold The process is actually unfold The reverse process of , That is, fold the vector back to matrix form .
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
#out.size() = (1,2,7,8)
The above process is actually equivalent to direct Conv, therefore
(torch.nn.functional.conv2d(inp, w) - out).abs().max()
#tensor(1.9073e-06)
You can see the result and process of convolution unfold + matmul + fold
The result gap is 10 Of -6 Power , It can almost be considered equal .
summary
utilize pytorch in fold and unfold The combination of can achieve similar Conv Sliding window for operation , If each of the same picture block
The parameters of are the same , Then it is called parameter sharing , Is the standard convolution layer ; If each block
The parameters of are different , Then it is not parameter sharing , This is generally called local connection layer (Local connected layer).
Reference resources
https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
https://pytorch.org/docs/stable/generated/torch.nn.Fold.html
https://blog.csdn.net/LoseInVain/article/details/88139435
边栏推荐
- Netcat simple gadget simulates client / server
- #HPDC 华为伙伴暨开发者大会2022随笔
- Calculate the sum of the factorials of the first five numbers
- Filter/split/sideoutput comparison of Flink diversion
- Dameng collects statistics
- Swift中struct与class的区别
- Dameng AWR report extraction
- 微分与梯度的概念理解
- (1) Pytorch deep learning: linear model training
- 昇腾工业质检应用实战
猜你喜欢
MIMO-OFDM无线通信技术及MATLAB实现读书笔记-衰落信道&室内信道(1)
2D目标检测综述之神经网络篇
理解Seperable Convolution
【CANN训练营】基于昇腾CANN平台的AI CPU算子开发
MIMO - OFDM Wireless Communication Technology and matlab Implementation (2) - outdoor Channel Model under SISO
[système robotique Ros] navigation autonome + détection de cibles Yolo + annonce vocale
Li Hongyi 2020 machine learning -- P11 logistic progression
【CANN训练营】CANN训练营_昇腾AI趣味应用实现AI趣味应用(上)随笔
(8) Pytorch deep learning: convolutional neural network (basic) -- change (VII) fully connected neural network into convolutional neural network
【2022年第一期 CANN训练营进阶班模型课】第一次大作业和附加内容
随机推荐
(三)PyTorch深度学习:反向传播梯度下降
Find a number between two Fibonacci series
List element addition
电脑端微信有很多垃圾可以清理
Flink datastream API (XIV) Flink output to MySQL (JDBC)
Summary of flutter
论文学习---Non-parametric Statistical Learning for URLLC Transmission Rate Control
Flink SQL implements the data processing of Kafka and writes it to tidb
李宏毅机器学习2020---P12 Brief introduction of DL & P15 Why DL
EndNote X9导入论文投稿的期刊参考文献格式
Statistical information collection of dream
论文阅读--Risk-Resistant Resource Allocation for eMBB and URLLC coexistence under M/G/1 Queueing Model
Data generator - supports multiple types
MIMO-OFDM无线通信技术及MATLAB实现读书笔记-衰落信道&室内信道(1)
Tensorflow 1.x 和 Pytorch 中 Conv2d Padding的区别
(4) Pyqt5 series tutorials: use pychart to design the internal logic of pyqt5 in the serial port assistant parameter options (I)
Realization of data warehouse technology
TensorFlow v1 入门教程
Deployment of Dameng DEM
Antd mobile form validation RC form usage