Numerical Stability and Log Tricks of the Softmax Function

Numerical Stability and Log Tricks of the Softmax Function

Problem Description:
In practical computation of the Softmax function, directly using the original formula may encounter numerical stability issues, especially when input values are very large or very small. We need to understand the source of numerical instability and master the log tricks to improve computational stability.

Analysis of Numerical Instability:

  1. Original Softmax formula: For class j, its probability is P_j = e^{z_j} / Σ_k e^{z_k}
  2. When some z_k is very large (e.g., 1000), e^{z_k} may exceed the representation range of computer floating-point numbers (overflow), resulting in inf
  3. When all z_k are large negative numbers, e^{z_k} may underflow to 0, causing a division-by-zero error (denominator becomes 0)
  4. Even without overflow, division of large numbers can lead to significant numerical errors

Principles of Stability Tricks:

  1. Core idea: The Softmax function has translation invariance, meaning adding or subtracting the same constant from all inputs leaves the output unchanged
  2. Mathematical proof: P_j = e^{z_j}/Σ_k e^{z_k} = e^{z_j-c}/Σ_k e^{z_k-c}, where c is an arbitrary constant
  3. By choosing an appropriate c, we can keep the inputs to the exponential function within a safe numerical range

Specific Implementation Methods:

  1. The most common strategy is to set c = max(z_i), i.e., subtract the maximum value from the input vector
  2. Calculation steps:
    a. Find the maximum value in the input vector z: m = max(z_1, z_2, ..., z_K)
    b. Shift each element: z'_i = z_i - m
    c. Compute the stabilized Softmax: P_j = e^{z'_j} / Σ_k e^{z'_k}
  3. This ensures all inputs to the exponential function are ≤0, with the maximum value of e^{z'_i} being e^0=1, avoiding overflow

Log-Softmax Trick:

  1. When training neural networks, it is often necessary to compute log(Softmax) to avoid numerical underflow
  2. Stable formula for directly computing log(Softmax):
    log(P_j) = z_j - m - log(Σ_k e^{z_k - m})
  3. This formula performs exponentiation first, then summation, and finally takes the logarithm, which is more stable than computing Softmax separately and then taking the logarithm

Practical Application Considerations:

  1. In deep learning frameworks, such as PyTorch's nn.LogSoftmax(), the stable version is directly implemented
  2. When both Softmax and log Softmax need to be computed, log Softmax should be calculated first, and then Softmax can be obtained via exponentiation
  3. This trick is also applicable to the calculation of cross-entropy loss and can be combined into a single step for improved efficiency

Through this numerical stability treatment, the Softmax function can be safely computed across various input ranges while maintaining mathematical equivalence.