当前位置:网站首页>Pytorch implements retinanet (III) definition and training of loss
Pytorch implements retinanet (III) definition and training of loss
2022-07-20 08:24:00 【Visual feast】
Abstract
This version of the code is very concise ,loss The definition and writing of the training part are also very similar to the target classification , So the learning difficulty is reduced a lot , All the codes that can be saved are eliminated , The main goal is to let everyone understand the essence of target detection , Can write the training and testing part , The lack of mAP Calculation , I will explain this part separately .
focal loss Definition
In the first two chapters , The first part is mainly about retinanet extracted 5 Characteristic graphs are used for prediction , The second part explains how to process data into what we need to predict label, Just for the convenience of entering loss Training .focal loss The explanation of my previous blog has been very detailed , In this part, I mainly explain the data changes , Better understanding loss,
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import one_hot_embedding
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, num_classes=20):
super(FocalLoss, self).__init__()
self.num_classes = num_classes
def focal_loss(self, x, y): # Refer to my previous blog
'''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. '''
alpha = 0.25
gamma = 2
t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) # [N,21]
t = t[:,1:] # exclude background
t = Variable(t).cuda() # [N,20]
p = x.sigmoid()
pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1-p
w = alpha*t + (1-alpha)*(1-t) # w = alpha if t > 0 else 1-alpha
w = w * (1-pt).pow(gamma)
return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_alt(self, x, y): # Refer to the previous blog
'''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. '''
alpha = 0.25
t = one_hot_embedding(y.data.cpu(), 1+self.num_classes)
t = t[:,1:]
t = Variable(t).cuda()
xt = x*(2*t-1) # xt = x if t > 0 else -x
pt = (2*xt+1).sigmoid()
w = alpha*t + (1-alpha)*(1-t)
loss = -w*pt.log() / 2
return loss.sum()
def forward(self, loc_preds, loc_targets, cls_preds, cls_targets):# Focus on the data changes inside
'''Compute loss between (loc_preds, loc_targets) and (cls_preds, cls_targets). Args: loc_preds: (tensor) predicted locations, sized [batch_size, #anchors, 4]. loc_targets: (tensor) encoded target locations, sized [batch_size, #anchors, 4]. cls_preds: (tensor) predicted class confidences, sized [batch_size, #anchors, #classes]. cls_targets: (tensor) encoded target labels, sized [batch_size, #anchors]. loss: (tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + FocalLoss(cls_preds, cls_targets). '''
batch_size, num_boxes = cls_targets.size()
pos = cls_targets > 0 # [N,#anchors] # Find the box that is not the background
num_pos = pos.data.long().sum() # Not the number of background boxes
################################################################
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
################################################################
mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4] This step will first pos The dimension of increases , In kuanzhan loc——preds Dimensions
masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4] This step is to extract not the background box ,
masked_loc_targets = loc_targets[mask].view(-1,4) # [#pos,4] Labels should also extract boxes that are not background , In this way, it can correspond to the training loss
loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_targets, size_average=False) #loss Define the way , use smooth——L1
################################################################
# cls_loss = FocalLoss(loc_preds, loc_targets)
################################################################
pos_neg = cls_targets > -1 # exclude ignored anchors Also remove a part ,
mask = pos_neg.unsqueeze(2).expand_as(cls_preds) # Expand to the same dimension
masked_cls_preds = cls_preds[mask].view(-1,self.num_classes) # Filter some
cls_loss = self.focal_loss_alt(masked_cls_preds, cls_targets[pos_neg])
# print('loc_loss: %.3f | cls_loss: %.3f' % (loc_loss.data[0]/num_pos, cls_loss.data[0]/num_pos), end=' | ')
loss = (loc_loss+cls_loss)/num_pos
return loss
Before entering the training, we mainly filter the training data , Remove most parts that are not suitable for training , such as iou Set too small as the background , Here, due to the use of focal loss So the background is all used as training samples , In most target detection, the positive negative ratio of training is 1:3 The relationship between . Let's look at the data change form of each row
for images, loc_targets, cls_targets in dataloader:
print(images.size())
print(loc_targets.size())
print(cls_targets.size())
pos = cls_targets > 0
num_pos = pos.data.long().sum()
print('num',num_pos)
mask = pos.unsqueeze(2).expand_as(loc_preds)
print('mask',mask.shape)
masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]
print(masked_loc_preds.shape)
masked_loc_targets = loc_targets[mask].view(-1,4) # [#pos,4]
print(masked_loc_targets.shape)
pos_neg = cls_targets > -1 # exclude ignored anchors
mask = pos_neg.unsqueeze(2).expand_as(cls_preds)
masked_cls_preds = cls_preds[mask].view(-1, 20)
print(masked_cls_preds.shape)
a=cls_targets[pos_neg]
print(a.shape)
torch.Size([1, 3, 300, 300])
torch.Size([1, 17451, 4])
torch.Size([1, 17451])
num tensor(82)
mask torch.Size([1, 17451, 4])
torch.Size([82, 4])
torch.Size([82, 4])
torch.Size([17277, 20])
torch.Size([17277])
torch.Size([1, 3, 300, 300])
torch.Size([1, 17451, 4])
torch.Size([1, 17451])
num tensor(35)
mask torch.Size([1, 17451, 4])
torch.Size([35, 4])
torch.Size([35, 4])
torch.Size([17383, 20])
torch.Size([17383])
There are only two photos , Batch 1 , So load twice , It can be seen that there are few boxes entering the training part , Basically not more than 100, In the prediction phase, we will use nms There is only one box where the probability is the best . The classified data is basically not filtered out , Here we can change according to our own data set .
Training part
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from loss import FocalLoss
from retinanet import RetinaNet
from datagen import ListDataset
from torch.autograd import Variable
parser = argparse.ArgumentParser(description='PyTorch RetinaNet Training')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
assert torch.cuda.is_available(), 'Error: CUDA not found!'
best_loss = float('inf') # best test loss
start_epoch = 0 # start from epoch 0 or last epoch
# Data
print('==> Preparing data..')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
dataset = ListDataset(root='G:\detection\image',
list_file='G:\detection\image/test.txt', train=True, transform=transform, input_size=300)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
testset = ListDataset(root='/search/odin/liukuang/data/voc_all_images',
list_file='./data/voc12_val.txt', train=False, transform=transform, input_size=600)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=8, collate_fn=testset.collate_fn)
# Model
net = RetinaNet()
net.load_state_dict(torch.load('./model/net.pth'))
if args.resume:
print('==> Resuming from checkpoint..')
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_loss = checkpoint['loss']
start_epoch = checkpoint['epoch']
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
net.cuda()
criterion = FocalLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
net.module.freeze_bn()
train_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(dataloader):
inputs = Variable(inputs.cuda())
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
optimizer.zero_grad()
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
loss.backward()
optimizer.step()
train_loss += loss.data[0]
print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(batch_idx+1)))
def test(epoch):
print('\nTest')
net.eval()
test_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(testloader):
inputs = Variable(inputs.cuda(), volatile=True)
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
test_loss += loss.data[0]
print('test_loss: %.3f | avg_loss: %.3f' % (loss.data[0], test_loss/(batch_idx+1)))
# Save checkpoint
global best_loss
test_loss /= len(testloader)
if test_loss < best_loss:
print('Saving..')
state = {
'net': net.module.state_dict(),
'loss': test_loss,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_loss = test_loss
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)
This part of the code is basically similar to the training of target classification , In other target detection, we basically look at mAP value , The code here is simpler and better understood , Mainly for the prediction part , We need to process the predicted value , To know the position of the box .
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from retinanet import RetinaNet
from encoder import DataEncoder
from PIL import Image, ImageDraw
print('Loading model..')
net = RetinaNet()
net.load_state_dict(torch.load('./checkpoint/params.pth'))
net.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
print('Loading image..')
img = Image.open('./image/000001.jpg')
w = h = 600
img = img.resize((w,h))
print('Predicting..')
x = transform(img)
x = x.unsqueeze(0)
x = Variable(x, volatile=True)
loc_preds, cls_preds = net(x)
print('Decoding..')
encoder = DataEncoder()
boxes, labels = encoder.decode(loc_preds.data.squeeze(), cls_preds.data.squeeze(), (w,h))
draw = ImageDraw.Draw(img)
for box in boxes:
draw.rectangle(list(box), outline='red')
img.show()
The focus is on encoder.decode When training, it is not used directly xywh To make predictions ,
def decode(self, loc_preds, cls_preds, input_size):
'''Decode outputs back to bouding box locations and class labels. Args: loc_preds: (tensor) predicted locations, sized [#anchors, 4]. cls_preds: (tensor) predicted class labels, sized [#anchors, #classes]. input_size: (int/tuple) model input size of (w,h). Returns: boxes: (tensor) decode box locations, sized [#obj,4]. labels: (tensor) class labels for each box, sized [#obj,]. '''
CLS_THRESH = 0.5
NMS_THRESH = 0.5
input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \
else torch.Tensor(input_size)
anchor_boxes = self._get_anchor_boxes(input_size) # Get all candidate boxes
loc_xy = loc_preds[:,:2]
loc_wh = loc_preds[:,2:]
xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2] # This is true xy, Just the opposite of the training process , Well understood. , Using this method, a large number of experiments show that the error is smaller .
wh = loc_wh.exp() * anchor_boxes[:,2:]
boxes = torch.cat([xy-wh/2, xy+wh/2], 1) # [#anchors,4]
score, labels = cls_preds.sigmoid().max(1) # [#anchors,]
ids = score > CLS_THRESH
ids = ids.nonzero().squeeze() # [#obj,]
keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH) #nms Filter a lot of duplicate boxes
return boxes[ids][keep], labels[ids][keep]
nms
nms It is mainly used to remove frames with high degree of reconnection , There are many frames on an object that are the main frame , We're going to use nms Find the one with the best probability .
def box_nms(bboxes, scores, threshold=0.5, mode='union'):
'''Non maximum suppression. Args: bboxes: (tensor) bounding boxes, sized [N,4]. scores: (tensor) bbox scores, sized [N,]. threshold: (float) overlap threshold. mode: (str) 'union' or 'min'. Returns: keep: (tensor) selected indices. Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py '''
x1 = bboxes[:,0]
y1 = bboxes[:,1]
x2 = bboxes[:,2]
y2 = bboxes[:,3]
areas = (x2-x1+1) * (y2-y1+1) # Area calculation
_, order = scores.sort(0, descending=True) # Sort , The higher the score, the better the box .
keep = []
while order.numel() > 0: # Use the cycle to compare again and again , Find the best
i = order[0]
keep.append(i)
if order.numel() == 1:
break
xx1 = x1[order[1:]].clamp(min=x1[i]) # You can melt it in straw , Here we can find the point on the diagonal of the intersection area , Then it is used to calculate the area of the intersection area
yy1 = y1[order[1:]].clamp(min=y1[i])
xx2 = x2[order[1:]].clamp(max=x2[i])
yy2 = y2[order[1:]].clamp(max=y2[i])
w = (xx2-xx1+1).clamp(min=0)
h = (yy2-yy1+1).clamp(min=0)
inter = w*h
if mode == 'union':
ovr = inter / (areas[i] + areas[order[1:]] - inter)
elif mode == 'min': # I can't use this
ovr = inter / areas[order[1:]].clamp(max=areas[i])
else:
raise TypeError('Unknown nms mode: %s.' % mode)
ids = (ovr<=threshold).nonzero().squeeze()
if ids.numel() == 0:
break
order = order[ids+1]
return torch.LongTensor(keep)
summary
Come here retinanet That's it ,retinanet It is the code that I have seen dozens of source codes that are most suitable for entry target detection , The most important thing is to experience the general process and some fine nodes of target detection , The code is also easy to use , Let's start with a palm efficientnet and efficientdet Source code ,efficientnet The effect on classification is still very awesome
边栏推荐
猜你喜欢
Pytorch uses free GPU test training (aistudio) yolov4 as an example
Mysql千万级别表分区优化
微擎系统在生产运行异常
The micro engine system runs abnormally in production
第七十四篇:机器学习优化方法及超参数设置综述
指针运算练习题及字符串函数
pytorch 目标检测数据处理比赛使用
Jenkins linked flybook pushes the test report notification message in the form of signature verification
Niuke sword finger offer cut rope
Securityerror: (:) [] occurs when CMD executes the command, parentcontainserrorrecordexception
随机推荐
Overlay number
ThreadLocal学习笔记
BERT-tutorial
Redis地理算法GEO解析和应用
Pytorch yolo4 training any training set
练习题(1)创建一个集合c1,存放元素“one“,“two“,“three“
C disable global shortcut key
STM32-定时器
练习(3)创建一个List集合(ArrayList,LinkedList均可)
MySQL gets the start time and end time of the current day, yesterday, this week, last week, this month and last month
Linear structure understanding
珍惜时间,提高效率
自定义类型:结构体,位段,枚举,联合
Partition of integers
Win10中用VS2019编译live555
pytorch yolo4训练任意训练集
SSM notes
Serialization concept learning
ThreadLocal learning notes
YOLO系列目标检测数据集大全