博客專欄

        EEPW首頁 > 博客 > 知識蒸餾綜述:代碼整理(3)

        知識蒸餾綜述:代碼整理(3)

        發布人:計算機視覺工坊 時間:2022-01-16 來源:工程師 發布文章

        11. FSP: Flow of Solution Procedure

        全稱:A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

        鏈接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

        發表:CVPR17

        FSP認為教學生網絡不同層輸出的feature之間的關系比教學生網絡結果好

        9.jpg

        定義了FSP矩陣來定義網絡內部特征層之間的關系,是一個Gram矩陣反映老師教學生的過程。

        10.jpg

        使用的是L2 Loss進行約束FSP矩陣。實現如下:

        class FSP(nn.Module):
            """A Gift from Knowledge Distillation:
            Fast Optimization, Network Minimization and Transfer Learning"""
            def __init__(self, s_shapes, t_shapes):
                super(FSP, self).__init__()
                assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
                s_c = [s[1] for s in s_shapes]
                t_c = [t[1] for t in t_shapes]
                if np.any(np.asarray(s_c) != np.asarray(t_c)):
                    raise ValueError('num of channels not equal (error in FSP)')
            def forward(self, g_s, g_t):
                s_fsp = self.compute_fsp(g_s)
                t_fsp = self.compute_fsp(g_t)
                loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
                return loss_group
            @staticmethod
            def compute_loss(s, t):
                return (s - t).pow(2).mean()
            @staticmethod
            def compute_fsp(g):
                fsp_list = []
                for i in range(len(g) - 1):
                    bot, top = g[i], g[i + 1]
                    b_H, t_H = bot.shape[2], top.shape[2]
                    if b_H > t_H:
                        bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
                    elif b_H < t_H:
                        top = F.adaptive_avg_pool2d(top, (b_H, b_H))
                    else:
                        pass
                    bot = bot.unsqueeze(1)
                    top = top.unsqueeze(2)
                    bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
                    top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
                    fsp = (bot * top).mean(-1)
                    fsp_list.append(fsp)
                return fsp_list

        12. NST: Neuron Selectivity Transfer

        全稱:Like what you like: knowledge distill via neuron selectivity transfer

        鏈接:https://arxiv.org/pdf/1707.01219.pdf

        發表:CoRR17

        使用新的損失函數最小化教師網絡與學生網絡之間的Maximum Mean Discrepancy(MMD), 文中選擇的是對其教師網絡與學生網絡之間神經元選擇樣式的分布。

        11.jpg

        使用核技巧(對應下面poly kernel)并進一步展開以后可得:

        實際上提供了Linear Kernel、Poly Kernel、Gaussian Kernel三種,這里實現只給了Poly這種,這是因為Poly這種方法可以與KD進行互補,這樣整體效果會非常好。實現如下:

        class NSTLoss(nn.Module):
            """like what you like: knowledge distill via neuron selectivity transfer"""
            def __init__(self):
                super(NSTLoss, self).__init__()
                pass
            def forward(self, g_s, g_t):
                return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
            def nst_loss(self, f_s, f_t):
                s_H, t_H = f_s.shape[2], f_t.shape[2]
                if s_H > t_H:
                    f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                elif s_H < t_H:
                    f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                else:
                    pass
                f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
                f_s = F.normalize(f_s, dim=2)
                f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
                f_t = F.normalize(f_t, dim=2)
                # set full_loss as False to avoid unnecessary computation
                full_loss = True
                if full_loss:
                    return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
                            - 2 * self.poly_kernel(f_s, f_t).mean())
                else:
                    return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
            def poly_kernel(self, a, b):
                a = a.unsqueeze(1)
                b = b.unsqueeze(2)
                res = (a * b).sum(-1).pow(2)
                return res

        13. CRD: Contrastive Representation Distillation

        全稱:Contrastive Representation Distillation

        鏈接:https://arxiv.org/abs/1910.10699v2

        發表:ICLR20

        將對比學習引入知識蒸餾中,其目標修正為:學習一個表征,讓正樣本對的教師網絡與學生網絡盡可能接近,負樣本對教師網絡與學生網絡盡可能遠離。構建的對比學習問題表示如下:

        整體的蒸餾Loss表示如下:

        實現如下:https://github.com/HobbitLong/RepDistiller

        class ContrastLoss(nn.Module):
            """
            contrastive loss, corresponding to Eq (18)
            """
            def __init__(self, n_data):
                super(ContrastLoss, self).__init__()
                self.n_data = n_data
            def forward(self, x):
                bsz = x.shape[0]
                m = x.size(1) - 1
                # noise distribution
                Pn = 1 / float(self.n_data)
                # loss for positive pair
                P_pos = x.select(1, 0)
                log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
                # loss for K negative pair
                P_neg = x.narrow(1, 1, m)
                log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
                loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
                return loss
        class CRDLoss(nn.Module):
            """CRD Loss function
            includes two symmetric parts:
            (a) using teacher as anchor, choose positive and negatives over the student side
            (b) using student as anchor, choose positive and negatives over the teacher side
            Args:
                opt.s_dim: the dimension of student's feature
                opt.t_dim: the dimension of teacher's feature
                opt.feat_dim: the dimension of the projection space
                opt.nce_k: number of negatives paired with each positive
                opt.nce_t: the temperature
                opt.nce_m: the momentum for updating the memory buffer
                opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
            """
            def __init__(self, opt):
                super(CRDLoss, self).__init__()
                self.embed_s = Embed(opt.s_dim, opt.feat_dim)
                self.embed_t = Embed(opt.t_dim, opt.feat_dim)
                self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
                self.criterion_t = ContrastLoss(opt.n_data)
                self.criterion_s = ContrastLoss(opt.n_data)
            def forward(self, f_s, f_t, idx, contrast_idx=None):
                """
                Args:
                    f_s: the feature of student network, size [batch_size, s_dim]
                    f_t: the feature of teacher network, size [batch_size, t_dim]
                    idx: the indices of these positive samples in the dataset, size [batch_size]
                    contrast_idx: the indices of negative samples, size [batch_size, nce_k]
                Returns:
                    The contrastive loss
                """
                f_s = self.embed_s(f_s)
                f_t = self.embed_t(f_t)
                out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
                s_loss = self.criterion_s(out_s)
                t_loss = self.criterion_t(out_t)
                loss = s_loss + t_loss
                return loss

        14. Overhaul

        全稱:A Comprehensive Overhaul of Feature Distillation鏈接:http://openaccess.thecvf.com/content_ICCV_2019/papers/發表:CVPR19

        teacher transform中提出使用margin RELU激活函數。

        12.jpg

        student transform中提出使用1x1卷積。

        distillation feature postion選擇Pre-ReLU。

        13.jpg

        distance function部分提出了Partial L2 損失函數。

        14.jpg

        部分實現如下:

        class OFD(nn.Module):
          '''
          A Comprehensive Overhaul of Feature Distillation
          http://openaccess.thecvf.com/content_ICCV_2019/papers/
          Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
          '''
          def __init__(self, in_channels, out_channels):
            super(OFD, self).__init__()
            self.connector = nn.Sequential(*[
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels)
              ])
            for m in self.modules():
              if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                  nn.init.constant_(m.bias, 0)
              elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
          def forward(self, fm_s, fm_t):
            margin = self.get_margin(fm_t)
            fm_t = torch.max(fm_t, margin)
            fm_s = self.connector(fm_s)
            mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
            loss = torch.mean((fm_s - fm_t)**2 * mask)
            return loss
          def get_margin(self, fm, eps=1e-6):
            mask = (fm < 0.0).float()
            masked_fm = fm * mask
            margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)
            return margin

        參考文獻

        https://blog.csdn.net/weixin_44579633/article/details/119350631

        https://blog.csdn.net/winycg/article/details/105297089

        https://blog.csdn.net/weixin_46239293/article/details/120289163

        https://blog.csdn.net/DD_PP_JJ/article/details/121578722

        https://blog.csdn.net/DD_PP_JJ/article/details/121714957

        https://zhuanlan.zhihu.com/p/344881975

        https://blog.csdn.net/weixin_44633882/article/details/108927033

        https://blog.csdn.net/weixin_46239293/article/details/120266111

        https://blog.csdn.net/weixin_43402775/article/details/109011296

        https://blog.csdn.net/m0_37665984/article/details/103288582

        https://blog.csdn.net/m0_37665984/article/details/103269740

        本文僅做學術分享,如有侵權,請聯系刪文。

        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 托克逊县| 青田县| 平阴县| 长垣县| 伊金霍洛旗| 阿克陶县| 玉山县| 榆林市| 洪洞县| 仙居县| 万源市| 平乐县| 汕尾市| 通山县| 虹口区| 阿拉善左旗| 百色市| 郎溪县| 静乐县| 九龙城区| 错那县| 临湘市| 阿城市| 宜兰市| 宁化县| 石屏县| 浦城县| 县级市| 新余市| 普兰县| 阳城县| 崇阳县| 旬邑县| 紫云| 东辽县| 宁都县| 吉水县| 浮梁县| 罗山县| 天峨县| 阳山县|