Transformer取代者登場!微軟、清華剛推出RetNet:成本低、速度快、性能強(2)
Retentive 網絡
RetNet 由 L 個相同的塊堆疊而成,其布局與 Transformer 類似(即殘差連接和 pre-LayerNorm)。每個 RetNet 塊包含兩個模塊:多尺度retention(MSR)和前饋網絡(FFN)。
給定輸入序列,RetNet 以自回歸方式對序列進行編碼。輸入向量
首先被封裝為
,其中
是隱藏維度。然后,計算上下文向量表征
。
Retention
RetNet 具有循環和并行雙重形式的 retention 機制,因此能夠并行地訓練模型,同時循環地進行推理。
給定輸入,將其投影為一維函數 v (n) = X_n - w_V。考慮一個序列建模問題,通過狀態 s_n 映射 v (n) → o (n)。
為簡單起見,讓 v_n, o_n 表示 v (n),o (n)。此處以循環的方式對映射進行表述:
其中,將 v_n 映射到狀態向量 s_n,然后實現線性變換,對序列信息進行循環編碼。
接下來,使投影 Q_n, K_n 具有內容感知能力:
其中是可學習矩陣。
將矩陣對角化,其中
。然后得到
。通過將 Λ 吸收到 W_Q 和 W_K 中,可以將方程(1)重寫為
其中,稱為 xPos,即為 Transformer 提出的相對位置嵌入。進一步將 γ 簡化為標量,公式(3)則變為
其中?為共軛轉置。該公式很容易在訓練實例中并行化。
總之,從公式 (1) 所示的循環建模開始,然后推導出公式 (4) 中的并行公式。將原始映射 v (n) →o (n) 視為向量,得到如下的 retention 機制:
1)Retention 的并行表征
如圖 3a 所示,Retention 層定義為:
與自注意力類似,并行表征使得能夠使用 GPU 高效地訓練模型。
2)Retention 的循環表征
如圖 3b 所示,所提出機制也可以寫成循環神經網絡(RNN),這有利于推理。對于第 n 個時間步,循環得到的輸出為
這里的 Q, K, V, γ 和公式 5 相同。
3)Retention 分塊循環表征
并行表征和循環表征的混合形式可以加速訓練,特別是對于長序列。此處將輸入序列劃分為若干小塊。在每個塊內,按照并行表征(公式(5))進行計算。相反,跨塊信息則按照循環表征(公式(6))進行傳遞。具體來說,讓 B 表示塊長度。通過以下方式計算第 i 個分塊的 retention 輸出:
其中 [i] 表示第 i 個數據塊,例如。
門控多尺度 Retention
在每個層中,研究者使用 h = d_model/d 個 retention 頭,其中 d 是頭的維度。這些頭使用不同的參數矩陣 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)為每個頭分配不同的 γ。為了簡化,研究者將 γ 設置為在不同層之間相同并保持固定。另外,他們添加了一個 swish 門 [RZL17] 來增加層的非線性性。形式上,給定輸入 X,研究者將該層定義為:
其中,為可學習參數,GroupNorm [WH18] 對每個頭的輸出進行歸一化,遵循 [SPP^+19] 中提出的 SubLN。注意,這些頭使用多個 γ 尺度,這會帶來不同的方差統計結果。所以研究者分別對頭的輸出進行歸一化。
retention 的偽代碼如圖 4 所示。
Retention Score 歸一化
研究者利用 GroupNorm 的尺度不變性來提高 retention 層的數值精度。具體而言,在 GroupNorm 中乘以一個標量值不會影響輸出和反向梯度,即 GroupNorm (α ? head_i) = GroupNorm (head_i)。研究者在公式(5)中實現了三個歸一化因子。首先,他們將 QK^? 歸一化為 QK^? / √ d。其次,他們將 D 替換為。第三,他們用 R 表示 retention scores R = QK^? ⊙ D,將其歸一化為
。然后,retention 輸出變為
。由于尺度不變的特性,上述技巧不會影響最終的結果,同時穩定了正向和反向傳遞的數值流動。
Retention 網絡總體結構
對于一個 L 層的 retention 網絡,研究者堆疊多尺度 retention (MSR) 和前饋網絡(FFN)來構建模型。形式上,輸入序列通過一個詞嵌入層被轉換為向量。研究者使用打包后的嵌入
作為輸入,并計算模型的輸出 X^L:
其中,LN (?) 為 LayerNorm [BKH16]。FFN 部分計算為 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 為參數矩陣。
訓練:研究者在訓練過程中使用了并行(公式 5)表示和塊循環(公式 7)表示。序列或塊內的并行有效地利用了 GPU 來加速計算。更有利的是,塊循環對于長序列訓練特別有用,這在 FLOPs 和內存消耗方面都是有效的。
推理:在推理過程中,研究者采用了循環表示(公式 6),這非常適合自回歸解碼。O (1) 的復雜度減少了內存占用和推理延遲,同時實現了相當的結果。
*博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。