当前位置:网站首页>[machine learning] how pytorch loads custom datasets and divides them
[machine learning] how pytorch loads custom datasets and divides them
2022-07-22 17:16:00 【Fish and jade meet rain】
Example data set : Genetic disease association data set
Data set containing : Sample number 、 Sample labels 、 And the characteristics of genetic diseases ( It is composed of the two features dim=256)
1. adopt torch.utils.data.random_split Divide 7:3
Inherit torch.utils.data.Dataset class
class Mirna_die_Dataset(Dataset):
def __init__(self, x=None,y=None):
data_file = os.path.join('data', 'labelmirnadiease', 'wfy_label_mirna_die.csv') # Load from
data_frame = pd.read_csv(data_file, header=None)
print(len(data_frame))
outputs, inputs = data_frame.iloc[:, 1], data_frame.iloc[:, 2:258]
# x, y = torch.tensor(inputs.values,dtype=torch.float32), torch.tensor(outputs.values.reshape(-1, 1),dtype=torch.float32)
x, y = torch.tensor(inputs.values, dtype=torch.float32), torch.tensor(outputs.values, dtype=torch.float32)
self.x = x
self.y = y
def __len__(self): # Return the number of samples
return len(self.x)
def __getitem__(self, idx):# Return samples and labels according to the index (tensor type )
x = self.x[idx]
y = self.y[idx]
return x,y
Instantiation Dataset class
md_dataset = Mirna_die_Dataset()
Divide the training set and the test set
train_size = int(len(md_dataset) * 0.7)
test_size = len(md_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(md_dataset, [train_size, test_size])
Generate data iterators data_iter
train_iter = DataLoader(train_dataset, batch_size=5,
shuffle=True, num_workers=0)
test_iter = DataLoader(test_dataset, batch_size=5,
shuffle=True, num_workers=0)
utilize iter Training
def train_epoch_ch3(net, train_iter, loss, updater): #@save
""" The training model has an iterative cycle ( As defined in section 3 Chapter )."""
# Set the model to training mode
if isinstance(net, torch.nn.Module):
net.train()
# The sum of training losses 、 Total training accuracy 、 Sample size
metric = Accumulator(3)
#
num = 0
#
for X, y in train_iter:
# Calculate the gradient and update the parameters
y_hat = net(X)
l = loss(y_hat, y)
##wfy
num+=1
print('batch num:',num,str(l)+'\n')
##
if isinstance(updater, torch.optim.Optimizer):
# Use PyTorch Built in optimizer and loss function
updater.zero_grad()
l.backward()
updater.step()
metric.add(
float(l) * len(y), accuracy(y_hat, y),
y.size().numel())
else:
# Use PyTorch Built in optimizer and loss function
l.sum().backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
# Return training loss and training accuracy
return metric[0] / metric[2], metric[1] / metric[2]
2. adopt sklearn 50% off directly , Reload
Divide 50% off and store
import pandas as pd
import os
import random
# Load the node embedded representation obtained by different representation learning algorithms
data_file = os.path.join('data', 'labelmirnadieasedeepwalksampling', 'wfy_label_mirna_die_hdmm_deepwalk_sampling.csv')
# print(data_file)
train = pd.read_csv(data_file,header=None)
# train = train[:20]
train.info()
def k_fold_split(train_df, k):
# os.system("mkdir data")
# Store separately 5 fold
k_fold = []
index = set(range(train.shape[0]))
for i in range(k):
# Prevent all data from being divisible k, Finally, put the rest into the last discount
if i == k - 1:
k_fold.append(list(index))
else:
tmp = random.sample(list(index), int(1.0 / k * train.shape[0]))
k_fold.append(tmp)
index -= set(tmp)
# Divide the original training set into k A training set containing training set and verification set , At the same time, each training set , Training set : Verification set =k-1:1
for i in range(k):
print(" The first {} fold ........".format(i + 1))
tra = []
dev = k_fold[i]
for j in range(k):
if i != j:
tra += k_fold[j] # Splice in addition to No i The other four folds besides the fold are used as the training set
train.iloc[tra].to_csv("data/node2vec_5_fold/train_{}.csv".format(i), index=False,header=None )
train.iloc[dev].to_csv("data/node2vec_5_fold/val_{}.csv".format(i), index=False,header=None)
print("done!")
if __name__ == "__main__":
k_fold_split(train, 5)
Inherit Dataset class
class Mirna_die_Dataset(Dataset):
def __init__(self, x=None, y=None,data_file=None):
wfy_utils.save_result(data_file)
## Direct reading
data_frame = pd.read_csv(data_file, header=None)
print(len(data_frame))
outputs = data_frame.iloc[:, 1]
inputs = data_frame.iloc[:, 2:258]
x, y = torch.tensor(inputs.values,dtype=torch.float32), torch.tensor(outputs.values.reshape(-1, 1),dtype=torch.float32)
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
x = self.x[idx]
y = self.y[idx]
return x,y # Return samples and labels tensor
Instantiation Dataset Class combination DataLoader Generate data iterators
train_data_file = os.path.join('data', 'hmdd20_deepwalk_5_fold', 'train_0.csv')
val_data_file = os.path.join('data', 'hmdd20_deepwalk_5_fold', 'val_0.csv')
train_dataset = Mirna_die_Dataset(data_file=train_data_file)
test_dataset = Mirna_die_Dataset(data_file=val_data_file)
# Iteratively partitioned data sets
train_iter = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=0)
test_iter = DataLoader(test_dataset, batch_size=len(test_dataset),
shuffle=True, num_workers=0)
边栏推荐
- [vs] trying to load a program with incorrect format
- Under fitting and over fitting (regularization)
- Is there anyone who can't analyze these data cases? A trick to teach you how to visualize recruitment data~
- Hande enterprise digital PAAS platform hzero version 1.9.0 was officially released!
- Tutorial update 20220719
- The three formats of the log "binlog" in MySQL are so interesting
- 一种跳板机的实现思路
- final、finally、finalize的区别
- Atomicinteger class is used in multithreading to ensure thread safety
- [reprint] UE4 interview Basics (II)
猜你喜欢
【图文并茂】在线一键重装win7系统详细教程
Simple use of UE4 terrain tool
16_ Response status code
NFS网络文件系统
Realization of a springboard machine
UE4 enters the designated area to realize the trigger acceleration function
[MySQL] SQL tuning practice teaching
FPGA - memory resources of internal structure of 7 Series FPGA -02- FIFO resources
Ffmpeg-rk3399 ffplay learning analysis
vim入门
随机推荐
2022/7/19-日报
Vivo official website app full model UI adaptation scheme
Blob URL DataURL
写一个定序器插件 sequence 字幕(一)
【C语言趣味实验】
BigInteger: what does new BigInteger (tokenjason. Getbytes()). ToString (16) mean
线程和进程
Gd32f470 serial port idle interrupt +dma
Zen administrator forgets password and retrieves password
Win11开机只有鼠标显示怎么办?
UE4 面试基础知识(三)
修复版动态视频壁纸微信小程序源码下载,支持多种类型流量主收益
What is I18N and what is its function
Ffmpeg-rk3399 ffplay learning analysis
UE4 modify the default cache path
GMT 0009-2012 data format go language operation
Codeforce d2. RGB substring (hard version) sliding window
Tutorial update 20220719
UE4 set collision body
LVS, this is enough