Gated Mechanism and Cell State Principles of Long Short-Term Memory Networks (LSTM)
Problem/Topic Description
Long Short-Term Memory (LSTM) networks are a special type of Recurrent Neural Network (RNN) specifically designed to address the long-term dependency problem (i.e., vanishing and exploding gradients) in simple RNNs. Its core innovation lies in the introduction of the "cell state" and a set of "gated mechanisms." This topic requires you to understand the internal structure of LSTM, particularly how the forget gate, input gate, and output gate work together to selectively remember or forget information, thereby enabling effective learning of long-term dependencies.
Problem-Solving Process/Principle Explanation
We will progressively deconstruct the internal workings of a single LSTM computational unit (one time step).
Step 1: Reviewing the Problem—Limitations of Simple RNNs
- Basic Structure: The hidden state
h_tof a simple RNN is a function of the current inputx_tand the previous hidden stateh_{t-1}:h_t = tanh(W * [h_{t-1}, x_t] + b). - Core Problem: During backpropagation through time, the gradient must be multiplied by the same weight matrix
Wrepeatedly. When the eigenvalues of this matrix are less than 1, the gradient decays exponentially to 0 (vanishing gradient), preventing the network from learning dependencies across distant time steps. Conversely, it may lead to exploding gradients.
Step 2: LSTM's Solution—Cell State and Gating
LSTM addresses the above issues by introducing a "highway" that runs through the entire time sequence—the Cell State (C_t), and "gates"—the Gated Mechanisms that protect and control this highway.
-
Cell State (
C_t):- Description: This is the core of LSTM, acting like a conveyor belt running from the beginning to the end of the sequence. It is designed for only slight linear interactions, allowing information to flow easily without drastic changes, facilitating stable gradient propagation.
- Analogy: Imagine the cell state as the main thesis of a paper. During writing, we might add new arguments or revise old ones, but the main thesis remains largely unchanged.
-
Gated Mechanism:
- Description: A gate is a structure that allows information to pass selectively, implemented using a Sigmoid activation function and an element-wise multiplication operation.
- Principle: The Sigmoid function outputs a value between 0 and 1. This value can be interpreted as a "pass-through ratio."
- 0 means "completely block all information" (fully closed).
- 1 means "allow all information to pass" (fully open).
- 0.6 means "allow 60% of the information to pass."
- An LSTM unit contains three such gates: the Forget Gate, the Input Gate, and the Output Gate.
Step 3: Detailed Step-by-Step Workflow of the Three Gates
Taking a time step t as an example, it receives the current input x_t, the previous hidden state h_{t-1}, and the previous cell state C_{t-1}. Its goal is to compute the current hidden state h_t and cell state C_t.
Process Overview (Conceptual):
[h_{t-1}, x_t] -> (Forget Gate + Input Gate) -> update C_{t-1} to C_t -> (Output Gate) -> produce h_t from C_t
Step 1: Decide What Information to Discard/Forget (Forget Gate)
- Purpose: First, we need to decide what information to discard from the old cell state
C_{t-1}. This is done by the "Forget Gate." - Calculation:
- Concatenate the previous hidden state
h_{t-1}and the current inputx_t. - Pass it through a Sigmoid activation function layer (forget gate) to obtain a vector
f_tbetween 0 and 1.
- Formula:
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) - Interpretation: Each value in
f_tcorresponds to a "retention score" for each dimension inC_{t-1}. A value close to 1 means "completely retain this information," close to 0 means "completely forget this information."
- Concatenate the previous hidden state
Step 2: Decide What New Information to Store/Update (Input Gate and Candidate Cell State)
This step consists of two parts, working together to decide what new information to add to the cell state.
-
a) Input Gate: Decide which new information is worth updating
- Purpose: Decide to what extent we will update the cell state with new information.
- Calculation: Similarly, use a Sigmoid layer to generate a vector
i_t.- Formula:
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) - Interpretation:
i_tdetermines to what extent we will adopt the candidate new values to be generated next.
- Formula:
-
b) Candidate Cell State: Generate new candidate values
- Purpose: Create a new vector of candidate cell state values
~C_t, which contains the new information that might be added to the cell state. - Calculation: Use a tanh activation function layer (output range -1 to 1) to generate candidate values.
- Formula:
~C_t = tanh(W_C · [h_{t-1}, x_t] + b_C) - Interpretation:
~C_tis the "proposal" of new information generated from the current input and past state.
- Formula:
- Purpose: Create a new vector of candidate cell state values
Step 3: Update the Cell State (from C_{t-1} to C_t)
- Purpose: Update the old cell state
C_{t-1}to the new cell stateC_t. - Calculation:
- Multiply the old cell state
C_{t-1}byf_t(forget gate output) to forget the information we previously decided to forget. - Multiply the candidate cell state
~C_tbyi_t(input gate output) to add the new information we decided to add. - Sum the results of the above two steps.
- Core Formula:
C_t = f_t * C_{t-1} + i_t * ~C_t - Interpretation: This is the most critical step in LSTM. It combines "selective forgetting" and "selective remembering." The entire update process is additive, unlike the multiplicative nature of simple RNNs. This additive operation greatly alleviates the vanishing gradient problem because gradients can flow more stably along the path of the cell state.
- Multiply the old cell state
Step 4: Decide What to Output Based on the Updated Cell State (Output Gate)
- Purpose: Finally, we need to decide what this time step should output (i.e., the hidden state
h_t) based on the updated cell stateC_t. - Calculation:
- Output Gate: First, run a Sigmoid layer (output gate) to decide which parts of the cell state we will output.
- Formula:
o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
- Formula:
- Modulate the Cell State: Then, pass the cell state
C_tthrough atanhfunction (scaling its values to between -1 and 1) and multiply it by the output gateo_t.- Final Output Formula:
h_t = o_t * tanh(C_t)
- Final Output Formula:
- Interpretation:
h_tis the final output of this LSTM unit at the current time step (and also the hidden state passed to the next time step). It is a filtered portion of the complete information contained in the cell stateC_t, as determined by the output gate.
- Output Gate: First, run a Sigmoid layer (output gate) to decide which parts of the cell state we will output.
Summary
LSTM achieves fine-grained control over the cell state (long-term memory) through its ingenious gated mechanisms:
- Forget Gate: Decides what to discard from long-term memory.
- Input Gate: Decides what new information to store in long-term memory.
- Cell State Update: Updates long-term memory in an additive and smooth manner, which is key to solving the vanishing gradient problem.
- Output Gate: Decides what to read from the current long-term memory as the output for the current moment.
This mechanism enables LSTM to effectively learn and utilize long-term dependencies in time series data, making it a powerful tool for processing sequential data (such as text, speech, and time series forecasting).