博客專欄

        EEPW首頁 > 博客 > 知識蒸餾綜述:代碼整理(1)

        知識蒸餾綜述:代碼整理(1)

        發布人:計算機視覺工坊 時間:2022-01-16 來源:工程師 發布文章

        作者 | PPRP 

        來源 | GiantPandaCV

        編輯 | 極市平臺

        導讀

        本文收集自RepDistiller中的蒸餾方法,盡可能簡單解釋蒸餾用到的策略,并提供了實現源碼。

        1. KD: Knowledge Distillation

        全稱:Distilling the Knowledge in a Neural Network

        鏈接:https://arxiv.org/pdf/1503.02531.pd3f

        發表:NIPS14

        最經典的,也是明確提出知識蒸餾概念的工作,通過使用帶溫度的softmax函數來軟化教師網絡的邏輯層輸出作為學生網絡的監督信息,

        使用KL divergence來衡量學生網絡與教師網絡的差異,具體流程如下圖所示(來自Knowledge Distillation A Survey)

        1.jpg

        對學生網絡來說,一部分監督信息來自hard label標簽,另一部分來自教師網絡提供的soft label。代碼實現:

        class DistillKL(nn.Module):
            """Distilling the Knowledge in a Neural Network"""
            def __init__(self, T):
                super(DistillKL, self).__init__()
                self.T = T
            def forward(self, y_s, y_t):
                p_s = F.log_softmax(y_s/self.T, dim=1)
                p_t = F.softmax(y_t/self.T, dim=1)
                loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
                return loss

        核心就是一個kl_div函數,用于計算學生網絡和教師網絡的分布差異。

        2. FitNet: Hints for thin deep nets

        全稱:Fitnets: hints for thin deep nets

        鏈接:https://arxiv.org/pdf/1412.6550.pdf

        發表:ICLR 15 Poster

        對中間層進行蒸餾的開山之作,通過將學生網絡的feature map擴展到與教師網絡的feature map相同尺寸以后,使用均方誤差MSE Loss來衡量兩者差異。

        2.jpg

        實現如下:

        class HintLoss(nn.Module):
            """Fitnets: hints for thin deep nets, ICLR 2015"""
            def __init__(self):
                super(HintLoss, self).__init__()
                self.crit = nn.MSELoss()
            def forward(self, f_s, f_t):
                loss = self.crit(f_s, f_t)
                return loss

        實現核心就是MSELoss。

        3. AT: Attention Transfer

        全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

        鏈接:https://arxiv.org/pdf/1612.03928.pdf

        發表:ICLR16

        為了提升學生模型性能提出使用注意力作為知識載體進行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實驗發現第一種方法既簡單效果又好。

        3.jpg

        實現如下:

        class Attention(nn.Module):
            """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
            via Attention Transfer
            code: https://github.com/szagoruyko/attention-transfer"""
            def __init__(self, p=2):
                super(Attention, self).__init__()
                self.p = p
            def forward(self, g_s, g_t):
                return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
            def at_loss(self, f_s, f_t):
                s_H, t_H = f_s.shape[2], f_t.shape[2]
                if s_H > t_H:
                    f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                elif s_H < t_H:
                    f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                else:
                    pass
                return (self.at(f_s) - self.at(f_t)).pow(2).mean()
            def at(self, f):
                return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

        首先使用avgpool將尺寸調整一致,然后使用MSE Loss來衡量兩者差距。

        4. SP: Similarity-Preserving

        全稱:Similarity-Preserving Knowledge Distillation

        鏈接:https://arxiv.org/pdf/1907.09682.pdf

        發表:ICCV19SP

        歸屬于基于關系的知識蒸餾方法。文章思想是提出相似性保留的知識,使得教師網絡和學生網絡會對相同的樣本產生相似的激活。可以從下圖看出處理流程,教師網絡和學生網絡對應feature map通過計算內積,得到bsxbs的相似度矩陣,然后使用均方誤差來衡量兩個相似度矩陣。

        4.jpg

        最終Loss為:

        G代表的就是bsxbs的矩陣。實現如下:

        class Similarity(nn.Module):
            """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
            def __init__(self):
                super(Similarity, self).__init__()
            def forward(self, g_s, g_t):
                return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
            def similarity_loss(self, f_s, f_t):
                bsz = f_s.shape[0]
                f_s = f_s.view(bsz, -1)
                f_t = f_t.view(bsz, -1)
                G_s = torch.mm(f_s, torch.t(f_s))
                # G_s = G_s / G_s.norm(2)
                G_s = torch.nn.functional.normalize(G_s)
                G_t = torch.mm(f_t, torch.t(f_t))
                # G_t = G_t / G_t.norm(2)
                G_t = torch.nn.functional.normalize(G_t)
                G_diff = G_t - G_s
                loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
                return loss

        5. CC: Correlation Congruence

        全稱:Correlation Congruence for Knowledge Distillation

        鏈接:https://arxiv.org/pdf/1904.01802.pdf

        發表:ICCV19

        CC也歸屬于基于關系的知識蒸餾方法。不應該僅僅引導教師網絡和學生網絡單個樣本向量之間的差異,還應該學習兩個樣本之間的相關性,而這個相關性使用的是Correlation Congruence 教師網絡雨學生網絡相關性之間的歐氏距離。

        整體Loss如下:

        實現如下:

        class Correlation(nn.Module):
            """Similarity-preserving loss. My origianl own reimplementation 
            based on the paper before emailing the original authors."""
            def __init__(self):
                super(Correlation, self).__init__()
            def forward(self, f_s, f_t):
                return self.similarity_loss(f_s, f_t)
            def similarity_loss(self, f_s, f_t):
                bsz = f_s.shape[0]
                f_s = f_s.view(bsz, -1)
                f_t = f_t.view(bsz, -1)
                G_s = torch.mm(f_s, torch.t(f_s))
                G_s = G_s / G_s.norm(2)
                G_t = torch.mm(f_t, torch.t(f_t))
                G_t = G_t / G_t.norm(2)
                G_diff = G_t - G_s
                loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
                return loss


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 渭南市| 昭通市| 滨海县| 巢湖市| 威海市| 宁波市| 体育| 九江县| 什邡市| 称多县| 张家川| 广德县| 芜湖县| 岳阳县| 德昌县| 丰都县| 开平市| 宁波市| 达日县| 德庆县| 神池县| 南郑县| 越西县| 同江市| 盐边县| 黄山市| 延庆县| 高要市| 漾濞| 郎溪县| 桃江县| 兰溪市| 织金县| 宕昌县| 交城县| 乌拉特中旗| 辽阳县| 将乐县| 海林市| 龙井市| 桦南县|