本文介紹了一項(xiàng)近似注意力機(jī)制新研究,耶魯大學(xué)、谷歌研究院等機(jī)構(gòu)提出了 HyperAttention,使 ChatGLM2 在 32k 上下文長(zhǎng)度上的推理時(shí)間快了 50%。
Transformer 已經(jīng)成功應(yīng)用于自然語(yǔ)言處理、計(jì)算機(jī)視覺(jué)和時(shí)間序列預(yù)測(cè)等領(lǐng)域的各種學(xué)習(xí)任務(wù)。雖然取得了成功,但這些模型仍面臨著嚴(yán)重的可擴(kuò)展性限制,原因是對(duì)其注意力層的精確計(jì)算導(dǎo)致了二次(在序列長(zhǎng)度上)運(yùn)行時(shí)和內(nèi)存復(fù)雜性。這對(duì)將 Transformer 模型擴(kuò)展到更長(zhǎng)的上下文長(zhǎng)度帶來(lái)了根本性的挑戰(zhàn)。
業(yè)界已經(jīng)探索了各種方法來(lái)解決二次時(shí)間注意力層的問(wèn)題,其中一個(gè)值得注意的方向是近似注意力層中的中間矩陣。實(shí)現(xiàn)這一點(diǎn)的方法包括通過(guò)稀疏矩陣、低秩矩陣進(jìn)行近似,或兩者的結(jié)合。
然而,這些方法并不能為注意力輸出矩陣的近似提供端到端的保證。這些方法旨在更快地逼近注意力的各個(gè)組成部分,但沒(méi)有一種方法能提供完整點(diǎn)積注意力的端到端逼近。這些方法還不支持使用因果掩碼,而因果掩碼是現(xiàn)代 Transformer 架構(gòu)的重要組成部分。最近的理論邊界表明,在一般情況下,不可能在次二次時(shí)間內(nèi)對(duì)注意力矩陣進(jìn)行分項(xiàng)近似。
不過(guò),最近一項(xiàng)名為 KDEFormer 的研究表明,在注意力矩陣項(xiàng)有界的假設(shè)條件下,它能在次二次時(shí)間內(nèi)提供可證明的近似值。從理論上講,KDEFormer 的運(yùn)行時(shí)大約為
;它采用核密度估計(jì) (kernel density estimation,KDE) 來(lái)近似列范數(shù),允許計(jì)算對(duì)注意力矩陣的列進(jìn)行采樣的概率。然而,目前的 KDE 算法缺乏實(shí)際效率,即使在理論上,KDEFormer 的運(yùn)行時(shí)與理論上可行的 O (n) 時(shí)間算法之間也有差距。
在文中,作者證明了在同樣的有界條目假設(shè)下,近線性時(shí)間的
算法是可能的。不過(guò),他們的算法還涉及使用多項(xiàng)式方法來(lái)逼近 softmax,很可能不切實(shí)際。
而在本文中,來(lái)自耶魯大學(xué)、谷歌研究院等機(jī)構(gòu)的研究者提供了一種兩全其美的算法,既實(shí)用高效,又是能實(shí)現(xiàn)最佳近線性時(shí)間保證。此外,該方法還支持因果掩碼,這在以前的工作中是不可能實(shí)現(xiàn)的。
論文標(biāo)題:HyperAttention: Long-context Attention in Near-Linear Time論文鏈接:
https://arxiv.org/abs/2310.05869 本文提出一種名為「HyperAttention」近似注意力機(jī)制,以解決大型語(yǔ)言模型中使用的長(zhǎng)上下文日益復(fù)雜帶來(lái)的計(jì)算挑戰(zhàn)。最近的工作表明,在最壞情況下,除非注意力矩陣的條目有界或矩陣的穩(wěn)定秩較低,否則二次時(shí)間是必要的。 研究者引入了兩個(gè)參數(shù)來(lái)衡量:(1)歸一化注意力矩陣中的最大列范數(shù),(2)檢測(cè)和刪除大條目后,非歸一化注意力矩陣中的行范數(shù)的比例。他們使用這些細(xì)粒度參數(shù)來(lái)反映問(wèn)題的難易程度。只要上述參數(shù)很小,即使矩陣具有無(wú)界條目或較大的穩(wěn)定秩,也能夠?qū)崿F(xiàn)線性時(shí)間采樣算法。 HyperAttention 的特點(diǎn)是模塊化設(shè)計(jì),可以輕松集成其他快速底層實(shí)現(xiàn),特別是 FlashAttention。根據(jù)經(jīng)驗(yàn),使用 LSH 算法來(lái)識(shí)別大型條目,HyperAttention 優(yōu)于現(xiàn)有方法,與 FlashAttention 等 SOTA 解決方案相比,速度有了顯著提高。研究者在各種不同的長(zhǎng)上下文長(zhǎng)度數(shù)據(jù)集上驗(yàn)證了 HyperAttention 的性能。 例如,HyperAttention 使 ChatGLM2 在 32k 上下文長(zhǎng)度上的推理時(shí)間快了 50%,而困惑度從 5.6 增加到 6.3。更大的上下文長(zhǎng)度(例如 131k)和因果掩碼情況下,HyperAttention 在單個(gè)注意力層上速度提升了 5 倍。

