一、NBCE
NBCE:使用樸素貝葉斯擴(kuò)展LLM的Context處理長(zhǎng)度
蘇神最早提出的擴(kuò)展LLM的context方法,基于bayes啟發(fā)得到的公式:

在問(wèn)答下實(shí)測(cè)確實(shí)不錯(cuò),在較長(zhǎng)context下的閱讀理解還算好用。
局限性是,無(wú)序性,即無(wú)法識(shí)別Context的輸入順序,這在續(xù)寫(xiě)故事等場(chǎng)景可能表現(xiàn)欠佳,做一些依賴(lài)每個(gè)context生成答案,比如提取文檔摘要,效果較差。
outputs=model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, use_cache=True, past_key_values=past_key_values ) past_key_values=outputs.past_key_values #=====核心代碼開(kāi)始===== beta=0.25 probas=torch.nn.functional.softmax(outputs.logits[:,-1],dim=-1) logits=probas.log() k=(probas*logits).sum(dim=-1)[1:].argmax()+1 logits_max=logits[k] logits_uncond=logits[0] logits=(1+beta)*logits_max-beta*logits_uncond #=====核心代碼結(jié)束===== #構(gòu)建分布,采樣 tau=0.01#tau=1是標(biāo)準(zhǔn)的隨機(jī)采樣,tau->0則是貪心搜索 probas=torch.nn.functional.softmax(logits[None]/tau,dim=-1) next_tokens=torch.multinomial(probas,num_samples=1).squeeze(1)
此處代碼,圖片,文本均選自科學(xué)空間。
二、線性?xún)?nèi)插
llama基于rotary embedding在2048長(zhǎng)度上預(yù)訓(xùn)練,該方法通過(guò)將position壓縮到0~2048之間,從而達(dá)到長(zhǎng)度外推的目的。
longchat將模型微調(diào)為上下文長(zhǎng)度外擴(kuò)為16384,壓縮比為 8。例如,position_ids = 10000 的 token 變?yōu)閜osition_ids = 10000 / 8 = 1250,相鄰 token 10001 變?yōu)?10001 / 8 = 1250.125
該方法的缺陷是需要進(jìn)行一定量的微調(diào),讓模型來(lái)適應(yīng)這種改變。
importtorch importtransformers importtransformers.models.llama.modeling_llama fromeinopsimportrearrange fromfunctoolsimportpartial classCondenseRotaryEmbedding(torch.nn.Module): def__init__(self,dim,ratio,max_position_embeddings=2048,base=10000,device=None): super().__init__() inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim)) self.register_buffer("inv_freq",inv_freq) #Buildheretomake`torch.jit.trace`work. self.ratio=ratio max_position_embeddings*=ratio print(f"CondensingPositionalembeddingsfrom{max_position_embeddings}to{max_position_embeddings//ratio}") self.max_seq_len_cached=max_position_embeddings t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)/ratio freqs=torch.einsum("i,j->ij",t,self.inv_freq) #Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation emb=torch.cat((freqs,freqs),dim=-1) dtype=torch.get_default_dtype() self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False) self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False) defforward(self,x,seq_len=None): #x:[bs,num_attention_heads,seq_len,head_size] #This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase. ifseq_len>self.max_seq_len_cached: self.max_seq_len_cached=seq_len t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)/self.ratio freqs=torch.einsum("i,j->ij",t,self.inv_freq) #Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation emb=torch.cat((freqs,freqs),dim=-1).to(x.device) self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False) self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False) return( self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype), self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype), ) defreplace_llama_with_condense(ratio): transformers.models.llama.modeling_llama.LlamaRotaryEmbedding=partial(CondenseRotaryEmbedding,ratio=ratio)
三、NTK-Aware Scaled RoPE
RoPE是一種β進(jìn)制編碼//spaces.ac.cn/archives/9675

