当前位置:网站首页>【PyTorch教程】05-如何使用PyTorch训练神经网络模型 (2022年最新)
【PyTorch教程】05-如何使用PyTorch训练神经网络模型 (2022年最新)
2022-07-21 05:10:00 【自牧君】

使用PyTorch训练神经网络:torch.autograd
1. 神经网络背景
神经网络 (Neural Networks, NN) 是在输入数据上执行的嵌套函数的集合。这些函数由参数 (由权重 weights
和偏差 biases
组成) 定义,这些参数在PyTorch中存储在张量中。
训练神经网络分为两步:
- 正向传播:在正向传播中,神经网络对正确的输出结果尽力作出最佳猜测。输入数据要在神经网络中的每个函数都运行一遍,才能得出推理结果。
- 反向传播:在反向传播中,神经网络根据输出结果的误差来调整、纠正其网络参数。从输出结果开始,从后往前遍历,收集有关网络参数 (梯度) 的误差的导数,并使用梯度下降法优化网络参数。
2. 加载预训练模型(有重大更新)
最近 (2022年7月) 安装或者更新了 PyTorch 和 torchvision 的同志们可能跑代码时遇到了下面的报错:
UserWarning: The parameter ‘pretrained’ is deprecated since 0.13 and will be removed in 0.15, please use ‘weights’ instead.
UserWarning: Arguments other than a weight enum or
None
for ‘weights’ are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passingweights=ResNet50_Weights.IMAGENET1K_V1
. You can also useweights=ResNet50_Weights.DEFAULT
to get the most up-to-date weights.Expected type ‘Optional[ResNet50_Weights]’, got ‘str’ instead
这是因为 torchvision 0.13对预训练模型加载方式作出了重大更新造成的。今天一次性就可以把上面3条Bug全部消灭。
从 torchvision 0.13开始,torchvision提供一个全新的多权重支持API (Multi-weight support API) ,支持将不同版本的权重参数文件加载到模型中。
2.1 新老版本写法对比
从 torchvision 0.13开始,加载预训练模型函数的参数从
pretrained = True
改为weights=预训练模型参数版本
。且旧版本的写法将在未来的torchvision 0.15版本中被Deprecated 。
举个例子:
from torchvision import models
# 旧版本的写法,将在未来的torchvision 0.15版本中被Deprecated
model_old = models.resnet50(pretrained=True) # deprecated
model_old = models.resnet50(True) # deprecated
# torchvision 0.13及以后的新版本写法
model_new = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
# 没有预训练模型加载
model = models.resnet50(weights=None)
model = models.resnet50()
其中,第7行代码的 IMAGENET1K_V1
表示的是 ResNet-50 在 ImageNet 数据集上进行预训练的第一个版本的权重参数文件。是一个版本标识符。
2.2 新写法的好处
在旧版本的写法 pretrained = True
中,对于预训练权重参数我们没有太多选择的余地,一执行起来就要使用默认的预训练权重文件版本。但问题是,现在深度学习的发展日新月异,很快就有性能更强的模型横空出世。
而使用新版本写法 weights=预训练模型参数版本
,相当于我们掌握了预训练权重参数文件的选择权。我们就可以尽情地使用更准更快更强更新的预训练权重参数文件,帮助我们的研究更上一层楼。
举个例子:
from torchvision import models
# 加载精度为76.130%的旧权重参数文件V1
model_v1 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# 等价写法
model_v1 = models.resnet50(weights="IMAGENET1K_V1")
# 加载精度为80.858%的新权重参数文件V2
model_v2 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# 等价写法
model_v1 = models.resnet50(weights="IMAGENET1K_V2")
如果你不知道哪个权重文件的版本是最新的,没关系,直接选择默认DEFAULT即可。官方会随着 torchvision 的升级而让 DEFAULT 权重文件版本保持在最新。如下代码所示:
from torchvision import models
# 如果你不知道哪个版本是最新, 直接选择默认DEFAULT即可
model_new = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
2.3 图像数据的预处理
在使用预训练模型之前,往往需要对输入图片进行预处理。但图片预处理的方式多种多样,且不同的模型之间的数据预处理方式也不一样,甚至同一个模型的不同版本之间的预处理方式都可能不相同。如果没有采用正确的预处理,可能将导致模型精度下降甚至输出错误结果。
基于这个背景,torchvision 很贴心地帮我们把每种模型的数据预处理方式都集成到其对于的模型权重文件中。这样,我们就可以很轻松地采用合适的方式对输入图片进行数据预处理了。通过 weights.transforms()
函数实现。
举个例子:
from torchvision import models
# 初始化预处理方式preprocess
weights = models.ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# 应用到输入图片img
img_transformed = preprocess(img)
2.4 训练模式和验证模式之间的转换
一些模块拥有不同的训练和验证过程,例如,batch normalization 。在 torchvision中,可以通过 train()
和 eval()
两个函数在训练和验证模式之间转换。
举个例子:
from torchvision import models
# 初始化模型
weights = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=weights)
# 将模型切换至验证模式
model.eval()
3. 正向传播
首先来看训练的其中一个步骤:正向传播。
3.1 初始化输入数据、标签和模型
从 torchvision
加载ResNet-18的预训练模型。然后,创建一个形状为 ( 3 × 64 × 64 ) (3\times 64\times 64) (3×64×64) 三维随机值张量来代表一张 64 × 64 64\times 64 64×64 像素的三通道的彩色图片。其对应的标签 label
初始化为一个形状为 ( 1 , 1000 ) (1, 1000) (1,1000) 的随机张量。
shape = (1, 3, 64, 64) # 三通道图片的形状,1是batch-size
img = torch.rand(shape).to('cuda') # 三通道图片
label = torch.rand(1, 1000).to('cuda') # 标签,对于1000个类别
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to('cuda') # 创建ResNet-18模型并加载预训练权重参数
【注意】
如果想让GPU进行正向传播和反向传播的运算,则必须同时把输入图片
img
和神经网络模型model
移动到GPU才行。
3.2 预测
接下来,输入图片 img
将通过模型的每一层神经网络,从而作出预测。这就是前向传播 (也叫正向传播) 。
# 前向传播,作出预测
prediction = model(img)
我们可以观察预测结果 prediction
的形状和所在设备:
print(f"Shape of prediction: {
prediction.shape}")
print(f"Device of prediction: {
prediction.device}")
输出:
Shape of prediction: torch.Size([1, 1000])
Device of prediction: cuda:0
4. 反向传播
4.1 计算损失函数
使用刚刚计算得到的预测结果 prediction
和输入图片对应的标签 label
来计算预测的误差,也称为损失函数 loss
。下一个步骤就是通过网络反向传播这个误差,当我们在误差张量上调用 .backward()
函数时,就开始反向传播。然后 Autograd
就会计算每个模型参数的梯度,并将这个梯度存储在参数的 .grad
属性中。
接着上面的例子:
loss = (prediction - labels).sum() # 损失函数
loss.backward() # 逆推计算
其中,为了对新手友好,这里的损失函数 loss
采用简单的与标签 labels
对应元素的查,再求和的方式来求。在后面实际中使用的损失函数大多数比这个复杂一些。
我们可以把误差 loss
打印出来看一下:
Loss result: -501.713134765625
4.2 加载优化器
紧接着,我们加载一个优化器 (optimizer) ,在这个例子中使用随机梯度下降法 SGD (Stochastic Gradient Descent) ,学习率 (learning rate) 为 0.01,动量 (momentum) 为 0.9 。把模型的所有参数都输入到优化器中:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
4.3 开始优化参数
最后一步,调用 .step()
方法来初始化随机梯度下降SGD。优化器 optimizer
会根据存储在 .grad
中的梯度来自动调整每个参数。
optimizer.step() # 梯度下降
至此,恭喜你已经成功掌握了训练你的神经网络所需的所有技能。如果你还想了解 autograd
背后的机制和原理,请移步到这一篇博文。
完整训练步骤如下:
import torch
from torchvision import models
# 伪造输入图片
shape = (1, 3, 64, 64) # 三通道图片的形状
img = torch.rand(shape).to('cuda') # 三通道图片
# 伪造标签
labels = torch.rand(1, 1000).to('cuda') # 标签
# 初始化模型
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to('cuda') # 创建ResNet-18模型并加载预训练权重参数
# 前向传播,作出预测
prediction = model(img)
# 计算损失函数,反向传播
loss = (prediction - labels).sum() # 损失函数
loss.backward() # 逆推计算
# 加载优化器
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
# 启动优化器
optimizer.step() # 梯度下降
边栏推荐
- Redis(六) - Redis企业实战之商户查询缓存
- 易语言学习笔记(二)
- 投票 | 选出您希望Navicat支持的数据库
- 并发编程(二十七) - JUC之原子类
- 2021-10-23
- Extract a subset from a point cloud
- 2021-08-11
- peoplecode 定义的名字引用
- 并发编程(十九)-JUC之AQS
- boost::this_ Thread:: sleep (boost:: posix_time:: microseconds (100000)) reports an error "this_thread": "the symbol on the left side of":: "must be of a type
猜你喜欢
Extract a subset from a point cloud
并发编程(三十一) - ReetrantReadWriteLock 读写锁原理
Concurrent programming (XXXI) - principle of reetrantreadwritelock
Peoplecode 运算符
Deep understanding of pointers (bubble sorting simulation implements qsort, and function pointers implement callback functions)
干货 | RDBMS 索引类型概述
Concurrent programming (XXII) -reentrantlock condition variable implementation principle
并发编程(十九)-JUC之AQS
[permission promotion] MySQL authorization raising method
yum check 时报错libmysqlclient.so.18()(64bit)
随机推荐
Redis(六) - Redis企业实战之商户查询缓存
Pointnet++ training record partseg
Deep analysis of data storage in memory
Fossage 2.0-metaforce force force chain operation race tutorial
浅析 SQL Server 的 CROSS APPLY 和 OUTER APPLY 查询 - 第二部分
PCL runtime ucrtbased Exception thrown by DLL
peoplesoft 更新表接口程序
Centos7上的PostgreSQL开启SSL配置
数据库监控的重要性
Force buckle 26 Delete duplicates in the ordered array 88 Merge two ordered arrays and 189 Rotate array
How to use iterative closest point ICP
并发编程(二十二)-ReentrantLock 条件变量实现原理
2021-08-11
Matlab2021a configuration GPU encountered error c1083: unable to open include file: "gpu/mxgpuarray.h": no such file or directory
SQL Server | Unicode 和非 Unicode 字符串数据类型
Redis(五) - Redis企业实战之短信登录
Oracle RAC镜像恢复的单实例数据库Redo日志增量抽取报错: ORA-01291 & 删除日志组报错: ORA-01567
SQL select 语句
Concurrent programming (XXVII) - Atomic classes of JUC
浅析 SQL Server 的 CROSS APPLY 和 OUTER APPLY 查询 - 第一部分