方法概覽
點(diǎn)積注意涉及處理三個(gè)輸入矩陣: Q (queries) 、K (key)、V (value),大小均為 nxd,其中 n 是輸入序列中的 token 數(shù),d 是潛在表征的維度。這一過(guò)程的輸出結(jié)果如下:
這里,矩陣 A := exp (QK^T) 被定義為 QK^T 的元素指數(shù)。D 是一個(gè) n×n 對(duì)角矩陣,由 A 各行之和導(dǎo)出, 這里
。在這種情況下,矩陣 A 被稱為「注意力矩陣」,(D^-1 ) A 被稱為「softmax 矩陣」。值得注意的是,直接計(jì)算注意力矩陣 A 需要 Θ(n2d)運(yùn)算,而存儲(chǔ)它需要消耗 Θ(n2)內(nèi)存。因此,直接計(jì)算 Att 需要 Ω(n2d)的運(yùn)行時(shí)和 Ω(n2)的內(nèi)存。
研究者目標(biāo)是高效地近似輸出矩陣 Att,同時(shí)保留其頻譜特性。他們的策略包括為對(duì)角縮放矩陣 D 設(shè)計(jì)一個(gè)近線性時(shí)間的高效估計(jì)器。此外,他們通過(guò)子采樣快速逼近 softmax 矩陣 D^-1A 的矩陣乘積。更具體地說(shuō),他們的目標(biāo)是找到一個(gè)具有有限行數(shù)
的采樣矩陣
以及一個(gè)對(duì)角矩陣
,從而滿足誤差的算子規(guī)范的以下約束:


算法
為了在近似 Att 時(shí)獲得頻譜保證,本文第一步是對(duì)矩陣 D 的對(duì)角線項(xiàng)進(jìn)行 1 ± ε 近似。隨后,根據(jù) V 的平方行??-norms,通過(guò)采樣逼近 (D^-1)A 和 V 之間的矩陣乘積。 近似 D 的過(guò)程包括兩個(gè)步驟。首先,使用植根于 Hamming 排序 LSH 的算法來(lái)識(shí)別注意力矩陣中的主要條目,如定義 1 所示。第二步是隨機(jī)選擇一小部分 K。本文將證明,在矩陣 A 和 D 的某些溫和假設(shè)條件下,這種簡(jiǎn)單的方法可以建立估計(jì)矩陣的頻譜邊界。研究者的目標(biāo)是找到一個(gè)足夠精確的近似矩陣 D,滿足:

,使得
。
算法的第一步是使用 Hamming 排序 LSH (sortLSH) 將鍵和查詢散列到大小均勻的桶中,從而識(shí)別注意力矩陣 A 中的大型條目。算法 1 詳細(xì)介紹了這一過(guò)程,圖 1 直觀地說(shuō)明了這一過(guò)程。


?整合近似對(duì)角線
和近似
與值矩陣 V 之間矩陣乘積的子程序。因此,研究者引入了 HyperAttention,這是一種高效算法,可以在近似線性時(shí)間內(nèi)近似公式(1)中具有頻譜保證的注意力機(jī)制。算法 3 將定義注意力矩陣中主導(dǎo)條目的位置的掩碼 MH 作為輸入。這個(gè)掩碼可以使用 sortLSH 算法(算法 1)生成,也可以是一個(gè)預(yù)定義的掩碼,類(lèi)似于 [7] 中的方法。研究者假定大條目掩碼 M^H 在設(shè)計(jì)上是稀疏的,而且其非零條目數(shù)是有界的
。
如圖 2 所示,本文方法基于一個(gè)重要的觀察結(jié)果。屏蔽注意力 M^C⊙A 可以分解成三個(gè)非零矩陣,每個(gè)矩陣的大小是原始注意力矩陣的一半。完全位于對(duì)角線下方的 A_21 塊是未屏蔽注意力。因此,我們可以使用算法 2 近似計(jì)算其行和。
圖 2 中顯示的兩個(gè)對(duì)角線區(qū)塊
和
是因果注意力,其大小只有原來(lái)的一半。為了處理這些因果關(guān)系,研究者采用遞歸方法,將它們進(jìn)一步分割成更小的區(qū)塊,并重復(fù)這一過(guò)程。算法 4 中給出了這一過(guò)程的偽代碼。


實(shí)驗(yàn)及結(jié)果
研究者通過(guò)擴(kuò)展現(xiàn)有大語(yǔ)言模型來(lái)處理 long range 序列,進(jìn)而對(duì)算法進(jìn)行基準(zhǔn)測(cè)試。所有實(shí)驗(yàn)都在單個(gè) 40GB 的 A100 GPU 上運(yùn)行,并用 FlashAttention 2 來(lái)進(jìn)行精確的注意力計(jì)算。 Monkey Patching自注意力 研究者首先在兩個(gè)預(yù)訓(xùn)練 LLM 上評(píng)估 HyperAttention,選擇了實(shí)際應(yīng)用中廣泛使用的具有不同架構(gòu)的兩個(gè)模型:chatglm2-6b-32k 和 phi-1.5。 在操作中,他們通過(guò)替換為 HyperAttention 來(lái) patch 最終的?注意力層,其中?的數(shù)量可以從 0 到每個(gè) LLM 中所有注意力層的總數(shù)不等。請(qǐng)注意,兩個(gè)模型中的注意力都需要因果掩碼,并且遞歸地應(yīng)用算法 4 直到輸入序列長(zhǎng)度 n 小于 4,096。對(duì)于所有序列長(zhǎng)度,研究者將 bucket 大小 b 和采樣列數(shù) m 均設(shè)置為 256。他們從困惑度和加速度兩個(gè)方面評(píng)估了這類(lèi) monkey patched 模型的性能。 同時(shí)研究者使用了一個(gè)長(zhǎng)上下文基準(zhǔn)數(shù)據(jù)集的集合 LongBench,它包含了 6 個(gè)不同的任務(wù),即單 / 多文檔問(wèn)答、摘要、小樣本學(xué)習(xí)、合成任務(wù)和代碼補(bǔ)全。他們選擇了編碼序列長(zhǎng)度大于 32,768 的數(shù)據(jù)集的子集,并且如果長(zhǎng)度超過(guò) 32,768,則進(jìn)行剪枝。接著計(jì)算每個(gè)模型的困惑度,即下一個(gè) token 預(yù)測(cè)的損失。為了突出長(zhǎng)序列的可擴(kuò)展性,研究者還計(jì)算所有注意力層的總加速,無(wú)論是由 HyperAttention 還是 FlashAttention 執(zhí)行。 結(jié)果如下圖 3 所示,即使經(jīng)過(guò) HyperAttention 的 monkey patch,chatglm2-6b-32k 仍顯示出合理的困惑度。例如替換 20 層后,困惑度大約增加了 1,并在達(dá)到 24 層之前繼續(xù)緩慢增加。注意力層的運(yùn)行時(shí)提升了大約 50%。如果所有層都被替換,則困惑度上升到 12,運(yùn)行速度提升 2.3。phi-1.5 模型也表現(xiàn)出了類(lèi)似的情況,但隨著 HyperAttention 數(shù)量的增加,困惑度會(huì)線性增長(zhǎng)。
此外,研究者評(píng)估了 LongBench 數(shù)據(jù)集上 monkey patched chatglm2-6b-32k 的性能,并計(jì)算單 / 多文檔問(wèn)答、摘要、小樣本學(xué)習(xí)、合成任務(wù)和代碼補(bǔ)全等各自任務(wù)上的評(píng)估分?jǐn)?shù)。結(jié)果如下表 1 所示。
雖然替換 HyperAttention 通常會(huì)導(dǎo)致性能下降,但他們觀察到它的影響會(huì)基于手頭任務(wù)發(fā)生變化。例如,摘要和代碼補(bǔ)全相對(duì)于其他任務(wù)具有最強(qiáng)的穩(wěn)健性。

