在 AI 領域,文本翻譯、語音識別、股價預測等場景都離不開序列數據處理。循環神經網絡(RNN)作為最早的序列建模工具,開創了 “記憶歷史信息” 的先河;而長短期記憶網絡(LSTM)則通過創新設計,突破了 RNN 的核心局限。今天,我們從原理、梯度推導到實踐,全面解析這兩大經典模型。
一、基礎鋪墊:RNN 的核心邏輯與痛點
RNN 的核心是讓模型 “記住過去”—— 通過隱藏層的循環連接,將前一時刻的信息傳遞到當前時刻,從而捕捉序列的時序關聯。但這種 “全記憶” 設計,也埋下了梯度消失的隱患。
1.1 核心結構與參數

RNN 結構簡化為 “輸入層 - 隱藏層 - 輸出層”,關鍵組件如下:
- 輸入:Xt(第 t 時刻輸入,如文本中的詞向量)
- 隱藏狀態:St(存儲截至 t 時刻的歷史信息,核心記憶載體)
- 輸出:Ot(第 t 時刻預測結果,如分類標簽)
- 共享參數(所有時間步復用):Wx:輸入→隱藏層權重矩陣(維度:隱藏層維度 × 輸入維度)Ws:隱藏層→自身的循環權重矩陣(維度:隱藏層維度 × 隱藏層維度,關鍵)Wo:隱藏層→輸出層權重矩陣(維度:輸出維度 × 隱藏層維度)偏置:b?(隱藏層偏置,維度:隱藏層維度 ×1)、b?(輸出層偏置,維度:輸出維度 ×1)
- 激活函數:隱藏層用 tanh(值縮至 [-1,1]),輸出層用 Softmax(分類)或線性激活(回歸)
1.2 前向傳播:信息如何流動?
前向傳播是 “輸入→輸出” 的計算過程,每個時間步的結果依賴前一時刻的隱藏狀態(以下基于標量簡化,向量場景邏輯一致):更新隱藏狀態

當前記憶 St由 “當前輸入 Xt” 和 “歷史記憶 St-1” 共同決定,tanh 確保狀態值在合理范圍。計算輸出

輸出僅依賴當前記憶St,體現 “歷史信息已壓縮到St中”。示例:若序列長度為 3(t=1,2,3),初始狀態 S?=0(無歷史信息):

1.3 反向傳播(BPTT)與梯度推導
模型訓練依賴時間反向傳播(BPTT):通過鏈式法則回溯所有時間步,計算損失對參數的梯度,再用梯度下降更新參數。假設損失函數為交叉熵損失 Loss = L (Ot, yt)(yt為 T 時刻真實標簽),核心是推導 Loss 對 Wx、Ws、Wo的梯度。1.3.1 核心梯度推導步驟步驟 1:計算 Loss 對輸出Ot的梯度若輸出層用 Softmax 激活 + 交叉熵損失,對單個樣本有:

當i=j時,等于:

當i≠j時,等于:

所以,softmax函數的導數可以表示為:

我們只需要將softmax層的輸出pi,pj代入上面的公式就可以做求導計算了。在多分類任務中,我們通常使用交叉熵損失函數(cross-entropy loss function)來評估模型的性能。交叉熵損失函數的定義如下:

其中yj是真實標簽的one??ot向量,pj是softmax函數的輸出。交叉熵損失函數的作用是衡量模型的預測概率p和真實標簽y之間的差異。交叉熵損失越小,表示模型的預測值越接近真實的標簽。經驗告訴我們,當使用softmax函數作為輸出層激活函數時,最好使用交叉熵作為其損失函數,這是因為交叉熵和softmax函數的結合可以簡化反向傳播的計算。為了證明這一點,我們對交叉熵函數求導:

其中?pj/?zi就是上文推導的softmax的導數,將其代入式中可得:

所以y是one-hot向量,所以:

最后,化簡得到的交叉熵函數的求導公式:

步驟 2:計算 Loss 對隱藏狀態 S?的梯度隱藏狀態St同時影響當前輸出Ot和下一時刻隱藏狀態 St+1,因此梯度需分兩部分:

拆解導數項:

由St+1=tanh(WxXt+1+WsSt+b1)求導:tanh'(x)=1?tanh2(x)因此遞推公式為:

(向量場景需轉置)步驟 3:計算 Loss 對參數的梯度對 Wo的梯度:

(向量場景下為外積)對 Wx的梯度:W?在所有時間步共享,需累加各時間步貢獻:

對Ws的梯度:同理,Ws的梯度為各時間步貢獻的累加:

1.3.2 梯度消失的核心原因:累乘衰減
從Ws的梯度公式可見,遠時刻(如 t=1)對梯度的貢獻需經過多次 tanh'(Sk)?Ws的累乘(k 從 2 到 T):tanh'(Sk) ∈ [0,1](tanh 導數特性,最大值為 1,多數時刻小于 0.5)|Ws| < 1(為避免數值爆炸,初始化時會限制權重范圍)導致累乘項隨時間步指數級衰減,例如:若tanh'(Sk)=0.5,|Ws|=0.8,序列長度T=10,則累乘項 =(0.5×0.8)^9≈0.00026,遠時刻梯度趨近于 0,模型無法捕捉長期依賴。
突破局限:LSTM 的創新設計與梯度推導
1997 年提出的 LSTM,通過“記憶細胞 + 門控機制”實現 “選擇性記憶”—— 保留重要信息、過濾噪聲,從根本上緩解梯度消失。
2.1 核心結構:三門 + 記憶細胞