有意思的解釋一下,RoPE 的行為就像一個(gè)時(shí)鐘。12小時(shí)時(shí)鐘基本上是一個(gè)維度為 3、底數(shù)為 60 的 RoPE。因此,每秒鐘,分針轉(zhuǎn)動(dòng) 1/60 分鐘,每分鐘,時(shí)針轉(zhuǎn)動(dòng) 1/60。
現(xiàn)在,如果將時(shí)間減慢 4 倍,那就是二使用的線性RoPE 縮放。不幸的是,現(xiàn)在區(qū)分每一秒,因?yàn)楝F(xiàn)在秒針幾乎每秒都不會(huì)移動(dòng)。
因此,如果有人給你兩個(gè)不同的時(shí)間,僅相差一秒,你將無(wú)法從遠(yuǎn)處區(qū)分它們。NTK-Aware RoPE 擴(kuò)展不會(huì)減慢時(shí)間。一秒仍然是一秒,但它會(huì)使分鐘減慢 1.5 倍,將小時(shí)減慢 2 倍。
這樣,您可以將 90 分鐘容納在一個(gè)小時(shí)中,將 24 小時(shí)容納在半天中。
所以現(xiàn)在你基本上有了一個(gè)可以測(cè)量 129.6k 秒而不是 43.2k 秒的時(shí)鐘。由于在查看時(shí)間時(shí)不需要精確測(cè)量時(shí)針,因此與秒相比,更大程度地縮放小時(shí)至關(guān)重要。
不想失去秒針的精度,但可以承受分針甚至?xí)r針的精度損失。
importtransformers old_init=transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ defntk_scaled_init(self,dim,max_position_embeddings=2048,base=10000,device=None): #Themethodisjustthesethreelines max_position_embeddings=16384 a=8#Alphavalue base=base*a**(dim/(dim-2))#Basechangeformula old_init(self,dim,max_position_embeddings,base,device) transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__=ntk_scaled_init
四、Dynamically Scaled RoPE

對(duì)于上面的方法二、三,都涉及到一個(gè)超參數(shù)α,用于調(diào)節(jié)縮放比例,該方法是通過(guò)序列長(zhǎng)度動(dòng)態(tài)選擇正確的比例參數(shù),效果可以看上圖。
對(duì)于線性插值,前 2k 上下文的精確位置值,然后在模型逐個(gè)生成標(biāo)記時(shí)重新計(jì)算每個(gè)新序列長(zhǎng)度的位置向量。本質(zhì)上,將比例設(shè)置為原始模型上下文長(zhǎng)度/當(dāng)前序列長(zhǎng)度。
對(duì)于動(dòng)態(tài) NTK,α 的縮放設(shè)置為 (α * 當(dāng)前序列長(zhǎng)度 / 原始模型上下文長(zhǎng)度) - (α - 1)。隨著序列長(zhǎng)度的增加動(dòng)態(tài)縮放超參數(shù)。
importmath
importtorch
classLlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,max_position_embeddings=2048,base=10000,ntk=False,device=None):
super().__init__()
self.ntk=ntk
self.base=base
self.dim=dim
self.max_position_embeddings=max_position_embeddings
inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim))
self.register_buffer("inv_freq",inv_freq)
#Buildheretomake`torch.jit.trace`work.
self.max_seq_len_cached=max_position_embeddings
t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1)
dtype=torch.get_default_dtype()
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False)
defforward(self,x,seq_len=None):
#x:[bs,num_attention_heads,seq_len,head_size]
#This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase.
ifseq_len>self.max_seq_len_cached:
self.max_seq_len_cached=seq_len
ifself.ntk:
base=self.base*((self.ntk*seq_len/self.max_position_embeddings)-(self.ntk-1))**(self.dim/(self.dim-2))
inv_freq=1.0/(base**(torch.arange(0,self.dim,2).float().to(x.device)/self.dim))
self.register_buffer("inv_freq",inv_freq)
t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)
ifnotself.ntk:
t*=self.max_position_embeddings/seq_len
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False)
return(
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)
五、consistent of Dynamically Scaled RoPE