?
原文標(biāo)題:全新近似注意力機(jī)制HyperAttention:對(duì)長(zhǎng)上下文友好、LLM推理提速50%
文章出處:【微信公眾號(hào):智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
-
物聯(lián)網(wǎng)
+關(guān)注
關(guān)注
2945文章
47820瀏覽量
414886
原文標(biāo)題:全新近似注意力機(jī)制HyperAttention:對(duì)長(zhǎng)上下文友好、LLM推理提速50%
文章出處:【微信號(hào):tyutcsplab,微信公眾號(hào):智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
堪稱史上最強(qiáng)推理芯片!英偉達(dá)發(fā)布 Rubin CPX,實(shí)現(xiàn)50倍ROI
NVIDIA BlueField-4為推理上下文記憶存儲(chǔ)平臺(tái)提供強(qiáng)大支持
NVIDIA DGX SuperPOD為Rubin平臺(tái)橫向擴(kuò)展提供藍(lán)圖
深入解析NVIDIA Nemotron 3系列開(kāi)放模型
奇異摩爾入選2025中國(guó)科創(chuàng)好公司半導(dǎo)體榜單
大語(yǔ)言模型如何處理上下文窗口中的輸入
NVIDIA TensorRT LLM 1.0推理框架正式上線
請(qǐng)問(wèn)riscv中斷還需要軟件保存上下文和恢復(fù)嗎?
米爾RK3576部署端側(cè)多模態(tài)多輪對(duì)話,6TOPS算力驅(qū)動(dòng)30億參數(shù)LLM
【「DeepSeek 核心技術(shù)揭秘」閱讀體驗(yàn)】+看視頻+看書(shū)籍+國(guó)產(chǎn)開(kāi)源大模型DeepSeekV3技術(shù)詳解--1
如何在NVIDIA Blackwell GPU上優(yōu)化DeepSeek R1吞吐量
鴻蒙NEXT-API19獲取上下文,在class中和ability中獲取上下文,API遷移示例-解決無(wú)法在EntryAbility中無(wú)法使用最新版
詳解 LLM 推理模型的現(xiàn)狀
全新近似注意力機(jī)制HyperAttention:對(duì)長(zhǎng)上下文友好、LLM推理提速50%
評(píng)論