当前位置:网站首页>LSTM stock price forecast pytorch
LSTM stock price forecast pytorch
2022-07-20 09:13:00 【hhllxx1121】
1. View the data
2. result
3. Code
#!/usr/bin/env python
# coding: utf-8
# In[1]:
get_ipython().run_line_magic('config', 'Completer.use_jedi = False')
# In[2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from sklearn.preprocessing import MinMaxScaler
# In[3]:
filePath = './000001SH_index.csv'
seq_len = 8
batch_size = 64
input_size = 1
hidden_size = 8
num_layers = 1
# In[4]:
sc = MinMaxScaler(feature_range=(-1,1))
# In[5]:
raw_data = pd.read_csv(filePath)
data = raw_data.loc[:,['close']]
data = pd.DataFrame(sc.fit_transform(data))
X = []
Y = []
for i in range(data.shape[0]-seq_len):
X.append(np.array(data.iloc[i:i+seq_len,:].values,dtype=np.float32))
Y.append(np.array(data.iloc[i+seq_len,0],dtype=np.float32))
x_train,x_test = X[:int(len(Y)*0.9)],X[int(len(Y)*0.9):]
y_train,y_test = Y[:int(len(Y)*0.9)],Y[int(len(Y)*0.9):]
print(np.array(x_test).shape)
print(np.array(x_train).shape)
# In[6]:
raw_data.close[:int(len(Y)*0.9)].plot(figsize=(16,4))
raw_data.close[int(len(Y)*0.9):].plot(figsize=(16,4))
plt.show()
# In[7]:
class MyDataSet(Dataset):
def __init__(self,X,Y):
super().__init__()
self.xx,self.yy = X,Y
def __getitem__(self,index):
return self.xx[index],self.yy[index]
def __len__(self):
return len(self.yy)
trainDataSet = MyDataSet(x_train,y_train)
testDataSet = MyDataSet(x_test,y_test)
trainLoader = DataLoader(dataset=trainDataSet, batch_size=batch_size,shuffle=True)
testLoader = DataLoader(dataset=testDataSet, batch_size=batch_size,shuffle=True)
# In[8]:
class LSTM(nn.Module):
def __init__(self,input_size,hidden_size,num_layers):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size = self.input_size,hidden_size = self.hidden_size,num_layers=self.num_layers,batch_first=True)
self.liner = nn.Linear(in_features=self.hidden_size,out_features=1)
def forward(self, x):
_, (hidden, cell) = self.lstm(x)
out = self.liner(hidden)
return out.reshape(-1,1)
model = LSTM(input_size,hidden_size,num_layers)
# In[9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.MSELoss()
# In[10]:
epochs = 200
"""
data_x batch_size, seq_size, input_size=1
data_y batch_size, 1
"""
for epoch in range(epochs):
for i,(data_x,data_y) in enumerate(trainLoader):
pred = model(data_x)
pred = pred.reshape(-1)
loss = loss_function(pred, data_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 ==0:
print("epoch:{},loss========{}".format(epoch+10,loss.item()))
# In[11]:
y_pred = []
y_label = []
with torch.no_grad():
pred = model(torch.from_numpy(np.array(X)))
y_pred.extend(sc.inverse_transform(pred))
y_label = np.array(Y).reshape(-1,1)
y_label = sc.inverse_transform(y_label)
plt.figure(figsize=(16,4))
plt.plot(y_label)
plt.plot(y_pred)
plt.show()
plt.figure(figsize=(16,4))
plt.plot(y_label[0:100],label='raw data')
plt.plot(y_pred[0:100], label='pred data')
plt.legend()
plt.show()
边栏推荐
- JS motion function encapsulation function, involving whether the speed is uniform, target value, JSON parameters, etc
- 手写校验框架
- Crudapi add, delete, modify and check the successful case of interface zero code product: Chamber of Commerce Alliance card project
- Tensorflow学习笔记--张量与会话
- When programmers have no Internet, how can they continue to learn to write code
- Redis data types and application scenarios
- ShardingJDBC
- MySQL log module
- Create playablead with cocos Creator
- ShardingJDBC
猜你喜欢
Android studio executes kotlin throwing com android. builder. errors. Solution to evalissueexception problem
YOLOv1详解
get post 区别 以及get 为啥比post要快
【深度学习】-Imdb数据集情感分析之模型对比(1)-RNN
CPU架构兼容
使用Cocos Creator制作试玩广告(PlayableAd)
说说如何安装 Openfire
Want to try Web3 work? It's enough to read this article
uniapp中引入自定义图标
【论文导读】Selecting Data Augmentation for Simulating Interventions
随机推荐
js Qrcode.js实现文字内容通过二维码展示
ShardingJDBC
TCP三次握手和四次挥
【一些有关GraN-DAG的知识点总结】
对于因果模型的常见评估函数:SHD 和 FDR
【论文导读】Continuity Scaling: A Rigorous Framework for Detecting andQuantifying Causality Accurately
说说 Redis 缓存删除策略
YOLOv2论文中英文对照翻译
【论文导读‘‘】Causal Protein-Signaling Networks Derived from Multiparameter Single-Cell Data
Want to try Web3 work? It's enough to read this article
ashx aspx
航天信息开电子发票 3.0 以及4.0(电子发票)
说说如何安装 Openfire
开发日常异常问题汇总
静态库.a文件和.framework文件
Create playablead with cocos Creator
Asp. NET <%=%> <%#%> <% %> <%@%>
Android Studio 执行 Kotlin 抛出 com.android.builder.errors.EvalIssueException 问题的解决方法
uniapp 微信小程序分享、分享朋友圈功能
[untitled] MySQL binlog data recovery process