当前位置:网站首页>pytorch优化器: optim.SGD && optimizer.zero_grad()
pytorch优化器: optim.SGD && optimizer.zero_grad()
2022-07-22 05:22:00 【ZwaterZ】
在神经网络优化器中,主要为了优化我们的神经网络,使神经网络在我们的训练过程中快起来,节省时间。在pytorch中提供了 torch.optim方法优化我们的神经网络,torch.optim 是实现各种优化算法的包。最常用的方法都已经支持,接口很常规,所以以后也可以很容易地集成更复杂的方法。
SGD就是optim中的一个算法(优化器):随机梯度下降算法
要使用torch.optim,你必须构造一个optimizer对象,这个对象能保存当前的参数状态并且基于计算梯度进行更新。
构建一个优化器
要构造一个优化器,你必须给他一个包含参数(必须都是variable对象)进行优化,然后可以指定optimizer的参数选项,比如学习率,权重衰减。具体参考torch.optim文档。
optimizer = optim.SGD(model.parameters(),
lr=1e-2,
momentum=0.9,
weight_decay=0.0005)
optimizer.zero_grad()
参数
1、model.parameters()是获取model网络的参数,构建好神经网络后,网络的参数都保存在parameters()函数当中。params (iterable) – 待优化参数的iterable(w和b的迭代) 或者是定义了参数组的dict
2、lr (float) – 学习率
3、momentum (float, 可选) – 动量因子(默认:0)
4、weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认:0)
5、dampening (float, 可选) – 动量的抑制因子(默认:0)
6、nesterov (bool, 可选) – 使用Nesterov动量(默认:False)
learning rate
1、学习率较小时,收敛到极值的速度较慢。
2、学习率较大时,容易在搜索过程中发生震荡。
weight decay
为了有效限制模型中的自由参数数量以避免过度拟合,可以调整成本函数。
一个简单的方法是通过在权重上引入零均值高斯先验值,这相当于将代价函数改变为E〜(w)= E(w)+λ2w2。
在实践中,这会惩罚较大的权重,并有效地限制模型中的自由度。
正则化参数λ决定了如何将原始成本E与大权重惩罚进行折衷。
learning rate decay
1、decay越小,学习率衰减地越慢,当decay = 0时,学习率保持不变。
2、decay越大,学习率衰减地越快,当decay = 1时,学习率衰减最快。
momentum
“冲量”这个概念源自于物理中的力学,表示力对时间的积累效应。
在普通的情况下x的更新 在加上冲量后就是在普通的情况下加上上次更新的x的与mom[0,1]的乘积
当本次梯度下降- dx * lr的方向与上次更新量v的方向相同时,上次的更新量能够对本次的搜索起到一个正向加速的作用。
当本次梯度下降- dx * lr的方向与上次更新量v的方向相反时,上次的更新量能够对本次的搜索起到一个减速的作用。
optimizer.zero_grad()
上图为一个简单的梯度下降示意图。以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。
lose
我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。
backward
调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。
optimizer.step
optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。
边栏推荐
- Oracle stored procedure parameter understanding
- arm64环境下,golang的第三方库hajimehoshi/oto依赖alsa-lib和cgo的解决方案
- Analyzing the upsurge of participating in robot education competition
- Grasp the development trend of robot education towards AI intelligence
- Use ffmpeg to push and pull streams
- GD32F470之can0收发+接收中断配置以及波特率计算(详细)
- SAP wper (POS interface monitor) idco posting voucher ALV Report
- 【Redis】分布式场景下Redis高可用部署方案
- Model compression, acceleration and mobile deployment
- Analyzing and optimizing robot course system and teaching strategy
猜你喜欢
Data Lake: evolution of data Lake Technology Architecture
跨域问题(CORS)详细说明和解决
Crack PLSQL by deleting the registry
还有人不会这些数据分析小案例吗?一招教你招聘数据可视化~
使用OpenCV实现哈哈镜效果
5.SSH远程服务
16_ Response status code
codeforce D2. RGB Substring (hard version) 滑動窗口
Cartopy绘图入门指南
[open hand] hande enterprise PAAS platform hzero heavy open source!
随机推荐
Tiktok tiktok get Tiktok video details interface
ES6中的一些新特性
汉得数字平台体系及试用知多少?
15_额外的模型
Hande apaas low code platform Feida 2.3.0 release was officially released!
Hande enterprise digital PAAS platform hzero version 1.9.0 was officially released!
How to connect Youxuan database on this computer
1840. 最高建筑高度 贪心
汉得aPaaS低代码平台 飞搭 2.3.0 RELEASE正式发布!
UE4 创建一个工程
线程系列协程原理
ig,ax = plt.subplots()
【红队】ATT&CK - 浏览器扩展实现持久化
UE4 用灰度图构建地形
shell(一)(更新中)
Oracle stored procedure parameter understanding
Visual system design example (Halcon WinForm) -8. matching search
[論文翻譯] Generalized Radiograph Representation Learning via Cross-Supervision Between Images
FPGA - 7系列 FPGA内部结构之Memory Resources -02- FIFO资源
Is there anyone who can't analyze these data cases? A trick to teach you how to visualize recruitment data~