模型压缩

目标检测网络的知识蒸馏

“Learning Efficient Object Detection Models with Knowledge Distillation”这篇文章通过知识蒸馏(Knowledge Distillation)与Hint指导学习(Hint Learning),提升了主干精简的多分类目标检测网络的推理精度(文章以Faster RCNN为例),例如Faster RCNN-Alexnet、Faster-RCNN-VGGM等,具体框架如下图所示: 教师网络的暗知识提取分为三点:中间层Feature Maps的Hint;RPN/RCN中分类层的暗知识;以及RPN/RCN中回归层的暗知识。具体如下: 具体指导学生网络学习时,RPN与RCN的分类损失由分类层softmax输出与hard target的交叉熵loss、以及分类层softmax输出与soft target的交叉熵loss构成: 由于检测器需要鉴别的不同类别之间存在样本不均衡(imbalance),因此在L_soft中需要对不同类别的交叉熵分配不同的权重,其中背景类的权重为1.5(较大的比例),其他分类的权重均为1.0: RPN与RCN的回归损失由正常的smooth L1 loss、以及文章所定义的teacher bounded regression loss构成: 其中Ls_L1表示正常的smooth L1 loss,Lb表示文章定义的teacher bounded regression loss。当学生网络的位置回归与ground truth的L2距离超过教师网络的位置回归与ground truth的L2距离、且大于某一阈值时,Lb取学生网络的位置回归与ground truth之间的L2距离,否则Lb置0。 Hint learning需要计算教师网络与学生网络中间层输出的Feature Maps之间的L2 loss,并且在学生网络中需要添加可学习的适配层(adaptation layer),以确保guided layer输出的Feature Maps与教师网络输出的Hint维度一致: 通过知识蒸馏、Hint指导学习,提升了精简网络的泛化性、并有助于加快收敛,最后取得了良好的实验结果,具体见文章实验部分。 以SSD为例,KD loss与Teacher bounded L2 loss设计如下: # -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from ..box_utils import match, log_sum_exp eps = 1e-5 def KL_div(p, q, pos_w, neg_w): p = p + eps q = q + eps log_p = p * torch.

知识蒸馏(Knowledge Distillation)

1、Distilling the Knowledge in a Neural Network Hinton的文章”Distilling the Knowledge in a Neural Network”首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(teacher network:复杂、但推理性能优越)相关的软目标(soft-target)作为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledge transfer)。 如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做softmax变换,可以获得软化的概率分布(软目标),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,可以用one-hot矢量表示。total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的推理性能通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。 教师网络与学生网络也可以联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络softmax输出的交叉熵loss、学生网络softmax输出的交叉熵loss、以及教师网络数值输出与学生网络softmax输出的交叉熵loss): 联合训练的Paper地址:https://arxiv.org/abs/1711.05852 2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions 这篇文章将total loss重新定义如下: GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch total loss的Pytorch代码如下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将teacher network的预测输出缓存到CPU内存中,可以减轻GPU显存的overhead: def loss_fn_kd(outputs, labels, teacher_outputs, params): """ Compute the knowledge-distillation (KD) loss given outputs, labels. "Hyperparameters": temperature and alpha NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher and student expects the input tensor to be log probabilities!