計算機(jī)視覺中的知識蒸餾
作者:ppog@知乎(已授權(quán)轉(zhuǎn)載)
編輯:CV技術(shù)指南
原文:https://zhuanlan.zhihu.com/p/497067556
前段時間熬完畢設(shè)的工作,趁著空閑想寫一篇關(guān)于知識蒸餾的博客,這是本人讀研期間的一個研究方向,但這篇博客不會過于深入,內(nèi)容大概簡短說說自己對于知識蒸餾的一些看法,大多數(shù)內(nèi)容來源于四月份看到的兩篇paper。鄙人愚見,有不當(dāng)之處歡迎批評!
文中涉及到的三篇論文
《Distilling the Knowledge in a Neural Network》
paper:arxiv.org/pdf/1503.0253
code:github.com/labmlai/anno
《Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results》
paper:arxiv.org/pdf/2204.0347
code:github.com/Alibaba-MIIL
《Decoupled Knowledge Distillation》
paper:arxiv.org/abs/2203.0867
code:github.com/megvii-resea
1、知識鋪墊one hot 編碼
one-hot 編碼(one-hot encoding)類似于虛擬變量(dummy variables),是一種將分類變量轉(zhuǎn)換為幾個二進(jìn)制列的方法,即一種硬編碼形式,類似非黑即白。其中 1 代表某個輸入樣本屬于該類別。
深度學(xué)習(xí)領(lǐng)域中,通常將數(shù)據(jù)標(biāo)注為hard label,但事實(shí)上同一個數(shù)據(jù)包含不同類別的信息,直接標(biāo)注為hard label無法顯示圖像數(shù)據(jù)間的相關(guān)性,例如分類任務(wù)中,數(shù)據(jù)樣本(下圖)的hard label是【sheep:1】,而實(shí)際上,樣本中包含了一條狗,對應(yīng)的soft label可能是【sheep:0.90;dog:0.10】。
基于上述事實(shí):
hard label會根據(jù)照片,告訴我們這就是羊,其他都不是;
soft label會告訴我們,這張照片大概率是羊,存在一定概率是狗。
但在實(shí)際應(yīng)用中,兩者均有其所長:hard label雖然更容易標(biāo)注,但是會丟失類內(nèi)、類間的關(guān)聯(lián)。而soft label能給模型帶來更強(qiáng)的泛化,攜帶更多的信息,但是獲取難度會比hard label大。
總的來說,兩者都屬于知識遷移的一種,知識蒸餾是模型層面的遷移方式,而遷移學(xué)習(xí)是數(shù)據(jù)層面的遷移方式。
具體而言,兩個在一定程度下都可以實(shí)現(xiàn)漲點(diǎn),以ImageNet-1K、ImageNet-21K、ResNet18、ResNet31為例(假設(shè)驗(yàn)證集恒不變):
對于遷移學(xué)習(xí),我們使用ResNet18在ImageNet-21K上進(jìn)行預(yù)訓(xùn)練,訓(xùn)練完后將模型遷移到ImageNet-1K上微調(diào),在驗(yàn)證集不變的情況,精度會更高。
對于知識蒸餾,我們使用ResNet32作為Teacher模型在ImageNet-1K上進(jìn)行訓(xùn)練,ResNet18作為Student模型同樣也在ImageNet-1K上訓(xùn)練,但會引入訓(xùn)練完后的Teacher模型做監(jiān)督,往往精度也會提高。
但兩種方式都會帶來一些問題,例如訓(xùn)練周期更長,更大的計算開銷,更嚴(yán)重的資源占用等等。
《Distilling the Knowledge in a Neural Network》是知識蒸餾的開山鼻祖,于2015年提出,目前引用量快超過10k。其提出來的帶溫度的kl散度損失是最早的分類算法蒸餾方案,由于是基于logits的蒸餾方式,易于復(fù)現(xiàn),后續(xù)也有許多在KL散度上進(jìn)行改進(jìn)的版本。
Knowledge Distillation 的整體示意如上圖所示(基于logits):
Teacher model:結(jié)構(gòu)較為復(fù)雜,特征提取能力更強(qiáng)的大模型,如ResNet31
Student model:結(jié)構(gòu)較為簡單,特征提取能力一般的小模型,如ResNet18
Hard label:輸入數(shù)據(jù)所對應(yīng)的類別,上文開頭解釋過了,常規(guī)的訓(xùn)練一般都是使用的Hard label
Soft label:輸入數(shù)據(jù)通過Teacher模型softmax層的輸出,蒸餾訓(xùn)練附加的loss基于此得來
distill loss:蒸餾采用的損失可能是KL、MSE、CE等,該論文采用的是基于溫度T的KL Loss
KD常見步驟
圍繞這幾個基本點(diǎn),共進(jìn)行步驟如下(假設(shè)數(shù)據(jù)集為cifer):
① ResNet31在cifer數(shù)據(jù)上訓(xùn)練得到的教師模型
② 將教師模型的prediction軟化,即輸入數(shù)據(jù)通過teacher model所得到的softmax層的輸出:
③ 得到軟化的預(yù)測向量后,通過KL散度損失進(jìn)行下一步計算:
可以看下代碼實(shí)現(xiàn):
# y_s: student output logits
# y_t: teacher output logits
# T: temperature for KD
# teacher model: resnet31
# student model: resnet18
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
KLDLoss = nn.KLDivLoss(reduction="none")
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.log_softmax(y_t/self.T, dim=1)
loss = KLDLoss(p_s, p_t) * (self.T**2)
return loss
④ 在計算出蒸餾的loss后,將這個kl_loss附加在原始的分類損失(假設(shè)是CE loss)上:
在經(jīng)過知識蒸餾的操作后,模型精度得到了提升,但當(dāng)時開展的相關(guān)實(shí)驗(yàn)比較少,畢竟是在2014年,各方面條件都有所限制,且文中作者也沒有十分詳細(xì)地解釋蒸餾能讓模型提升的具體原因。
時間來到了2022年,在4月7日,阿里達(dá)摩院在arvix上掛上了《Solving ImageNet》,該論文主要針對目前的計算機(jī)視覺模型,提出通用的訓(xùn)練方案USI,并且該方案主要基于KD蒸餾的訓(xùn)練方式。
不過在我看來,該論文展示了許多豐富的實(shí)驗(yàn)及結(jié)果,并且驗(yàn)證和解釋了為何KD是有效的,更像是對14年提出的KD進(jìn)行詳盡的補(bǔ)充。
文中提到,目前的計算機(jī)視覺模型大致下可以分為四類:
類似ResNet的常規(guī)CNN模型(ResNet-like)
面向移動端的輕量模型(Mobile-oriented)
Transformer模型(Transformer-base)
僅包含MLP的模型(MLP-only)
該作者對上述四種架構(gòu)的計算機(jī)視覺模型抽樣進(jìn)行了實(shí)驗(yàn),有意思的是,使用基于KD方式的訓(xùn)練方案的模型在Top-1上均獲得了不同程度的提高,特別是Mobile-oriented類的輕量模型。
為了更深入地了解KD對模型結(jié)果的影響,作者在下圖中展示了一些教師模型預(yù)測的標(biāo)簽,與ImageNet真實(shí)標(biāo)簽的對比。
圖片(a)包含了大量明顯的釘子,教師模型的預(yù)測是99.9%,而第二和第三個預(yù)測也與釘子(螺絲和錘子)相關(guān),但概率值可以忽略不計。
圖片(b)中包含了一架客機(jī),教師模型的最高預(yù)測是客機(jī)(83.6%)。然而,教師模型也有一些不能忽視的概率(11.3%)。這并非是錯誤,因?yàn)轱w機(jī)上有機(jī)翼。這里的教師模型減輕了實(shí)際情況與真實(shí)標(biāo)簽相互排斥的情況(即要么是1,要么是0),并提供了關(guān)于圖像內(nèi)容更準(zhǔn)確的信息(打個比方,前面提到的一張圖基本都是羊,但有一條狗,數(shù)據(jù)集的分類標(biāo)簽是羊,但teacher教師預(yù)測時會留出部分概率給了狗)。
圖片(c)中包含了一只母雞。然而,母雞的信息并非很明顯,教師模型的預(yù)測反映了這一點(diǎn),通過識別出一只概率較低的母雞(55.5%),還給出了一定的概率給公雞( 大約8.9%.)。雖然這是教師模型的錯誤,但實(shí)際上就算是人,這么小的目標(biāo)似乎也很難一下子分得清。
在圖片(d)中,教師模型認(rèn)為真實(shí)標(biāo)簽是錯誤的。真實(shí)標(biāo)簽是冰棍,而教師模型預(yù)測概率最大的是狗。作者認(rèn)為教師模型的預(yù)測反而是對的,因?yàn)楣吩趫D片中的信息更為突出。
從上面的例子中可以看到,教師模型的預(yù)測比簡單( 0或1)的真實(shí)標(biāo)簽包含了更豐富的信息,soft label解釋了類別之間的相關(guān)性。不僅如此,KD更能代表增強(qiáng)過后圖像的正確信息,能更好處理strong augmentations的問題。由于上述提到的原因,與僅使用hard label的訓(xùn)練相比,使用教師模型的soft label進(jìn)行訓(xùn)練會提供有更有效的監(jiān)督,訓(xùn)練會變得更有效、更穩(wěn)健。
上邊講到,KD有作用,但究竟是哪部分起作用,作用多大,是否存在負(fù)優(yōu)化,值得思考!
在今年的3月16日,曠視對KD(KL Loss)進(jìn)行了更加深入的剖析,提出了解耦蒸餾(《Decoupled Knowledge Distillation》,DKD),這篇文章很精彩,對14年提出的KD(KL Loss)進(jìn)行了多方位的解析,也開展了許多實(shí)驗(yàn)。
如上圖所示,研究者將 logits 拆解成兩部分,藍(lán)色部分指目標(biāo)類別(target class)的 score,綠色部分指非目標(biāo)類別(Non-target class)的 score。并且將KD重新表述為兩部分的加權(quán)和,即 TCKD 和 NCKD。
上述定義和數(shù)學(xué)關(guān)系將幫助我們得到 KL Loss 的新表達(dá)形式:
對于公式的補(bǔ)充解釋:
更有說服力的實(shí)驗(yàn)
為了觀察TCKD 和 NCKD 對蒸餾性能的影響,作者做了大量實(shí)驗(yàn),并試圖通過實(shí)驗(yàn)剖析TCKD 和 NCKD 的作用。
上圖為TCKD 和 NCKD在CIFAR-100 上進(jìn)行的實(shí)驗(yàn),作者初步得出以下結(jié)論:
同時使用 TCKD + NCKD = KD 的蒸餾方式,Student模型均獲得不同程度的提升;
單獨(dú)使用 TCKD 進(jìn)行蒸餾,會對蒸餾效果產(chǎn)生較大的損害,原因在于高溫系數(shù)(T)會導(dǎo)致?lián)p失附加上很大的梯度,增加非目標(biāo)類的 logits ,這會損害學(xué)生預(yù)測的正確性;
單獨(dú)使用 NCKD 進(jìn)行蒸餾,和 KD 效果差不多;
基于上述結(jié)論,是否 NCKD 更加有效,而 TCKD 存在負(fù)優(yōu)化?作者給出了進(jìn)一步的探討。
作者認(rèn)為 TCKD 受限于數(shù)據(jù)集的難易程度,假設(shè)一個樣本經(jīng)過教師模型后輸出概率是0.99,說明這個樣本是易樣本,數(shù)據(jù)集是容易分辨的,而當(dāng)概率只有0.75,甚至是0.55,那么樣本會陷入到模棱兩可的狀態(tài),模型也沒有把握認(rèn)定它就是所謂的那個它(你那么愛它,為什么不把它留下),數(shù)據(jù)集難度增加。
作者補(bǔ)充了以下三個實(shí)驗(yàn):更重的數(shù)據(jù)增強(qiáng);更多的噪聲;更復(fù)雜的數(shù)據(jù)。
1、更重的數(shù)據(jù)增強(qiáng)
上表顯示Teacher模型為ResNet32×4,Student模型為ShuffleNet-V1和ResNet8×4的實(shí)驗(yàn)結(jié)果,在使用 AutoAugment數(shù)據(jù)增強(qiáng)方法的情況下,訓(xùn)練集難樣本系數(shù)增大,此時使用 TCKD 可以達(dá)到較大的提升。
2、更多的噪聲
而通過引入噪聲,當(dāng)噪聲比例增大,TCKD 的提升程度也加強(qiáng)。
3、更復(fù)雜的數(shù)據(jù)
使用ResNet34作為Teacher模型,ResNet18作為Student模型,作者發(fā)現(xiàn)學(xué)生模型的Top-1增加了0.32個點(diǎn)。
最后,作者給出的結(jié)論是,通過嘗試各種策略來增加訓(xùn)練數(shù)據(jù)的復(fù)雜度(例如重的數(shù)據(jù)增強(qiáng)、更多的噪聲、困難的任務(wù))來證明 TCKD 的有效性。結(jié)果證實(shí),在對更具挑戰(zhàn)性的訓(xùn)練數(shù)據(jù)進(jìn)行知識蒸餾時,訓(xùn)練樣本“復(fù)雜度(難度)”的提升對于 TCKD 可能更有增益,說明 TCKD 對于數(shù)據(jù)集中復(fù)雜任務(wù)的監(jiān)督能力更強(qiáng)。
而上上上部分,作者也證實(shí)了NCKD 能力出眾,這也反映了一個事實(shí):說明非目標(biāo)類之間的知識對logits的蒸餾方式至關(guān)重要,它們可以比喻為能力出眾的“暗部成員”(知道卡卡西嗎?),論文中稱之為“暗知識”(dark knowledge)。
如何理解?大家可以把目標(biāo)類別的logits看作是light knowledge,按照我們慣有的思維,目標(biāo)類別是最重要的,我想要識別出一條狗,那么我就會找一大堆關(guān)于該目標(biāo)類別的樣本,不斷填充和豐富它的logits信息,而非目標(biāo)類別則顯得不那么重要,因?yàn)槲覀兿胍猭ill的名單中沒有他們,但不可置否,dark knowledge對于模型泛化性也非常關(guān)鍵。
依據(jù) Teacher 模型預(yù)測的置信度,作者對cifer訓(xùn)練集上的樣本做了排序,根據(jù)排序結(jié)果對數(shù)據(jù)集進(jìn)行切分,置信度0.5-1為一塊,置信度為0-0.5為一塊,實(shí)驗(yàn)結(jié)果如下:
在前 50% 的樣本上使用 NCKD 可以獲得更好的性能,這表明預(yù)測良好的樣本所攜帶的知識比其他樣本更豐富。然而,預(yù)測良好的樣本的損失權(quán)重被教師的高置信度所抑制。這也說明了,置信度高的樣本對蒸餾的效果更加顯著,應(yīng)當(dāng)采取措施讓它們不被抑制。
分類任務(wù)
作者使用DKD和KD進(jìn)行對比,效果都要優(yōu)于KD(KL Loss)的方式,在不同模型上實(shí)現(xiàn)了1-2,甚至是3個點(diǎn)的提升。
并且,作者對一些細(xì)節(jié)也進(jìn)行了補(bǔ)充,通常a設(shè)置為1時效果較好,而實(shí)際應(yīng)用中變動較大的為Beta,當(dāng)具體調(diào)為何值,需要根據(jù)實(shí)際的業(yè)務(wù)數(shù)據(jù)進(jìn)行實(shí)驗(yàn)。
檢測任務(wù)
作者使用了Faster rcnn作為baseline,通過替換不同的backbone以此作為teacher和student,可以看出,DKD的方式帶來的提升均超過了原始KD的方式,而將DKD與基于Feature蒸餾結(jié)合起來組成的DKD+ReviewDKD提升更大。這也證明了,檢測任務(wù)十分依賴于feature的定位能力,而logits這種high level的信息并不具備這種能力,這也使得基于logits的蒸餾方式效果差于feature的蒸餾,但總的來說,KD的解耦型DKD還是展示了更加優(yōu)越的性能。
這篇博客從三個層面講述了KD是什么?為什么有效?突然想寫這篇博客,原因在于四月份看到的兩篇論文解答了我之前在這個方向上的不少疑惑,隨整理出來。但由于本人并未涉略過深,仍會有很多理解不足的地方,也歡迎各位大佬批評指正!
參考文獻(xiàn)[1] pprp:知識蒸餾綜述:代碼整理
[2] medium.com/analytics-vi
[3] 從標(biāo)簽平滑和知識蒸餾理解Soft Label
[4] [論文閱讀]知識蒸餾(Distilling the Knowledge in a Neural Network)
[5] Distilling the Knowledge in a Neural Network 論文筆記
[6] oldsummer:2021 《Knowledge Distillation: A Survey》
[7] CVPR 2022|解耦知識蒸餾!曠視提出DKD:讓Hinton在7年前提出的方法重回SOTA行列!
[8] 阿里巴巴提出USI 讓AI煉丹自動化了,訓(xùn)練任何Backbone無需超參配置,實(shí)現(xiàn)大一統(tǒng)!
本文僅做學(xué)術(shù)分享,如有侵權(quán),請聯(lián)系刪文。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。