当前位置:网站首页>详解pytorch fold和unfold用法
详解pytorch fold和unfold用法
2022-07-20 05:32:00 【daimashiren】
先上结论,conv = unfold + matmul + fold
. 即卷积操作等价于,先unfold(展开),再执行矩阵乘法matmul,然后再fold(折叠)。具体过程如下:
unfold函数将一个输入Tensor(N,C,H,W) 展开成 (N,C * K1 * K2, Blocks),其中kernel形状为(K1,K2),总的Block数为Blocks。即把输入的Tensor根据kernel的大小展开成Blocks个向量。Block的计算公式如下:
B l o c k s = H b l o c k s × W b l o c k s Blocks = \text H_{blocks} \times W_{blocks} Blocks=Hblocks×Wblocks
其中:
H b l o c k s = H + 2 ∗ p a d d i n g [ 0 ] − k e r n e l [ 0 ] s t r i d e [ 0 ] + 1 H_{blocks} = \frac {H+2*padding[0]-kernel[0]}{stride[0]}+1 Hblocks=stride[0]H+2∗padding[0]−kernel[0]+1
W b l o c k s = W + 2 ∗ p a d d i n g [ 1 ] − k e r n e l [ 1 ] s t r i d e [ 1 ] + 1 W_{blocks} = \frac {W+2*padding[1]-kernel[1]}{stride[1]}+1 Wblocks=stride[1]W+2∗padding[1]−kernel[1]+1
代码举例:
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))#shape of inp_unf is (1,3*4*5,7*8)
其中,inp_unf的shape计算过程如下
H b l o c k s = 10 − 4 1 + 1 = 7 H_{blocks} = \frac {10-4}{1}+1 = 7 Hblocks=110−4+1=7
W b l o c k s = 12 − 5 1 + 1 = 8 W_{blocks} = \frac {12-5}{1}+1 = 8 Wblocks=112−5+1=8
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
#shape of out_unf is (1,2,56)
以上代码相当于 inp_unf(1, 60, 56) .t() * w(2 , 3 * 4 * 5).t() → out_unf (1, 56, 2 ) → out_unf (1, 2, 56)
unfold + matmul
已经完成,最后是 fold过程. fold过程其实就是unfold的反过程,即把向量折叠回矩阵形式。
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
#out.size() = (1,2,7,8)
以上过程其实等价于直接进行Conv,因此
(torch.nn.functional.conv2d(inp, w) - out).abs().max()
#tensor(1.9073e-06)
可以看出卷积的结果和经过了unfold + matmul + fold
的结果差距为10的-6次方,几乎可以认为是相等的了。
总结
利用pytorch 中fold 和unfold的组合可以实现类似Conv操作的滑动窗口,其中如果同一个图片的每个block
的参数都是相同的,那么称为参数共享,就是标准的卷积层;如果每个block
的参数都不一样,那么就不是参数共享的,此时一般称为局部连接层(Local connected layer)。
参考
https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
https://pytorch.org/docs/stable/generated/torch.nn.Fold.html
https://blog.csdn.net/LoseInVain/article/details/88139435
边栏推荐
- Paper reading -- risk resistant resource allocation for embB and urllc coexistence under m/g/1 queuing model
- (四)PyTorch深度学习:PytTorch实现线性回归
- Li Hongyi 2020 machine learning -- P11 logistic progression
- Hyperledger fabric super ledger CA construction and use
- 【CANN训练营】CANN训练营_昇腾AI趣味应用实现AI趣味应用(下)随笔
- Thesis study ---- urllc benefit from noma (1)
- The difference between struct and class in swift
- 【CANN訓練營】基於昇思的GAN實現隨筆
- Redis主从复制
- Deployment of Dameng DEM
猜你喜欢
【CANN训练营】基于昇腾CANN平台的AI CPU算子开发
MIMO - OFDM Wireless Communication Technology and matlab Implementation (2) - outdoor Channel Model under SISO
(5) Pytorch deep learning: logistic regression
数仓技术实现
Paper study -- resource allocation and beamforming desing in the short blocklength region for urllc
The role of 'defer' and 'async' attributes on the < srcipt> tag
(三)PyTorch深度学习:反向传播梯度下降
MIMO-OFDM无线通信技术及MATLAB实现(2)-SISO下的室外信道模型
(2) Pytorch deep learning: gradient descent
Hyperledger fabric super ledger CA construction and use
随机推荐
理解Seperable Convolution
Hyperledger fabric super ledger CA cluster construction and use
(10) Pytorch deep learning: convolutional neural network (simple residual convolutional neural network)
论文学习---Resource allocation and beamforming desing in the short blocklength regime for URLLC
Dameng index management
玩转CANN目标检测与识别一站式方案【基础篇】
Li Hongyi 2020 machine learning notes -- P10 classification
(5) Pytorch deep learning: logistic regression
SQL处理数据 同期群分析
Xcode compilation build number increases automatically
MySQL is migrated to Dameng through DTS
Opencv series of tutorials (I): opencv reads pictures and videos in specified folders and calls cameras
Basic overview of data warehouse
[camp d'entraînement can] essai de mise en œuvre du Gan basé sur Shengsi
The insertion order of single chain storage structure
Redis发布与订阅
Dameng password free login
Data Lake definition
[cann training camp] AI CPU operator development based on shengteng cann platform
Summary of flutter