博客專欄

        EEPW首頁 > 博客 > Transformer取代者登場!微軟、清華剛推出RetNet:成本低、速度快、性能強(2)

        Transformer取代者登場!微軟、清華剛推出RetNet:成本低、速度快、性能強(2)

        發布人:計算機視覺工坊 時間:2023-07-19 來源:工程師 發布文章

        Retentive 網絡


        RetNet 由 L 個相同的塊堆疊而成,其布局與 Transformer 類似(即殘差連接和 pre-LayerNorm)。每個 RetNet 塊包含兩個模塊:多尺度retention(MSR)和前饋網絡(FFN)。


        給定輸入序列圖片,RetNet 以自回歸方式對序列進行編碼。輸入向量圖片首先被封裝為圖片 ,其中圖片是隱藏維度。然后,計算上下文向量表征圖片。

        Retention

        RetNet 具有循環和并行雙重形式的 retention 機制,因此能夠并行地訓練模型,同時循環地進行推理。


        給定輸入圖片,將其投影為一維函數 v (n) = X_n - w_V。考慮一個序列建模問題,通過狀態 s_n 映射 v (n) → o (n)。


        為簡單起見,讓 v_n, o_n 表示 v (n),o (n)。此處以循環的方式對映射進行表述:


        其中,將 v_n 映射到狀態向量 s_n,然后實現線性變換,對序列信息進行循環編碼。


        接下來,使投影 Q_n, K_n 具有內容感知能力:


        其中圖片是可學習矩陣。

        將矩陣圖片對角化,其中圖片然后得到圖片。通過將 Λ 吸收到 W_Q 和 W_K 中,可以將方程(1)重寫為


        圖片

        其中,圖片為 xPos,即為 Transformer 提出的相對位置嵌入。進一步將 γ 簡化為標量,公式(3)則變為


        圖片


        其中?為共軛轉置。該公式很容易在訓練實例中并行化。


        總之,從公式 (1) 所示的循環建模開始,然后推導出公式 (4) 中的并行公式。將原始映射 v (n) →o (n) 視為向量,得到如下的 retention 機制:


        1)Retention 的并行表征


        如圖 3a 所示,Retention 層定義為:


        圖片


        與自注意力類似,并行表征使得能夠使用 GPU 高效地訓練模型。


        2)Retention 的循環表征


        圖片


        如圖 3b 所示,所提出機制也可以寫成循環神經網絡(RNN),這有利于推理。對于第 n 個時間步,循環得到的輸出為


        這里的 Q, K, V, γ 和公式 5 相同。


        3)Retention 分塊循環表征


        并行表征和循環表征的混合形式可以加速訓練,特別是對于長序列。此處將輸入序列劃分為若干小塊。在每個塊內,按照并行表征(公式(5))進行計算。相反,跨塊信息則按照循環表征(公式(6))進行傳遞。具體來說,讓 B 表示塊長度。通過以下方式計算第 i 個分塊的 retention 輸出:


        圖片

        其中 [i] 表示第 i 個數據塊,例如圖片。

        門控多尺度 Retention


        在每個層中,研究者使用 h = d_model/d 個 retention 頭,其中 d 是頭的維度。這些頭使用不同的參數矩陣 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)為每個頭分配不同的 γ。為了簡化,研究者將 γ 設置為在不同層之間相同并保持固定。另外,他們添加了一個 swish 門 [RZL17] 來增加層的非線性性。形式上,給定輸入 X,研究者將該層定義為:


        圖片

        其中,圖片為可學習參數,GroupNorm [WH18] 對每個頭的輸出進行歸一化,遵循 [SPP^+19] 中提出的 SubLN。注意,這些頭使用多個 γ 尺度,這會帶來不同的方差統計結果。所以研究者分別對頭的輸出進行歸一化。

        retention 的偽代碼如圖 4 所示。

         

        圖片


        Retention Score 歸一化


        研究者利用 GroupNorm 的尺度不變性來提高 retention 層的數值精度。具體而言,在 GroupNorm 中乘以一個標量值不會影響輸出和反向梯度,即 GroupNorm (α ? head_i) = GroupNorm (head_i)。研究者在公式(5)中實現了三個歸一化因子。首先,他們將 QK^? 歸一化為 QK^? / √ d。其次,他們將 D 替換為圖片。第三,他們用 R 表示 retention scores R = QK^? ⊙ D,將其歸一化為圖片。然后,retention 輸出變為 圖片。由于尺度不變的特性,上述技巧不會影響最終的結果,同時穩定了正向和反向傳遞的數值流動。


        Retention 網絡總體結構


        對于一個 L 層的 retention 網絡,研究者堆疊多尺度 retention (MSR) 和前饋網絡(FFN)來構建模型。形式上,輸入序列圖片通過一個詞嵌入層被轉換為向量。研究者使用打包后的嵌入圖片作為輸入,并計算模型的輸出 X^L:

        圖片

        其中,LN (?) 為 LayerNorm [BKH16]。FFN 部分計算為 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 為參數矩陣。

        訓練:研究者在訓練過程中使用了并行(公式 5)表示和塊循環(公式 7)表示。序列或塊內的并行有效地利用了 GPU 來加速計算。更有利的是,塊循環對于長序列訓練特別有用,這在 FLOPs 和內存消耗方面都是有效的。

        推理:在推理過程中,研究者采用了循環表示(公式 6),這非常適合自回歸解碼。O (1) 的復雜度減少了內存占用和推理延遲,同時實現了相當的結果。


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 尖扎县| 沁阳市| 贵南县| 牡丹江市| 郴州市| 云霄县| 娄烦县| 郓城县| 贵南县| 大化| 鄂尔多斯市| 都江堰市| 万全县| 平阴县| 札达县| 灵武市| 佛教| 彰化县| 调兵山市| 福鼎市| 无为县| 昆山市| 鄂伦春自治旗| 尉氏县| 赤峰市| 班玛县| 恩施市| 怀宁县| 兴和县| 广河县| 霞浦县| 土默特左旗| 乐都县| 泰宁县| 电白县| 乌苏市| 射洪县| 禹州市| 泸西县| 积石山| 西宁市|