目标检测网络的知识蒸馏
“Learning Efficient Object Detection Models with Knowledge Distillation”这篇文章通过知识蒸馏(Knowledge Distillation)与Hint指导学习(Hint Learning),提升了主干精简的多分类目标检测网络的推理精度(文章以Faster RCNN为例),例如Faster RCNN-Alexnet、Faster-RCNN-VGGM等,具体框架如下图所示:
教师网络的暗知识提取分为三点:中间层Feature Maps的Hint;RPN/RCN中分类层的暗知识;以及RPN/RCN中回归层的暗知识。具体如下:
具体指导学生网络学习时,RPN与RCN的分类损失由分类层softmax输出与hard target的交叉熵loss、以及分类层softmax输出与soft target的交叉熵loss构成:
由于检测器需要鉴别的不同类别之间存在样本不均衡(imbalance),因此在L_soft中需要对不同类别的交叉熵分配不同的权重,其中背景类的权重为1.5(较大的比例),其他分类的权重均为1.0:
RPN与RCN的回归损失由正常的smooth L1 loss、以及文章所定义的teacher bounded regression loss构成:
其中Ls_L1表示正常的smooth L1 loss,Lb表示文章定义的teacher bounded regression loss。当学生网络的位置回归与ground truth的L2距离超过教师网络的位置回归与ground truth的L2距离、且大于某一阈值时,Lb取学生网络的位置回归与ground truth之间的L2距离,否则Lb置0。
Hint learning需要计算教师网络与学生网络中间层输出的Feature Maps之间的L2 loss,并且在学生网络中需要添加可学习的适配层(adaptation layer),以确保guided layer输出的Feature Maps与教师网络输出的Hint维度一致:
通过知识蒸馏、Hint指导学习,提升了精简网络的泛化性、并有助于加快收敛,最后取得了良好的实验结果,具体见文章实验部分。
以SSD为例,KD loss与Teacher bounded L2 loss设计如下:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..box_utils import match, log_sum_exp
eps = 1e-5
def KL_div(p, q, pos_w, neg_w):
p = p + eps
q = q + eps
log_p = p * torch.log(p / q)
log_p[:,0] *= neg_w
log_p[:,1:] *= pos_w
return torch.sum(log_p)
class MultiBoxLoss(nn.Module):
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
cfg, use_gpu=True, neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.):
super(MultiBoxLoss, self).__init__()
self.use_gpu = use_gpu
self.num_classes = num_classes # 21
self.threshold = overlap_thresh # 0.5
self.background_label = bkg_label # 0
self.encode_target = encode_target # False
self.use_prior_for_matching = prior_for_matching # True
self.do_neg_mining = neg_mining # True
self.negpos_ratio = neg_pos # 3
self.neg_overlap = neg_overlap # 0.5
self.variance = cfg['variance']
# soft-target loss
self.neg_w = neg_w
self.pos_w = pos_w
self.Temp = Temp
self.reg_m = reg_m
def forward(self, predictions, pred_t, targets):
"""Multibox Loss
Args:
predictions (tuple): A tuple containing loc preds, conf preds,
and prior boxes from SSD net.
conf shape: torch.size(batch_size,num_priors,num_classes)
loc shape: torch.size(batch_size,num_priors,4)
priors shape: torch.size(num_priors,4)
pred_t (tuple): teacher's predictions
targets (tensor): Ground truth boxes and labels for a batch,
shape: [batch_size,num_objs,5] (last idx is the label).
"""
loc_data, conf_data, priors = predictions
num = loc_data.size(0)
priors = priors[:loc_data.size(1), :]
num_priors = (priors.size(0))
num_classes = self.num_classes
# predictions of teachers
loc_teach1, conf_teach1 = pred_t[0]
# match priors (default boxes) and ground truth boxes
loc_t = torch.Tensor(num, num_priors, 4)
conf_t = torch.LongTensor(num, num_priors)
for idx in range(num):
truths = targets[idx][:, :-1].data
labels = targets[idx][:, -1].data
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels,
loc_t, conf_t, idx)
# wrap targets
with torch.no_grad():
if self.use_gpu:
loc_t = loc_t.cuda(non_blocking=True)
conf_t = conf_t.cuda(non_blocking=True)
pos = conf_t > 0 # (1, 0, 1, ...)
num_pos = pos.sum(dim=1, keepdim=True) # [num, 1], number of positives
# Localization Loss (Smooth L1)
# Shape: [batch,num_priors,4]
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) # [batch,num_priors,1] before expand_as
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
# knowledge transfer for loc regression
# teach1
loc_teach1_p = loc_teach1[pos_idx].view(-1, 4)
l2_dis_s = (loc_p - loc_t).pow(2).sum(1)
l2_dis_s_m = l2_dis_s + self.reg_m
l2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1)
l2_num = l2_dis_s_m > l2_dis_t
l2_loss_teach1 = l2_dis_s[l2_num].sum()
l2_loss = l2_loss_teach1
# Compute max conf across batch for hard negative mining
batch_conf = conf_data.view(-1, self.num_classes)
loss_c = log_sum_exp(batch_conf.float()) - batch_conf.gather(1, conf_t.view(-1, 1)).float()
# Hard Negative Mining
loss_c[pos.view(-1, 1)] = 0
loss_c = loss_c.view(num, -1)
#loss_c[pos] = 0 # filter out pos boxes for now
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
# CrossEntropy loss
pos_idx = pos.unsqueeze(2).expand_as(conf_data) # [batch,num_priors,cls]
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
# soft loss for Knowledge Distillation
# teach1
conf_p_teach = conf_teach1[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
pt = F.softmax(conf_p_teach/self.Temp, dim=1)
if self.neg_w > 1.:
ps = F.softmax(conf_p/self.Temp, dim=1)
soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp**2)
else:
ps = F.log_softmax(conf_p/self.Temp, dim=1)
soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp**2)
soft_loss = soft_loss1
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
N = num_pos.data.sum().float()
loss_l = loss_l.float()
loss_c = loss_c.float()
loss_l /= N
loss_c /= N
l2_loss /= N
soft_loss /= N
return loss_l, loss_c, soft_loss, l2_loss
PyTorch版SSD:https://github.com/amdegroot/ssd.pytorch