方法四存在一個(gè)問(wèn)題是,因?yàn)棣潦莿?dòng)態(tài)的,因?yàn)榻獯a是有cache的,所以,在生成第100個(gè)token時(shí),算的α和第200個(gè)token時(shí),算的α?xí)r不一致的。
query和key的rotation base不一致,正確的應(yīng)該時(shí)這樣
importmath fromtypingimportList,Optional,Tuple,Union importtorch importtorch.nn.functionalasF importtorch.utils.checkpoint fromtorchimportnn fromtransformers.models.llama.modeling_llamaimportrepeat_kv,apply_rotary_pos_emb fromtransformers.models.llama.modeling_llamaimportLlamaAttention defforward( self, hidden_states:torch.Tensor, attention_mask:Optional[torch.Tensor]=None, position_ids:Optional[torch.LongTensor]=None, past_key_value:Optional[Tuple[torch.Tensor]]=None, output_attentions:bool=False, use_cache:bool=False, )->Tuple[torch.Tensor,Optional[torch.Tensor],Optional[Tuple[torch.Tensor]]]: bsz,q_len,_=hidden_states.size() ifself.pretraining_tp>1: key_value_slicing=(self.num_key_value_heads*self.head_dim)//self.pretraining_tp query_slices=self.q_proj.weight.split((self.num_heads*self.head_dim)//self.pretraining_tp,dim=0) key_slices=self.k_proj.weight.split(key_value_slicing,dim=0) value_slices=self.v_proj.weight.split(key_value_slicing,dim=0) query_states=[F.linear(hidden_states,query_slices[i])foriinrange(self.pretraining_tp)] query_states=torch.cat(query_states,dim=-1) key_states=[F.linear(hidden_states,key_slices[i])foriinrange(self.pretraining_tp)] key_states=torch.cat(key_states,dim=-1) value_states=[F.linear(hidden_states,value_slices[i])foriinrange(self.pretraining_tp)] value_states=torch.cat(value_states,dim=-1) else: query_states=self.q_proj(hidden_states) key_states=self.k_proj(hidden_states) value_states=self.v_proj(hidden_states) query_states=query_states.view(bsz,q_len,self.num_heads,self.head_dim).transpose(1,2) key_states=key_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2) value_states=value_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2) kv_seq_len=key_states.shape[-2] ifpast_key_valueisnotNone: kv_seq_len+=past_key_value[0].shape[-2] cos,sin=self.rotary_emb(value_states,seq_len=kv_seq_len) ifpast_key_valueisnotNone: #reusekw/oRoPE key_states=torch.cat([past_key_value[0],key_states],dim=2) #applyRoPEafterretrievingallkeysandqueries query_states,rotated_key_states=apply_rotary_pos_emb(query_states,key_states,cos,sin,position_ids) ifpast_key_valueisnotNone: #reusev,self_attention value_states=torch.cat([past_key_value[1],value_states],dim=2) past_key_value=(key_states,value_states)ifuse_cacheelseNone#cachethekeyw/oRoPE #repeatk/vheadsifn_kv_heads1: attn_output=attn_output.split(self.hidden_size//self.pretraining_tp,dim=2) o_proj_slices=self.o_proj.weight.split(self.hidden_size//self.pretraining_tp,dim=1) attn_output=sum([F.linear(attn_output[i],o_proj_slices[i])foriinrange(self.pretraining_tp)]) else: attn_output=self.o_proj(attn_output) ifnotoutput_attentions: attn_weights=None returnattn_output,attn_weights,past_key_value defreplace_llama_attn_with_consistent_ntk_rope(): LlamaAttention.forward=forward
審核編輯:劉清
-
解碼器
+關(guān)注
關(guān)注
9文章
1219瀏覽量
43405 -
LLM
+關(guān)注
關(guān)注
1文章
346瀏覽量
1331
原文標(biāo)題:淺談LLM的長(zhǎng)度外推
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
對(duì)比解碼在LLM上的應(yīng)用
餓了么確認(rèn)收購(gòu)百度外賣(mài)!最快本周收購(gòu),百度外賣(mài)為何會(huì)變成百度的棄子?
餓了么正式宣布收購(gòu)百度外賣(mài) 后者人員架構(gòu)不變以獨(dú)立品牌運(yùn)營(yíng)
餓了么正式宣布收購(gòu)百度外賣(mài) 內(nèi)部郵件曝光
LLM性能的主要因素
使用MLC-LLM支持RWKV-5推理的過(guò)程思考
如何利用位置編碼實(shí)現(xiàn)長(zhǎng)度外推?
LLM推理加速新范式!推測(cè)解碼(Speculative Decoding)最新綜述
hdmi線纜長(zhǎng)度根據(jù)什么決定選擇
什么是LLM?LLM的工作原理和結(jié)構(gòu)
LLM模型的應(yīng)用領(lǐng)域
llm模型有哪些格式
CS1-U DC/AC5-240V磁性開(kāi)關(guān)長(zhǎng)度要求
什么是LLM?LLM在自然語(yǔ)言處理中的應(yīng)用
新品 | LLM-8850 Kit,高性能AI加速卡套件 DinMeter v1.1,1/32DIN標(biāo)準(zhǔn)嵌入式開(kāi)發(fā)板
LLM的長(zhǎng)度外推淺談
評(píng)論