LSTM 的核心是 “記憶細胞(C?)” 和三個門控,分工明確(以下基于標量簡化):
組件 | 功能 | 激活函數 | 參數(權重 + 偏置) |
記憶細胞 Ct | 長期記憶載體,狀態平緩更新 | 無 | 依賴門控參數 |
遺忘門 ft | 控制保留多少歷史細胞狀態 Ct-1 | σ(Sigmoid,輸出[0,1]) | Wxf、W?f、bf |
更新門 it | 控制加入多少新信息到 Ct | σ(輸出 [0,1]) | Wxi、W?i、bi |
候選記憶 gt | 生成當前時刻的新候選信息 | tanh(輸出[-1,1]) | Wxg、W?g、bg |
輸出門 ot | 控制 Ct輸出到隱藏狀態 ht的比例 | σ(輸出 [0,1]) | Wxo、W?o、bo |
隱藏狀態 ht | 短期記憶,用于當前輸出 | tanh(輸出[-1,1]) | Wyo、bo |
? | 元素相乘 | 無 | 無 |
⊕ | 元素相加 | 無 | 無 |
σ 函數輸出 [0,1],完美適配 “門控控制”(1 = 完全保留,0 = 完全過濾);tanh 確保信息值在合理范圍
2.2 前向傳播:5 步完成記憶更新
LSTM 的前向傳播圍繞 “記憶細胞更新” 展開,步驟清晰:遺忘門:決定 “丟什么”ft=σ(Wxf?Xt+W?f??t?1+bf)例:ft=0.9→保留 90% 歷史記憶Ct?1;ft=0.1→過濾 90% 舊信息。更新門 + 候選記憶:決定 “加什么”更新門:it=σ(Wxi?Xt+W?i?ht?1+bi)(控制新信息的權重)候選記憶:gt=tanh(Wxg?Xt+W?g?ht?1+bg)(當前時刻的新信息)更新記憶細胞:“丟舊 + 加新”Ct=Ct?1?ft+gt?it?為對應元素相乘,Ct同時承載 “長期歷史Ct?1?ft” 和 “當前新信息gt?it”。輸出門:決定 “輸出什么”ot=σ(Wxo?Xt+W?o??t?1+bo)生成隱藏狀態與輸出ht=ot?tanh (Ct)(tanh 將Ct縮至 [-1,1],再通過ot過濾)yt=Wy???t+by(最終預測結果,分類任務需加 Softmax)
2.3 反向傳播與梯度推導
LSTM 的反向傳播仍基于 BPTT,但需同時更新三門參數和記憶細胞相關梯度,核心是確保記憶細胞 C?的梯度穩定傳遞。假設損失 Loss = L (yt,y't)(y't為真實標簽),以下為關鍵梯度推導。
2.3.1 核心梯度 1:Loss 對記憶細胞 C?的導數
記憶細胞Ct同時影響當前隱藏狀態?t和下一時刻記憶細胞Ct+1,梯度公式為:

拆解導數項:?Loss/??t:損失對隱藏狀態的梯度,由輸出層反向推導:

(包含當前輸出和下一時刻四門的貢獻)??t/?Ct=ot?tanh'(Ct)(由?t=ot?tanh (Ct)求導)?Ct+1/?Ct=ft+1(由Ct+1=Ct?ft+1+gt+1?it+1求導)最終遞推公式:

2.3.2 核心梯度 2:Loss 對門控參數的導數(以遺忘門為例)遺忘門參數(Wxf、Whf、bf)的梯度需通過鏈式法則推導:先求 Loss 對遺忘門輸出ft的梯度:

再求 Loss 對遺忘門權重 Wxf的梯度:

(σ函數導數為σ(x)?(1-σ(x)),此處 ft=σ(...),故?ft/?Wxf?ft?(1?ft)?Xt)同理,Loss對Whf的梯度:

更新門、輸出門、候選記憶的參數梯度推導邏輯一致,最終所有參數通過梯度下降(如 Adam 優化器)更新。2.3.3 LSTM 如何緩解梯度消失?對比 RNN 的梯度路徑,LSTM 的記憶細胞梯度傳遞具有決定性優勢:從?Loss?Ct的遞推公式可見,當模型需要保留長期信息時,會通過參數學習使遺忘門ft+1≈1,此時:

