当前位置:网站首页>【PyTorch深度学习实践】学习笔记 第三节 梯度下降
【PyTorch深度学习实践】学习笔记 第三节 梯度下降
2022-07-22 10:55:00 【咯吱咯吱咕嘟咕嘟】
开头
- 去年三月份学习的PyTorch深度学习实践课程,当时在有道笔记做了笔记并且实践了。现在好久没接触已经忘了。。。orz
回顾了下当时写的毕设路线—pytorch环境下的深度学习的高光谱图像分类问题文章,决定整理一下笔记,也为了能快速复习。希望按照这里面的顺序,把坑都填上,立个flag,这一周把坑都回顾一遍。Let’s
go! - 这一系列的博客都是PyTorch深度学习实践该课程的学习笔记。因为老师上传了ppt,我就不再截取ppt介绍原理了。
- 建议听一节课程,跟着练习实践一节,注重代码实现细节。这样点滴积累后就可以上手复杂的项目了。
这节课是最基础的一节之一,是自己定义的损失函数cost和计算梯度grad,能更好的理解原理。在今后的深度学习项目都是直接用的torch里的packages的封装函数了。
import matplotlib.pyplot as plt
# prepare the training set
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# initial guess of weight
w = 1.0
# define the model linear model y = w*x
def forward(x):
return x*w
#define the cost function MSE
def cost(xs, ys):
cost = 0
for x, y in zip(xs,ys):
y_pred = forward(x)
cost += (y_pred - y)**2
return cost / len(xs) # cost算的是全部样本的 MSE 平均loss
# define the gradient function gd
def gradient(xs,ys):
grad = 0
for x, y in zip(xs,ys):
grad += 2*x*(x*w - y) #所有样本预测与实际的loss对w的求导 导函数
return grad / len(xs) #对所有样本的梯度求平均
epoch_list = []
cost_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w-= 0.01 * grad_val # 0.01 learning rate 在这里更新w 为了下个epoch重新计算predict结果
print('epoch:', epoch, 'w=', w, 'loss=', cost_val)
epoch_list.append(epoch)
cost_list.append(cost_val)
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list,cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()
运行结果:
predict (before training) 4 4.0
predict (after training) 4 7.999998569488525
补充
上面设计的是梯度下降法,即计算的是全部样本的平均损失。
随机梯度下降法(SGD)和梯度下降法的主要区别在于:
1、损失函数由cost()更改为loss()。cost是计算所有训练数据的损失(循环中需要cost+=),loss是计算一个训练函数的损失。对应于源代码则是少了两个for循环。(在def loss和grad函数中去掉for循环,而在训练的epoch中变成两个for的循环了。)
2、梯度函数gradient()由计算所有训练数据的梯度更改为计算一个训练数据的梯度。
3、本算法中的随机梯度主要是指,每次拿一个训练数据来训练,然后更新梯度参数。本算法中梯度总共更新100(epoch)x3 = 300次(每个样本计算一次就要更新一下w)。梯度下降法中梯度总共更新100(epoch)次(算的是一轮中所有样本的平均)。
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
def forward(x):
return x*w
# calculate loss function
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2
# define the gradient function sgd
def gradient(x, y):
return 2*x*(x*w - y)
epoch_list = []
loss_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
for x,y in zip(x_data, y_data):
grad = gradient(x,y)
w = w - 0.01*grad # update weight by every grad of sample of training set
print("\tgrad:", x, y,grad)
l = loss(x,y)
print("progress:",epoch,"w=",w,"loss=",l)
epoch_list.append(epoch)
loss_list.append(l)
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list,loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
结果:
可以看出SGD的收敛更快,训练周期短。
by 小李
如果你坚持到这里了,请一定不要停,山顶的景色更迷人!好戏还在后面呢。加油!
边栏推荐
- Set colSpan invalidation for TD of table
- 项目中手机、姓名、身份证信息等在日志和响应数据中脱敏操作
- 使用js写个3d banner
- 解决TraceCompass网站打不开和Stackoverflow显示不全的问题
- 微信小程序Cannot read property 'setData' of null错误
- Airtest test framework construction
- BUUCTF闖關日記--[網鼎杯 2020 青龍組]AreUSerialz
- 多线程04--线程的原子性、CAS
- Buuctf entry diary -- [mrctf2020] how about you (super detailed)
- 多线程01--创建线程和线程状态
猜你喜欢
随机推荐
Set colSpan invalidation for TD of table
BUUCTF闯关日记--[极客大挑战 2019]HardSQL1
Bash基本功能—通配符和其他特殊符号
Style writing in next
Buuctf breakthrough diary -- [netding cup 2020 Qinglong group]areuserialz
Pytest testing framework built quickly
使用简单的js实现圆弧布局
L'applet Wechat ne peut pas lire la propriété 'setdata' de NULL Error
BUUCTF闯关日记--[网鼎杯 2020 青龙组]AreUSerialz
Bash基本功能—历史命令与补全
第四章:minio的presigned URLs上传文件
Airtest conducts webui automated testing (selenium)
pycharm设置
RPM包管理—YUM在线管理-IP地址配置和网络YUM源
selenium测试框架快速搭建(ui自动化测试)
Redis 系列14--Redis Cluster
Seata 初识
[LTTng学习之旅]------Trace控制--进阶
Bash变量--位置参数变量
Redis series 14 -- redis cluster