长短期记忆网络(LSTM)中的遗忘门、输入门、输出门控制信号计算与信息流详解
字数 2143 2025-12-14 15:35:14
长短期记忆网络(LSTM)中的遗忘门、输入门、输出门控制信号计算与信息流详解
描述
长短期记忆网络是循环神经网络的改进结构,核心是通过三个门控单元(遗忘门、输入门、输出门)和细胞状态来解决长期依赖问题。本知识点将深入解析每个门控单元的计算过程、控制信号的产生机制,以及信息在LSTM单元中的完整流动路径。
逐步讲解
1. 核心问题背景
传统RNN在反向传播时,梯度需要沿时间步连续相乘,易导致梯度消失/爆炸,难以学习长距离依赖。LSTM通过引入“细胞状态”作为信息高速公路,并让门控单元控制信息的增减,从而让梯度在细胞状态上更稳定地流动。
2. LSTM单元结构概览
一个LSTM单元在时刻t包含以下组成部分:
- 细胞状态(C_t):贯穿时间的水平线,保存长期记忆
- 隐藏状态(h_t):当前时刻的输出,包含短期记忆
- 三个门控:遗忘门(f_t)、输入门(i_t)、输出门(o_t)
- 候选记忆(~C_t):当前输入可能存入的新信息
其中门控本质是sigmoid神经网络层,输出0~1的值,0表示“完全阻止”,1表示“完全放行”。
3. 遗忘门控制信号计算
作用:决定细胞状态C_{t-1}中哪些信息应该被丢弃。
计算步骤:
- 输入:当前输入x_t、上一时刻隐藏状态h_{t-1}
- 参数:权重矩阵W_f、偏置b_f
- 公式:f_t = σ(W_f·[h_{t-1}, x_t] + b_f)
- 细节:将h_{t-1}和x_t拼接为向量,经线性变换后输入sigmoid函数。输出向量f_t每个维度对应C_{t-1}的一个维度,表示该维度信息的保留比例(0表示全忘,1表示全记)。
4. 输入门控制信号与候选记忆生成
作用:决定当前输入中哪些新信息应该存入细胞状态。
包含两个部分:
- 输入门i_t:控制哪些候选值会被更新
i_t = σ(W_i·[h_{t-1}, x_t] + b_i) - 候选记忆~C_t:由当前输入生成的可能新记忆
~C_t = tanh(W_C·[h_{t-1}, x_t] + b_C)
细节:~C_t使用tanh(输出-1~1)作为激活,生成一个候选的向量值;i_t则控制这个候选值有多少会加入细胞状态。
5. 细胞状态更新
作用:将旧记忆与新信息结合,形成更新后的长期记忆。
公式:C_t = f_t ⊙ C_{t-1} + i_t ⊙ ~C_t
步骤分解:
- 遗忘阶段:f_t ⊙ C_{t-1}(逐元素相乘),按比例丢弃旧信息
- 新增阶段:i_t ⊙ ~C_t,按比例添加候选新信息
- 相加:得到更新后的细胞状态C_t
物理意义:这是LSTM的核心,通过加法而非乘法更新,使得梯度在C_t上可以稳定流动(梯度以加法而非连乘方式传递)。
6. 输出门控制信号与隐藏状态生成
作用:基于当前细胞状态,决定输出哪些信息到隐藏状态。
步骤:
- 输出门控制信号:o_t = σ(W_o·[h_{t-1}, x_t] + b_o)
- 当前细胞状态经tanh变换:tanh(C_t)(将C_t值压缩到-1~1,作为可输出的记忆内容)
- 隐藏状态计算:h_t = o_t ⊙ tanh(C_t)
物理意义:o_t控制当前记忆的哪些部分会输出到隐藏状态h_t。h_t同时作为当前时刻输出和下一时刻的输入。
7. 完整信息流总结
以时间顺序梳理:
- 接收输入x_t和上一状态h_{t-1}、C_{t-1}
- 并行计算三个门(f_t, i_t, o_t)和候选记忆~C_t
- 用f_t和i_t更新细胞状态:C_t = f_tC_{t-1} + i_t~C_t
- 用o_t和tanh(C_t)计算隐藏状态h_t
- 输出h_t,并将(C_t, h_t)传递到下一时刻
8. 与梯度消失问题的关系
关键点:细胞状态更新公式C_t = f_t⊙C_{t-1} + i_t⊙~C_t
- 梯度在C_t上的反向传播路径中,包含一条从C_t直接到C_{t-1}的加法路径(通过f_t⊙C_{t-1}项)
- 这条路径的梯度是逐元素乘以f_t,而f_t是sigmoid输出(接近0~1),避免了传统RNN中权重矩阵连乘导致的指数级缩小/放大
- 虽然理论上仍可能存在梯度消失(当f_t接近0时),但网络可以通过学习将f_t设为接近1,从而主动保持梯度流动
9. 实例计算(简化)
假设h_{t-1}和x_t均为二维向量,拼接为4维向量:
- 若遗忘门计算得f_t=[0.8,0.1],表示保留旧细胞状态第一维80%、第二维10%
- 若输入门i_t=[0.7,0.9],候选记忆~C_t=[0.5,-0.3]
- 则C_t = [0.8,0.1]⊙C_{t-1} + [0.7,0.9]⊙[0.5,-0.3]
- 若输出门o_t=[0.6,0.4],则h_t = o_t ⊙ tanh(C_t)
这个门控机制使LSTM可以选择性记住长期信息、遗忘无关信息、输出相关信息,从而有效处理序列中的长期依赖。