由于 tanh'(Ct)∈[0,1],ot∈[0,1],但核心是?Loss/?Ct+1直接傳遞到?Loss/?Ct,無指數級衰減。即使序列長度達到 100+,遠時刻(如 t=1)的梯度仍能穩定傳遞到當前時刻(如 t=100),從而有效捕捉長期依賴。
關鍵補充:模型如何 “學習” 讓ft+1≈1?
遺忘門ft+1的輸出由以下公式決定:ft+1=σ(Wxf?Xt+1+W?f??t+bf)
其中σ是 Sigmoid 函數,當輸入值>2 時,σ(x)≈0.95(接近1)。模型通過以下兩種方式學習讓ft+1≈1:
初始化階段:設置遺忘門偏置 bf>0
工程實踐中,會將遺忘門的偏置bf初始化為1~2(而非默認0),此時即使Wxf?Xt+1+W?f??t=0,ft+1=σ(bf)≈0.73(已較高),為后續學習 “保留長期信息” 奠定基礎。訓練階段:通過損失反向調整參數當模型因 “未保留遠時刻信息” 導致 Loss 升高時,反向傳播會調整Wxf、W?f、bf的取值:若 t=1 的信息對 t=100 的預測很重要,但當前f2=0.1(過濾了 t=1 的信息),則 Loss 會增大;反向傳播時,?Loss/?f2為正值(增加f2可降低 Loss),進而通過?Loss/?Wf調整權重,使f2增大;反復迭代后,模型會學習到 “對重要的長期信息,讓ft+1≈1。
三、RNN vs LSTM:怎么選?
兩大模型各有優劣,需結合場景匹配:
維度 | 循環神經網絡(RNN) | 長短期記憶網絡(LSTM) |
記憶能力 | 僅短期依賴 | 長期依賴(序列長度 100+) |
梯度問題 | 易出現梯度消失,遠時刻信息丟失 | 記憶細胞梯度穩定,緩解梯度消失 |
模型復雜度 | 低(僅 3 組核心參數:W?、W?、W?) | 高(9 組核心參數:3 門 ×3 組權重 + 輸出層權重) |
參數數量 | 少(如隱藏層維度 H=128,輸入維度 D=64,參數量≈1282+128×64=24576) | 多(同上述維度,參數量≈4×(1282+128×64)=98304,約為 RNN 的 4 倍) |
計算效率 | 快(前向 / 反向傳播步驟少) | 慢(門控計算多) |
訓練難度 | 低(參數少,收斂快,易實現) | 高(參數多,易過擬合,需更多數據和正則化) |
核心優勢 | 結構簡單、訓練速度快、資源占用低 | 魯棒性強、長期依賴捕捉能力突出、任務精度高 |
四、工程實踐小貼士4.1 模型選擇策略先簡后繁:先用 RNN 驗證短序列任務可行性,若精度不達標(如測試集準確率 < 85%),再替換為 LSTM;折中方案:若 LSTM 計算壓力大,可選用 GRU(門控循環單元)—— 簡化為重置門和更新門 2 個門,參數量比 LSTM 少 25%,性能接近 LSTM;
數據適配:若序列長度差異大(如文本長度 5-200 詞),可采用 “截斷 + 填充”(固定序列長度)或 “動態批處理”(同批次序列長度一致)。4.2 LSTM 性能優化技巧參數裁剪:隱藏層維度從 256 降至 128,參數量減少 75%,訓練速度提升 2-3 倍;序列分段:將長序列(如 1000 幀音頻)拆分為 10 個 100 幀子序列,采用 “滾動預測” 拼接結果;量化訓練:將 32 位浮點數參數轉為 16 位半精度,顯存占用減少 50%,推理速度提升 1.5 倍;正則化:添加 Dropout(隱藏層 dropout 率 0.2-0.5)、L2 正則化(權重衰減系數 1e-4),緩解過擬合。
4.3 常見問題排查
問題現象 | 可能原因 | 解決方案 |
訓練 loss 不下降 | 1. 學習率過高 / 過低2. 梯度消失(LSTM 遺忘門ft過小) | 1. 調整學習率(如 Adam優化器默認 0.001,可嘗試0.0001-0.01)2. 初始化遺忘門偏置bf為1-2(使ft初始值接近 1) |
測試集 loss 波動大 | 1. 數據量不足2. 序列長度分布不均 | 1. 數據增強(如文本同義詞替換、時序數據加噪)2. 按序列長度分組訓練,平衡各長度樣本占比 |
總結RNN 作為序列建模的 “基石”,以簡單的循環結構開創了歷史信息復用的思路,但受限于梯度消失無法處理長序列;LSTM 則通過記憶細胞和門控機制的創新,從梯度傳遞路徑上解決了長期依賴問題,成為長序列任務的經典方案。盡管當前 Transformer(如 BERT、GPT)在多數序列任務中表現更優,但 RNN 和 LSTM 的核心思想(時序關聯捕捉、選擇性記憶)仍是理解復雜序列模型的基礎,也是 AI 工程師在資源受限場景下的重要選擇。你在項目中用過 RNN 或 LSTM 嗎?遇到過哪些訓練難題?歡迎在評論區分享你的實踐經驗!
本文轉自:秦芯智算
-
神經網絡
+關注
關注
42文章
4838瀏覽量
107744 -
rnn
+關注
關注
0文章
92瀏覽量
7345 -
LSTM
+關注
關注
0文章
63瀏覽量
4378
發布評論請先 登錄
一文讀懂LSTM與RNN:從原理到實戰,掌握序列建模核心技術
評論