当前位置:网站首页>Python predicts the model code demo of the next number through the first three numbers
Python predicts the model code demo of the next number through the first three numbers
2022-07-22 02:54:00 【wzw12315】
Related knowledge :
- Use scatter_ Conduct one-hot code
scatter_(self, dim, index, value) take src The data in is based on index The index in is in accordance with dim Fill in the direction of value value , This function can be converted from to onehot Code to understand
LSTM return output、hidden and cell
output, (hidden,cell) = self.LSTM(x)
RNN return output、hidden
output, hidden = self.RNN(x)
# -----------------------------------
# Module import
import numpy
import torch
from torch import nn
# -----------------------------------
# Data preprocessing
data_length = 30
# Definition 30 Number , Through the first three predictions, the latter , such as :1,2,3->4
seq_length = 3
# It can be seen from the above that the sequence length is 3
number = [i for i in range(data_length)]
li_x = []
li_y = []
for i in range(0, data_length - seq_length):
x = number[i: i + seq_length]
y = number[i + seq_length]
li_x.append(x)
li_y.append(y)
#number: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
#li_x: [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, 8], [7, 8, 9], [8, 9, 10],
# [9, 10, 11], [10, 11, 12], [11, 12, 13], [12, 13, 14], [13, 14, 15], [14, 15, 16], [15, 16, 17], [16, 17, 18], [17, 18, 19],
# [18, 19, 20], [19, 20, 21], [20, 21, 22], [21, 22, 23], [22, 23, 24], [23, 24, 25], [24, 25, 26], [25, 26, 27], [26, 27, 28]]
#li_y: [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
# Input data format :seq_len, batch, input_size (27,1,3)
data_x = numpy.reshape(li_x, (len(li_x), 1, seq_length))
# Normalize the input data
data_x = torch.from_numpy(data_x / float(data_length)).float()
# scatter_ Function to convert onehot code
# Set the output data to one-hot code
data_y = torch.zeros(len(li_y), data_length).scatter_(1, torch.tensor(li_y).unsqueeze_(dim=1), 1).float()
# -----------------------------------
# Define the network model
class net(nn.Module):
# Model structure :LSTM + Full connection + Softmax
def __init__(self, input_size, hidden_size, output_size, num_layer):
super(net, self).__init__()
# LSTM return output、hidden and cell
# RNN return output、hidden
self.layer1 = nn.LSTM(input_size, hidden_size, num_layer)
self.layer2 = nn.Linear(hidden_size, output_size)
self.layer3 = nn.Softmax()
def forward(self, x):
x, _ = self.layer1(x)
# Format :[27, 1, 32], Represents the number of samples ,batch Size and hidden layer size
sample, batch, hidden = x.size()
x = x.reshape(-1, hidden)
# Convert it into a two-dimensional matrix and calculate it with full connection
x = self.layer2(x)
x = self.layer3(x)
return x
model = net(seq_length, 32, data_length, 4)
# -----------------------------------
# Define the loss function and optimizer
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# -----------------------------------
# Training models
# Before training, you can first look at the difference between the predicted results of the initialized parameters
# result = model(data_x)
# for target, pred in zip(data_y, result):
# print("{} -> {}".format(target.argmax().data, pred.argmax().data))
# Start training 1000 round
for _ in range(1000):
output = model(data_x)
loss = loss_fun(data_y, output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (_ + 1) % 50 == 0:
print('Epoch: {}, Loss: {}'.format(_, loss.data))
# -----------------------------------
# Predicted results
result = model(data_x)
for target, pred in zip(data_y, result):
print(" Correct result :{}, forecast :{}".format(target.argmax().data, pred.argmax().data))
precision = (result.argmax(dim=1).data==data_y.argmax(dim=1).data)
# result :
# Correct result :3, forecast :3
# Correct result :4, forecast :4
# Correct result :5, forecast :5
# Correct result :6, forecast :6
# Correct result :7, forecast :7
# Correct result :8, forecast :8
# Correct result :9, forecast :9
# Correct result :10, forecast :10
# Correct result :11, forecast :11
# Correct result :12, forecast :12
# Correct result :13, forecast :13
# Correct result :14, forecast :14
# Correct result :15, forecast :15
# Correct result :16, forecast :16
# Correct result :17, forecast :21
# Correct result :18, forecast :18
# Correct result :19, forecast :27
# Correct result :20, forecast :21
# Correct result :21, forecast :21
# Correct result :22, forecast :21
# Correct result :23, forecast :21
# Correct result :24, forecast :24
# Correct result :25, forecast :25
# Correct result :26, forecast :26
# Correct result :27, forecast :27
# Correct result :28, forecast :28
# Correct result :29, forecast :29
边栏推荐
- Kuberntes cloud native combat high availability deployment architecture
- 2019杭电多校 第一场 6581-Vacation【思维】
- 函数之递归[通俗易懂]
- tslib-1.4移植到mini2440开发板
- Product code update code
- Macro summary of C language
- pytorch入门三 数据类型与函数
- pytorch入门二 使用pyplot动态展示函数拟合过程
- [Skynet] vs2019 debug Skynet (modify vs2013 project)
- 李宏毅深度学习课程笔记 -卷积神经网络
猜你喜欢
卷积核扩大到51x51,新型CNN架构SLaK反击Transformer
不懂点儿统计学,《星球大战》白看了
【Skynet】vs2019调试skynet(修改vs2013工程)
力扣 1260. 二维网格迁移
grafana可视化配置图表table
Airbnb 如何实现 Kubernetes 集群动态扩展
This price is fragrant enough! Lingyao 142022 shadow cyan glaze spike: 12th generation core +2.8k OLED screen
Unity2D~对周目解密小游戏练习(三天完成)
有一说一,要搞明白优惠券架构是如何演化的,只需10张图!
荐号 | 真正的卓越者,都在践行“人生最优策略”,推荐这几个优质号
随机推荐
Kuberntes cloud native combat high availability deployment architecture
Nature | Yang 祎 et al. Revealed that the evolution within the host may lead to the pathogenesis of intestinal symbiotic bacteria
Did someone cut someone with a knife on Shanghai Metro Line 9? Rail transit public security: safety drill
使用kmean进行图像分割 使用CRFs进行分割后处理
Codeforces Round #578 (Div. 2) C - Round Corridor 【数论+规律】
复习memcpy函数的代码实现
WDK开发入门1-基础环境搭建和第一个驱动程序(VS2010)
Kubernetes Pod篇:带你轻松玩转Pod(下篇)
Use dichotomy to find peak value
70. 爬楼梯:假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢?
matplotlib.pyplot接口汇总
pytorch实现 分组卷积 深度可分离卷积
完美+今日头条笔试题+知识点总结
2019杭电多校 第九场 6684-Rikka with Game【博弈题】
The efficiency principle that ISR should follow
[wechat applet] camera system camera (79/100)
bootloader系列一——Arm处理器启动流程解析
Mutex和智能指针替代读写锁
Web Monitoring - mjpg streamer migration
Why does a very simple function crash