Transformer取代者登場!微軟、清華剛推出RetNet:成本低、速度快、性能強(3)
與以往方法的聯系和區別
表 1 從不同角度對 RetNet 與以往的方法進行了比較。對比結果與圖 2 所示的「不可能三角」相呼應。此外,RetNet 對于長序列具有線性記憶復雜性,因為它采用了分塊循環表示。
Transformer:retention 的并行表示與 Transformers [VSP^+17] 有著相似的思路。最相關的 Transformer 變體是 Lex Transformer [SDP^+22],它實現了 xPos 作為位置嵌入。如式 (3) 所示,retention 的推導與 xPos 一致。與注意力相比,retention 消除了 softmax 并使循環公式成為可能,這非常有利于推理。
S4:與式 (2) 不同,如果 Q_n 和 K_n 是 content-unaware 的,則公式可簡并為 S4 [GGR21],其中
Linear Attention:變體通常使用各種 kernel來取代 softmax 函數。然而,線性注意力難以有效地編碼位置信息,導致模型性能下降。此外,研究者從頭開始重新檢查序列建模,而不是以近似 softmax 為目標。
AFT/RWKV:Attention Free Transformer (AFT) 簡化了點積對元素運算的關注,并將 softmax 移動到關鍵向量。RWKV 用指數衰減取代 AFT 的位置嵌入,并循環運行模型進行訓練和推理。相比之下,retention 保留了高維狀態來編碼序列信息,有助于提高表達能力和性能。
xPos/RoPE:與為 Transformers 提出的相對位置嵌入方法相比,公式(3)呈現出與 xPos [SDP^+22] 和 RoPE [SLP^+21] 類似的表達式。
Sub-LayerNorm:如公式(8)所示,retention 層使用 Sub-LayerNorm [WMH^+22] 對輸出進行歸一化。由于多尺度建模導致不同頭的方差不同,研究者將原始的 LayerNorm 替換為 GroupNorm。
實驗結果
該研究進行了大量的實驗來評估 RetNet,包括語言建模任務、下游任務上零樣本、少樣本學習性能,此外,研究者還比較了 RetNet 訓練和推理的速度、內存消耗和延遲等指標。
與 Transformer 的比較
語言建模任務。圖 5 報告了基于 Transformer 和 RetNet 的語言模型在驗證集上的困惑度(perplexity)結果。實驗給出了 13 b、2.7B 和 6.7B 三種模型尺寸的縮放曲線。表明,RetNet 取得了與 Transformer 可比較的結果。
更重要的是,這一結果還表明了 RetNet 在大小擴展方面更具優勢。除了性能優勢外,實驗中 RetNet 的訓練也非常穩定。RetNet 是 Transformer 的有力競爭對手。研究者根據經驗發現,當模型規模大于 2B 時,RetNet 開始超越 Transformer。
該研究還在各種下游任務上對語言模型進行了比較。他們使用 6.7B 大小的模型進行了零樣本和 4 個樣本學習的評估,如表 3 所示。表中展示的關于準確率的數字與圖 5 中呈現的語言建模困惑度一致。在零樣本學習和上下文學習設置中,RetNet 在性能上與 Transformer 相當。
訓練成本
表 4 比較了 Transformer 和 RetNet 在訓練速度和內存開銷方面的結果,其中訓練序列長度為 8192。此外,該研究還將其與 FlashAttention 進行了比較。
實驗結果表明,在訓練過程中,RetNet 比 Transformer 更節省內存,并且具有更高的吞吐量。即使與 FlashAttention 相比,RetNet 在速度和內存成本方面仍然具有競爭力。此外,由于不依賴于特定的內核,用戶可以輕松高效地在其他平臺上訓練 RetNet。例如,研究者可以在具有良好吞吐量的 AMD MI200 集群上訓練 RetNet 模型。
推理成本
圖 6 比較了 Transformer 和 RetNet 在推理過程中的內存成本、吞吐量和延遲。實驗中使用了 A100-80GB GPU 評估了 6.7B 模型。圖 6 顯示,RetNet 在推理成本方面優于 Transformer。
內存:如圖 6a 所示,由于 KV(鍵和值)緩存,Transformer 的內存成本呈線性增長。相比之下,RetNet 的內存消耗即使對于長序列也保持一致。
吞吐量:如圖 6b 所示,隨著解碼長度的增加,Transformer 的吞吐量開始下降。相比之下,RetNet 通過利用 Retention 的循環表征,在解碼過程中具有更高的吞吐量,并且與長度無關。
延遲:延遲是部署中的重要指標,它極大地影響用戶體驗。圖 6c 報告了解碼延遲。實驗結果顯示,增加批次大小會使 Transformer 的延遲變大。此外,Transformer 的延遲隨著輸入長度的增加而增加得更快。為了使延遲可接受,研究者不得不限制批次大小,這會損害 Transformer 的整體推理吞吐量。相比之下,RetNet 的解碼延遲優于 Transformer,并且在不同的批次大小和輸入長度下幾乎保持不變。
與 Transformer 變體比較
下表表明,RetNet 在不同的數據集上優于先前的方法。RetNet 不僅在領域內語料庫上取得更好的評估結果,還在幾個領域外數據集上獲得更低的困惑度。這種優越的性能使得 RetNet 成為 Transformer 的有力繼任者。
消融實驗
下表列出了 RetNet 的各種設計選擇,并在表 6 中報告了語言建模結果。
*博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。