当前位置:网站首页>(四)PyTorch深度学习:PytTorch实现线性回归
(四)PyTorch深度学习:PytTorch实现线性回归
2022-07-20 05:32:00 【Kkh_8686】
PytTorch实现线性回归
1、线性模型可以直接通过 torch 中的模块 torch.nn.Module 来继承获得,线性模型是所有神经网络模块最基本的类。
- 把线性模型构造一个类(常见的方法),构造的LinearModel模型类继承自torch.nn.Module(因为 Module 里面有很多方法,在模型训练中用到)。
- 这个类至少有下面两个def定义的函数 。第1个函数(构造函数)的作用:初始化对象,默认调用的函数;第2个函数的作用:进行前馈的过程中所要进行的计算。
- 没有写 backward函数,是因为使用了 torch.nn.Module 构造出来的对象会自动的根据
计算模型(函数模型)自动完成了backward操作了。 - 但是如果我们想要模型由基本的PyTorch支持的运算,可以封装成Module,然后实例化调用,自动反向传播。但有时候PyTorch计算图效率不高,如果你有更加快的方法求导,可以从 function 来继承(也是PyTorch里面的一个类)。
class LinearModel(torch.nn.MOdule):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
如上代码中:nn.Linear(weight, bias) # 分别是权重、编移。
y = w * x + b
(weight:w; bia:b)
2、损失函数:
size_average=False(不需要将加起来的损失求均值)
criterion = torch.nn.MSELoss(size_average=False)
3、优化器:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)中
第一个参数:model.parameters()。在我们线性模型(LinearmModel)里面,没有定义相对应的权重,只有一个Linear成员(但是Linear里面有两个权重参数)。参数:model.parameters(),会检查model里面的所有成员(w,b),加到需要训练的对应参数上。
第二个参数:lr,学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
4、训练过程:
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # 梯度归零
loss.backward()
optimizer.step() # 参数优化跟新
5、完整代码:
import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # 梯度归零
loss.backward()
optimizer.step() # 参数优化跟新
print('w = ',model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
边栏推荐
猜你喜欢
随机推荐
JS event flow (capture phase, target phase, bubble phase) cancels the default bubble behavior of the browser
TypeScript
Virtual DOM 的实现原理
Los Angeles: t226229 arithmetic series
页面性能:如何系统地优化页面?
[dish of learning notes dog learning C] detailed array name
ORALCE mapping CLOB
ECMAScript新特性
A. Log Chopping
使用反射的方式将RDD转换为DataFrame
阿里矢量图库 当前页全选
[dish of learning notes dog learning C] if branch statement and switch branch statement
【学习笔记之菜Dog学C】if分支语句与switch分支语句
Flink SQL自定义解析 Map和Array数据类型
ES6常用的新特性
[dish of learning notes dog learning C] exercise
Redis持久化
Redis发布与订阅
Configuration of Visual Studio development environment
mysql_账号授权权限回收、账号锁定解锁、账号创建删除