当前位置:网站首页>(一)PyTorch深度学习:线性模型训练
(一)PyTorch深度学习:线性模型训练
2022-07-20 05:32:00 【Kkh_8686】
(一)PyTorch教程:线性模型训练
1、如下列表,我们有X、Y的数据:
2、想要预测x=4时,y的值?我们就需要提出一个模型,使得输入X的值经过模型后得到输出值Y_pred与真实很接近,这是我们的目标。
提出线性模型:y_pred = x * w,式中y_pred是经过了线性模型得到的预测值。
3、训练损失函数:
损失函数:loss = (y_pred - y)^2 = (x * w - y),式中y是真实值,y_pred是预测值。
4、代码:
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# 定义模型:y = x * w
def forward(x):
return x * w
# 定义损失函数:loss = (y_pred - y) * (y_pred - y)
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) * (y_pred - y)
w_list = [] # 存放权重 w 的列表
mse_list = [] # 存放权重 w 对应的损失值列表
for w in np.arange(0.0, 4.1, 0.1): # 选择权重 w 以0.1间隔从0.0到4选值
print('W = ', w)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val) # 输入x_val,经过模型得到输出值y_pred_val
loss_val = loss(x_val, y_val) # 模型得到输出值y_pred_val减去实际值所得到的值的平方,得到损失值
l_sum += loss_val # 损失值求和
print('\t', x_val, y_val, y_pred_val, loss_val)
print('MSE = ', l_sum / 3) # 损失值的均值
w_list.append(w)
mse_list.append(l_sum / 3) # 添加损失值的均值到mse_list列表中
plt.plot(w_list, mse_list) # 绘制权重 w 和 权重对应的损失值的均值的图
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()
5、运行结果:
边栏推荐
- Swift中struct与class的区别
- xshell安装完,启动报错:由于找不到 mfc110.dll,无法继续执行代码。重新安装程序可能会解决此问题
- [dish of learning notes dog learning C] chain access, function declaration and definition, goto statement
- 列表元素相加
- 【学习笔记之菜Dog学C】if分支语句与switch分支语句
- 达梦ODBC安装
- OpenLayers Draw绘制时删除最后一个点
- mongoose使用validate验证, 获取自定义验证信息
- 二分图
- koa2 接收不到post方法提交的formData数据(值: {})
猜你喜欢
Knapsack problem (01 knapsack / full knapsack explanation)
04—— el 和 data 的两种写法
P7354 "pmoi-1" Knight chess
mysql 通过dts迁移至达梦
(九)PyTorch深度学习:卷积神经网络( GoogleNet网络架构中的 inception module 模块为本次卷积神经网络架构)
[dish of learning notes, dog learning C] first learn operators and original code, inverse code, complement code
【学习笔记之菜Dog学C】数组
Flink 分流之 Filter/Split/SideOutPut 比较
微服务理论介绍
[dish of learning notes dog learning C] chain access, function declaration and definition, goto statement
随机推荐
Los Angeles: t226229 arithmetic series
达梦免密登录
Netcat 简单的小工具模拟客户端/服务端
mysql 通过dts迁移至达梦
Antd mobile form validation RC form usage
数据湖定义
如何使用 IDEA 打 jar 包
分别用递归和非递归的方式实现二叉树先序、中序和后序遍历
达梦ODBC安装
idea配置
Flutter 小结
mysql_user表_字段含义
node 查询目标 目录下所有(文件或文件夹)名为 filename 的文件路径
Redis主从复制
kettle_配置数据库连接_报错
list类型转String类型
UITableView之性能优化
mysql_备份还原_指定表_备份表_还原表_innobackup
页面性能:如何系统地优化页面?
Shell 之 if/for/while/case 案例