博客專欄

        EEPW首頁(yè) > 博客 > Batch Normalization原理與實(shí)戰(zhàn)(1)

        Batch Normalization原理與實(shí)戰(zhàn)(1)

        發(fā)布人:計(jì)算機(jī)視覺工坊 時(shí)間:2022-09-26 來(lái)源:工程師 發(fā)布文章

        作者丨天雨粟@知乎

        來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/34879333

        編輯丨江大白

        圖片導(dǎo)讀本文主要從理論與實(shí)戰(zhàn)的視角,對(duì)深度學(xué)習(xí)中的Batch Normalization的思路進(jìn)行講解、歸納和總結(jié),并輔以代碼讓小伙伴兒們對(duì)Batch Normalization的作用有更加直觀的了解。

        前言

        本文主要分為兩大部分。第一部分是理論板塊,主要從背景、算法、效果等角度對(duì)Batch Normalization進(jìn)行詳解;第二部分是實(shí)戰(zhàn)板塊,主要以MNIST數(shù)據(jù)集作為整個(gè)代碼測(cè)試的數(shù)據(jù),通過(guò)比較加入Batch Normalization前后網(wǎng)絡(luò)的性能來(lái)讓大家對(duì)Batch Normalization的作用與效果有更加直觀的感知。

        (一)理論板塊

        理論板塊將從以下四個(gè)方面對(duì)Batch Normalization進(jìn)行詳解:

        • 提出背景
        • BN算法思想
        • 測(cè)試階段如何使用BN
        • BN的優(yōu)勢(shì)

        理論部分主要參考2015年Google的Sergey Ioffe與Christian Szegedy的論文內(nèi)容,并輔以吳恩達(dá)Coursera課程與其它博主的資料。所有參考內(nèi)容鏈接均見于文章最后參考鏈接部分。

        1 提出背景1.1 煉丹的困擾

        在深度學(xué)習(xí)中,由于問題的復(fù)雜性,我們往往會(huì)使用較深層數(shù)的網(wǎng)絡(luò)進(jìn)行訓(xùn)練,相信很多煉丹的朋友都對(duì)調(diào)參的困難有所體會(huì),尤其是對(duì)深層神經(jīng)網(wǎng)絡(luò)的訓(xùn)練調(diào)參更是困難且復(fù)雜。在這個(gè)過(guò)程中,我們需要去嘗試不同的學(xué)習(xí)率、初始化參數(shù)方法(例如Xavier初始化)等方式來(lái)幫助我們的模型加速收斂。深度神經(jīng)網(wǎng)絡(luò)之所以如此難訓(xùn)練,其中一個(gè)重要原因就是網(wǎng)絡(luò)中層與層之間存在高度的關(guān)聯(lián)性與耦合性。下圖是一個(gè)多層的神經(jīng)網(wǎng)絡(luò),層與層之間采用全連接的方式進(jìn)行連接。

        圖片

        我們規(guī)定左側(cè)為神經(jīng)網(wǎng)絡(luò)的底層,右側(cè)為神經(jīng)網(wǎng)絡(luò)的上層。那么網(wǎng)絡(luò)中層與層之間的關(guān)聯(lián)性會(huì)導(dǎo)致如下的狀況:隨著訓(xùn)練的進(jìn)行,網(wǎng)絡(luò)中的參數(shù)也隨著梯度下降在不停更新。一方面,當(dāng)?shù)讓泳W(wǎng)絡(luò)中參數(shù)發(fā)生微弱變化時(shí),由于每一層中的線性變換與非線性激活映射,這些微弱變化隨著網(wǎng)絡(luò)層數(shù)的加深而被放大(類似蝴蝶效應(yīng));另一方面,參數(shù)的變化導(dǎo)致每一層的輸入分布會(huì)發(fā)生改變,進(jìn)而上層的網(wǎng)絡(luò)需要不停地去適應(yīng)這些分布變化,使得我們的模型訓(xùn)練變得困難。上述這一現(xiàn)象叫做Internal Covariate Shift。

        1.2 什么是Internal Covariate Shift

        Batch Normalization的原論文作者給了Internal Covariate Shift一個(gè)較規(guī)范的定義:在深層網(wǎng)絡(luò)訓(xùn)練的過(guò)程中,由于網(wǎng)絡(luò)中參數(shù)變化而引起內(nèi)部結(jié)點(diǎn)數(shù)據(jù)分布發(fā)生變化的這一過(guò)程被稱作Internal Covariate Shift。

        這句話該怎么理解呢? 我們同樣以1.1中的圖為例, 我們定義每一層的線性變換為  input , 其中  代表層數(shù); 非線性變換為 , 其中  為 第  層的激活函數(shù)。

        隨著梯度下降的進(jìn)行, 每一層的參數(shù)  與  都會(huì)被更新, 那么  的分布也就發(fā)生了改變, 進(jìn)而  也同樣出現(xiàn)分布的改變。而  作為第  層的輸入, 意味著  層就需要去不停 適應(yīng)這種數(shù)據(jù)分布的變化, 這一過(guò)程就被叫做Internal Covariate Shift。

        1.3 Internal Covariate Shift會(huì)帶來(lái)什么問題?

        (1)上層網(wǎng)絡(luò)需要不停調(diào)整來(lái)適應(yīng)輸入數(shù)據(jù)分布的變化,導(dǎo)致網(wǎng)絡(luò)學(xué)習(xí)速度的降低

        我們?cè)谏厦嫣岬搅颂荻认陆档倪^(guò)程會(huì)讓每一層的參數(shù)  和  發(fā)生變化,進(jìn)而使得每一層的線性與非線性計(jì)算結(jié)果分布產(chǎn)生變化。后層網(wǎng)絡(luò)就要不停地去適應(yīng)這種分布變化,這個(gè)時(shí)候就會(huì)使得整個(gè)網(wǎng)絡(luò)的學(xué)習(xí)速率過(guò)慢。

        (2)網(wǎng)絡(luò)的訓(xùn)練過(guò)程容易陷入梯度飽和區(qū),減緩網(wǎng)絡(luò)收斂速度

        當(dāng)我們?cè)谏窠?jīng)網(wǎng)絡(luò)中采用飽和激活函數(shù) (saturated activation function) 時(shí), 例如sigmoid, tanh 激活函數(shù), 很容易使得模型訓(xùn)練陷入梯度飽和區(qū) (saturated regime)。隨著模型訓(xùn)練的進(jìn)行, 我 們的參數(shù)  會(huì)逐漸更新并變大, 此時(shí)  就會(huì)隨之變大, 并且  還受 到更底層網(wǎng)絡(luò)參數(shù)  的影響, 隨著網(wǎng)絡(luò)層數(shù)的加深,  很容易陷入梯度 飽和區(qū), 此時(shí)梯度會(huì)變得很小甚至接近于 0 , 參數(shù)的更新速度就會(huì)減慢, 進(jìn)而就會(huì)放慢網(wǎng)絡(luò)的收玫 速度。

        對(duì)于激活函數(shù)梯度飽和問題,有兩種解決思路。第一種就是更為非飽和性激活函數(shù),例如線性整流函數(shù)ReLU可以在一定程度上解決訓(xùn)練進(jìn)入梯度飽和區(qū)的問題。另一種思路是,我們可以讓激活函數(shù)的輸入分布保持在一個(gè)穩(wěn)定狀態(tài)來(lái)盡可能避免它們陷入梯度飽和區(qū),這也就是Normalization的思路。

        1.4 我們?nèi)绾螠p緩Internal Covariate Shift?

        要緩解ICS的問題,就要明白它產(chǎn)生的原因。ICS產(chǎn)生的原因是由于參數(shù)更新帶來(lái)的網(wǎng)絡(luò)中每一層輸入值分布的改變,并且隨著網(wǎng)絡(luò)層數(shù)的加深而變得更加嚴(yán)重,因此我們可以通過(guò)固定每一層網(wǎng)絡(luò)輸入值的分布來(lái)對(duì)減緩ICS問題。

        (1)白化(Whitening)

        白化(Whitening)是機(jī)器學(xué)習(xí)里面常用的一種規(guī)范化數(shù)據(jù)分布的方法,主要是PCA白化與ZCA白化。白化是對(duì)輸入數(shù)據(jù)分布進(jìn)行變換,進(jìn)而達(dá)到以下兩個(gè)目的:

        • 使得輸入特征分布具有相同的均值與方差。 其中PCA白化保證了所有特征分布均值為0,方差為1;而ZCA白化則保證了所有特征分布均值為0,方差相同;
        • 去除特征之間的相關(guān)性。

        通過(guò)白化操作,我們可以減緩ICS的問題,進(jìn)而固定了每一層網(wǎng)絡(luò)輸入分布,加速網(wǎng)絡(luò)訓(xùn)練過(guò)程的收斂(LeCun et al.,1998b;Wiesler&Ney,2011)。

        (2)Batch Normalization提出

        既然白化可以解決這個(gè)問題,為什么我們還要提出別的解決辦法?當(dāng)然是現(xiàn)有的方法具有一定的缺陷,白化主要有以下兩個(gè)問題:

        • 白化過(guò)程計(jì)算成本太高, 并且在每一輪訓(xùn)練中的每一層我們都需要做如此高成本計(jì)算的白化操作;
        • 白化過(guò)程由于改變了網(wǎng)絡(luò)每一層的分布,因而改變了網(wǎng)絡(luò)層中本身數(shù)據(jù)的表達(dá)能力。底層網(wǎng)絡(luò)學(xué)習(xí)到的參數(shù)信息會(huì)被白化操作丟失掉。

        既然有了上面兩個(gè)問題,那我們的解決思路就很簡(jiǎn)單,一方面,我們提出的normalization方法要能夠簡(jiǎn)化計(jì)算過(guò)程;另一方面又需要經(jīng)過(guò)規(guī)范化處理后讓數(shù)據(jù)盡可能保留原始的表達(dá)能力。于是就有了簡(jiǎn)化+改進(jìn)版的白化——Batch Normalization。

        2 Batch Normalization2.1 思路

        既然白化計(jì)算過(guò)程比較復(fù)雜,那我們就簡(jiǎn)化一點(diǎn),比如我們可以嘗試單獨(dú)對(duì)每個(gè)特征進(jìn)行normalizaiton就可以了,讓每個(gè)特征都有均值為0,方差為1的分布就OK。

        另一個(gè)問題,既然白化操作減弱了網(wǎng)絡(luò)中每一層輸入數(shù)據(jù)表達(dá)能力,那我就再加個(gè)線性變換操作,讓這些數(shù)據(jù)再能夠盡可能恢復(fù)本身的表達(dá)能力就好了。

        因此,基于上面兩個(gè)解決問題的思路,作者提出了Batch Normalization,下一部分來(lái)具體講解這個(gè)算法步驟。

        2.2 算法

        在深度學(xué)習(xí)中,由于采用full batch的訓(xùn)練方式對(duì)內(nèi)存要求較大,且每一輪訓(xùn)練時(shí)間過(guò)長(zhǎng);我們一般都會(huì)采用對(duì)數(shù)據(jù)做劃分,用mini-batch對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練。因此,Batch Normalization也就在mini-batch的基礎(chǔ)上進(jìn)行計(jì)算。

        2.2.1 參數(shù)定義

        我們依舊以下圖這個(gè)神經(jīng)網(wǎng)絡(luò)為例。我們定義網(wǎng)絡(luò)總共有 LL 層(不包含輸入層)并定義如下符號(hào):

        圖片

        參數(shù)相關(guān):

        •  : 網(wǎng)絡(luò)中的層標(biāo)號(hào)
        •  : 網(wǎng)絡(luò)中的最后一層或總層數(shù)
        •  : 第  層的維度, 即神經(jīng)元結(jié)點(diǎn)數(shù)
        •  : 第  層的權(quán)重矩陣,
        •  第  層的偏置向量,
        •  : 第  層的線性計(jì)算結(jié)果,  input
        •  : 第  層的激活函數(shù)
        •  : 第  層的非線性激活結(jié)果,

        樣本相關(guān):

        •  : 訓(xùn)練樣本的數(shù)量
        •  : 訓(xùn)練樣本的特征數(shù)
        •  : 訓(xùn)練樣本集,  (注意這里  的一列是一個(gè) 樣本)
        •  : batch size, 即每個(gè)batch中樣本的數(shù)量
        •  : 第  個(gè)mini-batch的訓(xùn)練數(shù)據(jù), , 其中
        2.2.2 算法步驟

        介紹算法思路沿襲前面BN提出的思路來(lái)講。第一點(diǎn), 對(duì)每個(gè)特征進(jìn)行獨(dú)立的normalization。我們考慮一個(gè)batch的訓(xùn)練, 傳入  個(gè)訓(xùn)練樣本, 并關(guān)注網(wǎng)絡(luò)中的某一層, 忽略上標(biāo)  。

        我們關(guān)注當(dāng)前層的第  個(gè)維度, 也就是第  個(gè)神經(jīng)元結(jié)點(diǎn), 則有  。我們當(dāng)前維度進(jìn)行規(guī)范化:


        其中  是為了防止方差為0產(chǎn)生無(wú)效計(jì)算。

        下面我們?cè)賮?lái)結(jié)合個(gè)具體的例子來(lái)進(jìn)行計(jì)算。下圖我們只關(guān)注第  層的計(jì)算結(jié)果, 左邊的矩陣是  線性計(jì)算結(jié)果, 還末進(jìn)行激活函數(shù)的非線性變換。此時(shí)每一列是一個(gè)樣本, 圖中可以看到共有8列, 代表當(dāng)前訓(xùn)練樣本的batch中共有8個(gè)樣本, 每一行代表當(dāng)前  層神經(jīng)元的一個(gè)節(jié)點(diǎn), 可以看到當(dāng)前  層共有4個(gè)神經(jīng)元結(jié)點(diǎn), 即第  層維度為4。我們可以看到, 每行的數(shù)據(jù)分布都不同。

        圖片

        對(duì)于第一個(gè)神經(jīng)元, 我們求得  (其中  ), 此時(shí)我們利用  對(duì)第一行數(shù)據(jù)(第一個(gè)維度)進(jìn)行normalization得到新的值 。同理我們可以計(jì)算出其他輸入維度歸一化后的值。如下圖:

        圖片

        通過(guò)上面的變換,我們解決了第一個(gè)問題,即用更加簡(jiǎn)化的方式來(lái)對(duì)數(shù)據(jù)進(jìn)行規(guī)范化,使得第 ll 層的輸入每個(gè)特征的分布均值為0,方差為1。

        如同上面提到的,Normalization操作我們雖然緩解了ICS問題,讓每一層網(wǎng)絡(luò)的輸入數(shù)據(jù)分布都變得穩(wěn)定,但卻導(dǎo)致了數(shù)據(jù)表達(dá)能力的缺失。也就是我們通過(guò)變換操作改變了原有數(shù)據(jù)的信息表達(dá)(representation ability of the network),使得底層網(wǎng)絡(luò)學(xué)習(xí)到的參數(shù)信息丟失。另一方面,通過(guò)讓每一層的輸入分布均值為0,方差為1,會(huì)使得輸入在經(jīng)過(guò)sigmoid或tanh激活函數(shù)時(shí),容易陷入非線性激活函數(shù)的線性區(qū)域。

        因此, BN又引入了兩個(gè)可學(xué)習(xí) (learnable) 的參數(shù)  與  。這兩個(gè)參數(shù)的引入是為了恢復(fù)數(shù)據(jù)本 身的表達(dá)能力, 對(duì)規(guī)范化后的數(shù)據(jù)進(jìn)行線性變換, 即  。特別地, 當(dāng)  時(shí), 可以實(shí)現(xiàn)等價(jià)變換(identity transform)并且保留了原始輸入特征的分布信 息。

        通過(guò)上面的步驟,我們就在一定程度上保證了輸入數(shù)據(jù)的表達(dá)能力。

        以上就是整個(gè)Batch Normalization在模型訓(xùn)練中的算法和思路。

        補(bǔ)充:在進(jìn)行normalization的過(guò)程中, 由于我們的規(guī)范化操作會(huì)對(duì)減去均值, 因此, 偏置項(xiàng)  可以被忽略掉或可以被置為0, 即

        2.2.3 公式

        對(duì)于神經(jīng)網(wǎng)絡(luò)中的第 ll 層,我們有:


        3 測(cè)試階段如何使用Batch Normalization?

        我們知道BN在每一層計(jì)算的  與  都是基于當(dāng)前batch中的訓(xùn)練數(shù)據(jù), 但是這就帶來(lái)了一個(gè)問 題: 我們?cè)陬A(yù)測(cè)階段, 有可能只需要預(yù)測(cè)一個(gè)樣本或很少的樣本, 沒有像訓(xùn)練樣本中那么多的數(shù) 據(jù), 此時(shí)  與  的計(jì)算一定是有偏估計(jì), 這個(gè)時(shí)候我們?cè)撊绾芜M(jìn)行計(jì)算呢?

        利用BN訓(xùn)練好模型后, 我們保留了每組mini-batch訓(xùn)練數(shù)據(jù)在網(wǎng)絡(luò)中每一層的  與  。此時(shí)我們使用整個(gè)樣本的統(tǒng)計(jì)量來(lái)對(duì)Test數(shù)據(jù)進(jìn)行歸一化,具體來(lái)說(shuō)使用均值與方差的無(wú)偏估計(jì):


        得到每個(gè)特征的均值與方差的無(wú)偏估計(jì)后, 我們對(duì)test數(shù)據(jù)采用同樣的normalization方法:


        另外,除了采用整體樣本的無(wú)偏估計(jì)外。吳恩達(dá)在Coursera上的Deep Learning課程指出可以對(duì)train階段每個(gè)batch計(jì)算的mean/variance采用指數(shù)加權(quán)平均來(lái)得到test階段mean/variance的估計(jì)。

        4 Batch Normalization的優(yōu)勢(shì)

        Batch Normalization在實(shí)際工程中被證明了能夠緩解神經(jīng)網(wǎng)絡(luò)難以訓(xùn)練的問題,BN具有的有事可以總結(jié)為以下三點(diǎn):

        (1)BN使得網(wǎng)絡(luò)中每層輸入數(shù)據(jù)的分布相對(duì)穩(wěn)定,加速模型學(xué)習(xí)速度

        BN通過(guò)規(guī)范化與線性變換使得每一層網(wǎng)絡(luò)的輸入數(shù)據(jù)的均值與方差都在一定范圍內(nèi),使得后一層網(wǎng)絡(luò)不必不斷去適應(yīng)底層網(wǎng)絡(luò)中輸入的變化,從而實(shí)現(xiàn)了網(wǎng)絡(luò)中層與層之間的解耦,允許每一層進(jìn)行獨(dú)立學(xué)習(xí),有利于提高整個(gè)神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)速度。

        (2)BN使得模型對(duì)網(wǎng)絡(luò)中的參數(shù)不那么敏感,簡(jiǎn)化調(diào)參過(guò)程,使得網(wǎng)絡(luò)學(xué)習(xí)更加穩(wěn)定

        在神經(jīng)網(wǎng)絡(luò)中,我們經(jīng)常會(huì)謹(jǐn)慎地采用一些權(quán)重初始化方法(例如Xavier)或者合適的學(xué)習(xí)率來(lái)保證網(wǎng)絡(luò)穩(wěn)定訓(xùn)練。

        當(dāng)學(xué)習(xí)率設(shè)置太高時(shí), 會(huì)使得參數(shù)更新步伐過(guò)大, 容易出現(xiàn)震蕩和不收斂。但是使用BN的網(wǎng)絡(luò)將 不會(huì)受到參數(shù)數(shù)值大小的影響。例如, 我們對(duì)參數(shù)  進(jìn)行縮放得到  。對(duì)于縮放前的值  , 我們?cè)O(shè)其均值為 , 方差為 ; 對(duì)于縮放值 , 設(shè)其均值為 , 方差為 , 則我們有:


        我們忽略 , 則有:


        注:公式中的  是當(dāng)前層的輸入,也是前一層的輸出;不是下標(biāo)啊旁友們!

        我們可以看到, 經(jīng)過(guò)BN操作以后, 權(quán)重的縮放值會(huì)被“抺去”, 因此保證了輸入數(shù)據(jù)分布穩(wěn)定在一定范圍內(nèi)。另外, 權(quán)重的縮放并不會(huì)影響到對(duì)  的梯度計(jì)算; 并且當(dāng)權(quán)重越大時(shí), 即  越大,  越小,意味著權(quán)重  的梯度反而越小,這樣BN就保證了梯度不會(huì)依賴于參數(shù)的scale, 使得參數(shù)的更新處在更加穩(wěn)定的狀態(tài)。

        因此,在使用Batch Normalization之后,抑制了參數(shù)微小變化隨著網(wǎng)絡(luò)層數(shù)加深被放大的問題,使得網(wǎng)絡(luò)對(duì)參數(shù)大小的適應(yīng)能力更強(qiáng),此時(shí)我們可以設(shè)置較大的學(xué)習(xí)率而不用過(guò)于擔(dān)心模型divergence的風(fēng)險(xiǎn)。

        (3)BN允許網(wǎng)絡(luò)使用飽和性激活函數(shù)(例如sigmoid,tanh等),緩解梯度消失問題

        在不使用BN層的時(shí)候, 由于網(wǎng)絡(luò)的深度與復(fù)雜性, 很容易使得底層網(wǎng)絡(luò)變化累積到上層網(wǎng)絡(luò)中, 導(dǎo)致模型的訓(xùn)練很容易進(jìn)入到激活函數(shù)的梯度飽和區(qū); 通過(guò)normalize操作可以讓激活函數(shù)的輸入數(shù)據(jù)落在梯度非飽和區(qū), 緩解梯度消失的問題; 另外通過(guò)自適應(yīng)學(xué)習(xí)  與  又讓數(shù)據(jù)保留更多的原始信息。

        (4)BN具有一定的正則化效果

        在Batch Normalization中,由于我們使用mini-batch的均值與方差作為對(duì)整體訓(xùn)練樣本均值與方差的估計(jì),盡管每一個(gè)batch中的數(shù)據(jù)都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會(huì)有所不同,這就為網(wǎng)絡(luò)的學(xué)習(xí)過(guò)程中增加了隨機(jī)噪音,與Dropout通過(guò)關(guān)閉神經(jīng)元給網(wǎng)絡(luò)訓(xùn)練帶來(lái)噪音類似,在一定程度上對(duì)模型起到了正則化的效果。

        另外,原作者通過(guò)也證明了網(wǎng)絡(luò)加入BN后,可以丟棄Dropout,模型也同樣具有很好的泛化效果。


        (二)實(shí)戰(zhàn)板塊

        經(jīng)過(guò)了上面了理論學(xué)習(xí),我們對(duì)BN有了理論上的認(rèn)知。“Talk is cheap, show me the code”。接下來(lái)我們就通過(guò)實(shí)際的代碼來(lái)對(duì)比加入BN前后的模型效果。實(shí)戰(zhàn)部分使用MNIST數(shù)據(jù)集作為數(shù)據(jù)基礎(chǔ),并使用TensorFlow中的Batch Normalization結(jié)構(gòu)來(lái)進(jìn)行BN的實(shí)現(xiàn)。

        數(shù)據(jù)準(zhǔn)備:MNIST手寫數(shù)據(jù)集

        代碼地址:我的GitHub (https://github.com/NELSONZHAO/zhihu/tree/master/batch_normalization_discussion)

        注:TensorFlow版本為1.6.0

        實(shí)戰(zhàn)板塊主要分為兩部分:

        • 網(wǎng)絡(luò)構(gòu)建與輔助函數(shù)
        • BN測(cè)試
        1 網(wǎng)絡(luò)構(gòu)建與輔助函數(shù)

        首先我們先定義一下神經(jīng)網(wǎng)絡(luò)的類,這個(gè)類里面主要包括了以下方法:

        • build_network:前向計(jì)算
        • fully_connected:全連接計(jì)算
        • train:訓(xùn)練模型
        • test:測(cè)試模型
        1.1 build_network

        我們首先通過(guò)構(gòu)造函數(shù),把權(quán)重、激活函數(shù)以及是否使用BN這些變量傳入,并生成一個(gè)training_accuracies來(lái)記錄訓(xùn)練過(guò)程中的模型準(zhǔn)確率變化。這里的initial_weights是一個(gè)list,list中每一個(gè)元素是一個(gè)矩陣(二維tuple),存儲(chǔ)了每一層的權(quán)重矩陣。build_network實(shí)現(xiàn)了網(wǎng)絡(luò)的構(gòu)建,并調(diào)用了fully_connected函數(shù)(下面會(huì)提)進(jìn)行計(jì)算。要注意的是,由于MNIST是多分類,在這里我們不需要對(duì)最后一層進(jìn)行激活,保留計(jì)算的logits就好。

        圖片

        1.2 fully_connected

        這里的fully_connected主要用來(lái)每一層的線性與非線性計(jì)算。通過(guò)self.use_batch_norm來(lái)控制是否使用BN。

        圖片

        另外,值得注意的是,tf.layers.batch_normalization接口中training參數(shù)非常重要,官方文檔中描述為:

        training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics). NOTE: make sure to set this parameter correctly, or else your training/inference will not work properly.

        當(dāng)我們訓(xùn)練時(shí),要設(shè)置為True,保證在訓(xùn)練過(guò)程中使用的是mini-batch的統(tǒng)計(jì)量進(jìn)行normalization;在Inference階段,使用False,也就是使用總體樣本的無(wú)偏估計(jì)。

        1.3 train

        train函數(shù)主要用來(lái)進(jìn)行模型的訓(xùn)練。除了要定義label,loss以及optimizer以外,我們還需要注意,官方文檔指出在使用BN時(shí)的事項(xiàng):

        Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

        因此當(dāng)self.use_batch_norm為True時(shí),要使用tf.control_dependencies保證模型正常訓(xùn)練。

        圖片

        注意:在訓(xùn)練過(guò)程中batch_size選了60(mnist.train.next_batch(60)),這里是因?yàn)锽N的原paper中用的60。( We trained the network for 50000 steps, with 60 examples per mini-batch.)

        1.4 test

        test階段與train類似,只是要設(shè)置self.is_training=False,保證Inference階段BN的正確。

        圖片

        經(jīng)過(guò)上面的步驟,我們的框架基本就搭好了,接下來(lái)我們?cè)賹懸粋€(gè)輔助函數(shù)train_and_test以及plot繪圖函數(shù)就可以開始對(duì)BN進(jìn)行測(cè)試?yán)病rain_and_test以及plot函數(shù)見GitHub代碼中,這里不再贅述。


        *博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。



        關(guān)鍵詞: AI

        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉
        主站蜘蛛池模板: 汤阴县| 偃师市| 万安县| 光山县| 新泰市| 崇信县| 平和县| 宁都县| 衡水市| 喀喇| 正宁县| 如皋市| 阿拉善左旗| 揭东县| 武鸣县| 郯城县| 杨浦区| 嘉善县| 望都县| 高陵县| 井冈山市| 灯塔市| 盐池县| 赤壁市| 东辽县| 乐昌市| 枞阳县| 勐海县| 弋阳县| 元阳县| 瑞安市| 荔浦县| 彭阳县| 吉水县| 永顺县| 同心县| 荃湾区| 庆城县| 石泉县| 元谋县| 阳西县|