Detailed Explanation of Backpropagation Through Time (BPTT) Algorithm for Recurrent Neural Networks (RNN)
1. Problem Description
BPTT (Backpropagation Through Time) is the backpropagation algorithm used for training Recurrent Neural Networks (RNNs). Unlike standard backpropagation, BPTT needs to handle time-series data by unfolding the RNN into a chain structure over time steps to compute the gradients of the loss function with respect to the parameters at each time step. This section will detail the mathematical principles, computational steps of BPTT, and the root causes of the vanishing/exploding gradient problem.
2. Review of RNN Forward Propagation
Assume the input sequence is \(\{x_1, x_2, ..., x_T\}\), the hidden state \(h_t\) and output \(y_t\) are calculated as follows:
\[h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) \]
\[y_t = W_{hy} h_t + b_y \]
where \(W_{hh}, W_{xh}, W_{hy}\) are weight matrices, \(b_h, b_y\) are biases, and the initial hidden state \(h_0\) is often set to 0.
3. Chain Rule Principle of BPTT
Let the loss at time step \(t\) be \(L_t\) (e.g., cross-entropy loss), and the total loss \(L = \sum_{t=1}^T L_t\). The goal of BPTT is to compute \(\frac{\partial L}{\partial W}\) (taking \(W_{hh}\) as an example):
- Since \(h_t\) depends on \(h_{t-1}\), and \(h_{t-1}\) in turn depends on \(W_{hh}\), gradients need to be propagated backward through time.
- The gradient with respect to \(W_{hh}\) is:
\[\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \sum_{k=1}^t \frac{\partial L_t}{\partial h_t} \frac{\partial h_t}{\partial h_k} \frac{\partial h_k}{\partial W_{hh}} \]
where \(\frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}}\) is the key term, representing the propagation of the gradient from step \(t\) back to step \(k\).
4. Specific Calculation Steps
Step 1: Compute the gradient of the loss with respect to the hidden state
Starting from the last time step \(T\), define \(\delta_t = \frac{\partial L}{\partial h_t}\). For \(t = T\):
\[\delta_T = \frac{\partial L_T}{\partial y_T} \frac{\partial y_T}{\partial h_T} = (y_T - \hat{y}_T) W_{hy}^T \]
where \(\hat{y}_T\) is the true label. For \(t < T\):
\[\delta_t = \frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial h_t} + \delta_{t+1} \frac{\partial h_{t+1}}{\partial h_t} \]
The second term reflects the temporal dependency: the current hidden state \(h_t\) affects the future losses from \(L_{t+1}\) to \(L_T\).
Step 2: Compute \(\frac{\partial h_{t+1}}{\partial h_t}\)
From \(h_{t+1} = \tanh(W_{hh} h_t + W_{xh} x_{t+1} + b_h)\), let \(a_t = W_{hh} h_{t-1} + W_{xh} x_t + b_h\), then:
\[\frac{\partial h_{t+1}}{\partial h_t} = W_{hh}^T \cdot \text{diag}(\tanh'(a_{t+1})) \]
where \(\tanh'(z) = 1 - \tanh^2(z)\) is the element-wise derivative.
Step 3: Aggregate parameter gradients
The gradient for \(W_{hh}\) is the sum of contributions from all time steps:
\[\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \text{diag}(\tanh'(a_t)) \delta_t h_{t-1}^T \]
Similarly, \(\frac{\partial L}{\partial W_{xh}}\) and \(\frac{\partial L}{\partial b_h}\) can be computed.
5. Analysis of Vanishing/Exploding Gradient Problem
From \(\frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}}\), the eigenvalues of each Jacobian matrix \(\frac{\partial h_j}{\partial h_{j-1}}\) reflect the degree of gradient scaling. If the eigenvalues \(|\lambda| > 1\) of the weight matrix \(W_{hh}\), the repeated multiplication leads to gradient explosion; if \(|\lambda| < 1\), the gradient decays exponentially to 0.
- Example: If \(\frac{\partial h_j}{\partial h_{j-1}} \approx \alpha I\) (scaled identity matrix), then \(\frac{\partial h_t}{\partial h_k} \approx \alpha^{t-k} I\). When \(t-k\) is large, \(\alpha^{t-k}\) approaches 0 (vanishing) or infinity (explosion).
6. Optimization Methods and Summary
- Gradient Clipping: Limits the gradient norm to alleviate the explosion problem.
- Gated Architectures: LSTMs/GRUs use gating mechanisms to control gradient flow, mitigating the vanishing problem.
- Truncated BPTT: Only propagates gradients back a fixed number of steps, balancing computation and long-term dependencies.
BPTT is the core algorithm for training RNNs, but its gradient issues have driven the development of more advanced sequence models (e.g., Transformer).