当前位置:网站首页>LSTM 股价预测pytorch
LSTM 股价预测pytorch
2022-07-19 05:12:00 【hhllxx1121】
1.查看数据
2. 结果
3.代码
#!/usr/bin/env python
# coding: utf-8
# In[1]:
get_ipython().run_line_magic('config', 'Completer.use_jedi = False')
# In[2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from sklearn.preprocessing import MinMaxScaler
# In[3]:
filePath = './000001SH_index.csv'
seq_len = 8
batch_size = 64
input_size = 1
hidden_size = 8
num_layers = 1
# In[4]:
sc = MinMaxScaler(feature_range=(-1,1))
# In[5]:
raw_data = pd.read_csv(filePath)
data = raw_data.loc[:,['close']]
data = pd.DataFrame(sc.fit_transform(data))
X = []
Y = []
for i in range(data.shape[0]-seq_len):
X.append(np.array(data.iloc[i:i+seq_len,:].values,dtype=np.float32))
Y.append(np.array(data.iloc[i+seq_len,0],dtype=np.float32))
x_train,x_test = X[:int(len(Y)*0.9)],X[int(len(Y)*0.9):]
y_train,y_test = Y[:int(len(Y)*0.9)],Y[int(len(Y)*0.9):]
print(np.array(x_test).shape)
print(np.array(x_train).shape)
# In[6]:
raw_data.close[:int(len(Y)*0.9)].plot(figsize=(16,4))
raw_data.close[int(len(Y)*0.9):].plot(figsize=(16,4))
plt.show()
# In[7]:
class MyDataSet(Dataset):
def __init__(self,X,Y):
super().__init__()
self.xx,self.yy = X,Y
def __getitem__(self,index):
return self.xx[index],self.yy[index]
def __len__(self):
return len(self.yy)
trainDataSet = MyDataSet(x_train,y_train)
testDataSet = MyDataSet(x_test,y_test)
trainLoader = DataLoader(dataset=trainDataSet, batch_size=batch_size,shuffle=True)
testLoader = DataLoader(dataset=testDataSet, batch_size=batch_size,shuffle=True)
# In[8]:
class LSTM(nn.Module):
def __init__(self,input_size,hidden_size,num_layers):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size = self.input_size,hidden_size = self.hidden_size,num_layers=self.num_layers,batch_first=True)
self.liner = nn.Linear(in_features=self.hidden_size,out_features=1)
def forward(self, x):
_, (hidden, cell) = self.lstm(x)
out = self.liner(hidden)
return out.reshape(-1,1)
model = LSTM(input_size,hidden_size,num_layers)
# In[9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.MSELoss()
# In[10]:
epochs = 200
"""
data_x batch_size, seq_size, input_size=1
data_y batch_size, 1
"""
for epoch in range(epochs):
for i,(data_x,data_y) in enumerate(trainLoader):
pred = model(data_x)
pred = pred.reshape(-1)
loss = loss_function(pred, data_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 ==0:
print("epoch:{},loss========{}".format(epoch+10,loss.item()))
# In[11]:
y_pred = []
y_label = []
with torch.no_grad():
pred = model(torch.from_numpy(np.array(X)))
y_pred.extend(sc.inverse_transform(pred))
y_label = np.array(Y).reshape(-1,1)
y_label = sc.inverse_transform(y_label)
plt.figure(figsize=(16,4))
plt.plot(y_label)
plt.plot(y_pred)
plt.show()
plt.figure(figsize=(16,4))
plt.plot(y_label[0:100],label='raw data')
plt.plot(y_pred[0:100], label='pred data')
plt.legend()
plt.show()
边栏推荐
猜你喜欢
MPPT power controller design
Some problems or exceptions encountered in the project and their handling methods
电源学习(2)——基本元件
Talk about the redis cache penetration scenario and the corresponding solutions
js 运动函数封装函数,涉及 是否匀速、目标值、json参数等
Array common methods, principle simulation, and high order of common functions
Seata XA 模式示例分析
Crudapi add, delete, modify and check the successful case of interface zero code product: Chamber of Commerce Alliance card project
【深度学习】-Imdb数据集情感分析之模型对比(2)- LSTM
【深度学习】-Imdb数据集情感分析之模型对比(1)-RNN
随机推荐
事务处理(结合分布式事务)
YOLOv2论文中英文对照翻译
Block level element block inline element inline inline block level element inline block and mutual conversion
mysql 如何查询json格式的字段
VIVADO 错误代码 [USF-XSim-62] [XSIM 43-4316] 解决思路
Summary of Alibaba cloud technology points
百度飞桨七天训练营结营总结
Responsive layout [responsive] and adaptive layout [adaptive], single page [spa] and multi page [MPa]
Summary of margin consolidation issues
单片机2——动态数码管的一些实例
Write an Aidl
Want to try Web3 work? It's enough to read this article
The problem of data set CSV coding format in machine learning
通过JS将图片File转为base64并压缩
MPPT电源控制器设计
块级元素 block 内联元素 inline 内联块级元素inline-block 以及相互转换
想尝试 Web3 工作?看这篇文章就够了
使用Cocos Creator制作试玩广告(PlayableAd)
手写校验框架
如何抓取 app 网站 的数据