当前位置:网站首页>pytorch 数据增强cutmix的实现
pytorch 数据增强cutmix的实现
2022-07-19 05:02:00 【视觉盛宴】
摘要
cutmix和mixup是一种比较重要的数据增强手段,普通的数据增强也只是在照片上修改,增强了对网络提取特征图的能力,cutmix这种就是混合label,增强了fc的学习能力。
cutmix的思想,
只要是用过二张照片,随机的截取一部分,然后换位置,导致label也发生变化。
本来有四种花,我放上去的图片是二个换位置比较明显的图片,照片对应发生改变,label也改变了,变成了有小数的,这里大家可能有疑问。我这里是在线下做的测试,在实际我们运行程序的过程中,原本分类的label的1确实会变成小数,但是位置没有变化,因为1代表的是下标,下标不为0的就是分类想要预测的结果,及时是小数,下标还在原来的位置,所以分类的label还是没有变化的。
读取照片
第一步就是将一个文件夹下的照片读取出来,模仿dataloader这个批次的加载。保证你的 data下有四张以上的照片,大小必须一致。
import glob
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10,10]
import cv2
# Path to data
data_folder = f"./data/"
# Read filenames in the data folder
filenames = glob.glob(f"{data_folder}*.jpg")
# Read first 10 filenames
image_paths = filenames[:4]
# Display a sample image
# plt.imshow(cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)); plt.show();
image_batch = []
image_batch_labels = []
n_images = 4
print(image_paths)
for i in range(4):
image = cv2.cvtColor(cv2.imread(image_paths[i]), cv2.COLOR_BGR2RGB)
image_batch.append(image)
image_batch_labels=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])
for i in range(2):
for j in range(2):
plt.subplot(2,2,2*i+j+1)
plt.imshow(image_batch[2*i+j])
plt.show()
c=image_batch[0]
print(c.shape)
这里我输出了shape,是(500,500,3)但是有4张照片,所以应该是(4,500,500,3)的shape,和dataloader还是不一样的。这里我只是测试。
随机截取
这里的截取只要不超过边界就行,很容易看懂
def rand_bbox(size, lamb):
W = size[0]
H = size[1]
cut_rat = np.sqrt(1. - lamb)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
换位置
lam = np.random.beta(beta, beta)
rand_index = np.random.permutation(len(image_batch)) #产生一个换位置的索引【1,0,2,3】
target_a = image_batch_labels
target_b = np.array(image_batch_labels)[rand_index]
print('img.shape',image_batch[0].shape)
bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
print('bbx1',bbx1)
print('bby1',bby1)
print('bbx2',bbx2)
print('bby1',bby1)
image_batch_updated = image_batch.copy() #前面都是list形式,所以切片操作必须是array
image_batch_updated=np.array(image_batch_updated)
image_batch=np.array(image_batch)
image_batch_updated[:, bbx1:bby1, bbx2:bby2, :] = image_batch[rand_index, bbx1:bby1, bbx2:bby2, :]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image_batch.shape[1] * image_batch.shape[2])) #label对应也要改变
label = target_a * lam + target_b * (1. - lam)
全部代码
import glob
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10,10]
import cv2
# Path to data
data_folder = f"./data/"
# Read filenames in the data folder
filenames = glob.glob(f"{data_folder}*.jpg")
# Read first 10 filenames
image_paths = filenames[:4]
# Display a sample image
# plt.imshow(cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)); plt.show();
image_batch = []
image_batch_labels = []
n_images = 4
print(image_paths)
for i in range(4):
image = cv2.cvtColor(cv2.imread(image_paths[i]), cv2.COLOR_BGR2RGB)
image_batch.append(image)
image_batch_labels=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])
# for i in range(2):
# for j in range(2):
# plt.subplot(2,2,2*i+j+1)
# plt.imshow(image_batch[2*i+j])
# plt.show()
def rand_bbox(size, lamb):
W = size[0]
H = size[1]
cut_rat = np.sqrt(1. - lamb)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
image = cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)
# Crop a random bounding box
lamb = 0.3
size = image.shape
print('size',size)
# bbox = rand_bbox(size, lamb)
# # Draw bounding box on the image
# im = image.copy()
# x1 = bbox[0]
# y1 = bbox[1]
# x2 = bbox[2]
# y2 = bbox[3]
# cv2.rectangle(im, (x1, y1), (x2, y2), (255, 0, 0), 3)
# plt.imshow(im);
# plt.title('Original image with random bounding box')
# plt.show();
# Show cropped image
# plt.imshow(image[y1:y2, x1:x2]);
# plt.title('Cropped image')
# plt.show()
def generate_cutmix_image(image_batch, image_batch_labels, beta):
# generate mixed sample
lam = np.random.beta(beta, beta)
rand_index = np.random.permutation(len(image_batch))
target_a = image_batch_labels
target_b = np.array(image_batch_labels)[rand_index]
print('img.shape',image_batch[0].shape)
bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
print('bbx1',bbx1)
print('bby1',bby1)
print('bbx2',bbx2)
print('bby1',bby1)
image_batch_updated = image_batch.copy()
image_batch_updated=np.array(image_batch_updated)
image_batch=np.array(image_batch)
image_batch_updated[:, bbx1:bby1, bbx2:bby2, :] = image_batch[rand_index, bbx1:bby1, bbx2:bby2, :]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image_batch.shape[1] * image_batch.shape[2]))
label = target_a * lam + target_b * (1. - lam)
return image_batch_updated, label
# image_batch=np.array(image_batch)
# image_batch_updated = image_batch.copy()
# c=[1,0,2,3]
# mm=np.array(image_batch_updated)
# mm[:, 10:200, 10:200, :] = image_batch[c, 10:200, 10:200, :]
# Generate CutMix image
# Let's use the first image of the batch as the input image to be augmented
input_image = image_batch[0]
image_batch_updated, image_batch_labels_updated = generate_cutmix_image(image_batch, image_batch_labels, 1.0)
# Show original images
print("Original Images")
for i in range(2):
for j in range(2):
plt.subplot(2,2,2*i+j+1)
plt.imshow(image_batch[2*i+j])
plt.show()
# Show CutMix images
print("CutMix Images")
for i in range(2):
for j in range(2):
plt.subplot(2,2,2*i+j+1)
plt.imshow(image_batch_updated[2*i+j])
plt.show()
# Print labels
print('Original labels:')
print(image_batch_labels)
print('Updated labels')
print(image_batch_labels_updated)
我写的这种是方便大家观看,跟直接在pytorch模型里的代码修改还是有点差距的,
线上的我还未测试,大家可以按照下面的代码对比操作一下。
pytorch使用
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix(data, targets1, targets2, targets3, alpha):
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_targets1 = targets1[indices]
shuffled_targets2 = targets2[indices]
shuffled_targets3 = targets3[indices]
lam = np.random.beta(alpha, alpha)
bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]
return data, targets
def mixup(data, targets1, targets2, targets3, alpha):
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_targets1 = targets1[indices]
shuffled_targets2 = targets2[indices]
shuffled_targets3 = targets3[indices]
lam = np.random.beta(alpha, alpha)
data = data * lam + shuffled_data * (1 - lam)
targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]
return data, targets
def cutmix_criterion(preds1,preds2,preds3, targets):
targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
criterion = nn.CrossEntropyLoss(reduction='mean')
return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)
def mixup_criterion(preds1,preds2,preds3, targets):
targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
criterion = nn.CrossEntropyLoss(reduction='mean')
return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)
for i, (image_id, images, label1, label2, label3) in enumerate(data_loader_train):
images = images.to(device)
label1 = label1.to(device)
label2 = label2.to(device)
label3 = label3.to(device)
# print (image_id, label1, label2, label3)
if np.random.rand()<0.5:
images, targets = mixup(images, label1, label2, label3, 0.4)
output1, output2, output3 = model(images)
loss = mixup_criterion(output1,output2,output3, targets)
else:
images, targets = cutmix(images, label1, label2, label3, 0.4)
output1, output2, output3 = model(images)
loss = cutmix_criterion(output1,output2,output3, targets)
总结
github也有完整版的修改,可以去搜索测试一下,这种特殊的数据增强对分类修改还是比较容易的,目标检测中基本都是不考虑label,只是组合新照片。
边栏推荐
猜你喜欢
随机推荐
开发提测标准(简易版)
China Astragalus Injection Market Evaluation and investment strategy report (2022 Edition)
有关编码表的基础知识
整数的分划问题
太卷了, 某公司把自家运营多年的核心系统(智慧系统)完全开源了....
Basic page status code
单实例Mongo升级为副本集
three. JS endless pipeline Perspective
MySQL ten million level sub table optimization
Rlib learning [2] --env definition + env rollout
Basic knowledge about coding table
Codeforces 429E 2-SAT
cmd执行命令出现SecurityError: (:) [],ParentContainsErrorRecordException
线性结构理解
selnium 获取js内容
C#中的Explicit和Implicit了解一下吧
上海域格4G模块PPP拨号相关问题
HCIP --- 重发布
Mongo sort exceeds maximum memory error
一个开源的网页画板,真的太方便了