Joint Gradient Derivation of Softmax Function and Cross-Entropy Loss

Joint Gradient Derivation of Softmax Function and Cross-Entropy Loss

Problem Description
In multi-class classification tasks for neural networks, the Softmax function is often used in combination with cross-entropy loss. Interviews frequently require the derivation of the gradient for this combination, i.e., the partial derivative of the loss function L with respect to the model weights W, ∂L/∂W. Understanding this derivation process is crucial for mastering backpropagation.

Knowledge Explanation

  1. Scenario Setting and Symbol Definitions

    • Problem: We have a K-class classification problem.
    • Model: For an input sample x (feature vector), the model first calculates its "score" (logit) for each class k: z_k = w_k^T x + b_k. Here, w_k is the weight vector for class k, and b_k is the bias term. All scores form the vector z = [z_1, z_2, ..., z_K]^T.
    • Softmax Layer: Transforms the score vector z into a probability distribution vector ŷ.
      ŷ_i = Softmax(z_i) = e^{z_i} / (∑_{j=1}^{K} e^{z_j})
      Here, ŷ_i represents the predicted probability that sample x belongs to class i, and it satisfies ∑_{i=1}^{K} ŷ_i = 1.
    • True Label: The true label y is usually represented by a one-hot vector. For example, if the true class is c, then y_c = 1, and for all j ≠ c, y_j = 0.
    • Cross-Entropy Loss Function: Measures the difference between the predicted probability distribution ŷ and the true distribution y.
      L = - ∑_{i=1}^{K} y_i log(ŷ_i)
      Since y is a one-hot vector (only 1 at the true class c, 0 elsewhere), the loss can be simplified to:
      L = - y_c log(ŷ_c) = - log(ŷ_c)
  2. Objective and Derivation Route

    • Objective: Calculate the gradient of the loss L with respect to a specific weight vector w_k (or a specific score z_k), i.e., ∂L / ∂w_k (or ∂L / ∂z_k). This is the core step of backpropagation.
    • Derivation Route: Apply the chain rule. To compute ∂L / ∂w_k, we decompose it as:
      ∂L / ∂w_k = (∂L / ∂z_k) * (∂z_k / ∂w_k)
      Here, ∂z_k / ∂w_k = x (because z_k = w_k^T x + b_k). Therefore, the key is to find the "upstream" gradient ∂L / ∂z_k, the partial derivative of the loss L with respect to the score z_k. We will derive this for all k (from 1 to K).
  3. Key Step: Calculate ∂L / ∂z_k

    • Application of the Chain Rule: The loss L is a function of ŷ, and each ŷ_j is a function of all z. Therefore, according to the multivariate chain rule:
      ∂L / ∂z_k = ∑_{j=1}^{K} (∂L / ∂ŷ_j) * (∂ŷ_j / ∂z_k)
      This summation is necessary because changing one score z_k affects all predicted probabilities ŷ_j.

    • Calculate the First Term: ∂L / ∂ŷ_j
      From the loss function L = - ∑_{i=1}^{K} y_i log(ŷ_i), taking the derivative with respect to a specific ŷ_j:
      ∂L / ∂ŷ_j = - y_j / ŷ_j
      Note that y_j is a constant (the true label).

    • Calculate the Second Term: ∂ŷ_j / ∂z_k (This is the most crucial and error-prone step.)
      This is the partial derivative of the Softmax function. Two cases need to be discussed because the derivative of ŷ_j with respect to z_k yields different results for j = k and j ≠ k.

      • Case 1: When j = k
        ŷ_j = e^{z_j} / (∑_{m=1}^{K} e^{z_m})
        Let S = ∑_{m=1}^{K} e^{z_m}. Find ∂ŷ_j / ∂z_k (here j=k).
        This uses the quotient rule: ∂ŷ_j / ∂z_k = (e^{z_j} * S - e^{z_j} * e^{z_k}) / S^2
        Since j=k, this simplifies to: (e^{z_j} * S - e^{z_j} * e^{z_j}) / S^2 = (e^{z_j} / S) * (1 - e^{z_j} / S) = ŷ_j (1 - ŷ_j)

      • Case 2: When j ≠ k
        ŷ_j = e^{z_j} / S
        Now take the derivative with respect to z_k (k≠j). Here e^{z_j} is a constant, and S contains e^{z_k}.
        ∂ŷ_j / ∂z_k = (0 * S - e^{z_j} * e^{z_k}) / S^2 = - (e^{z_j} / S) * (e^{z_k} / S) = - ŷ_j ŷ_k

      • Unified Representation:
        ∂ŷ_j / ∂z_k = { ŷ_k (1 - ŷ_k), if j = k -ŷ_j ŷ_k, if j ≠ k }
        This result can be concisely written as: ∂ŷ_j / ∂z_k = ŷ_j (δ_{jk} - ŷ_k), where δ_{jk} is the Kronecker delta function (1 when j=k, 0 otherwise).

    • Combine the Two Terms, Calculate the Final Gradient
      Now substitute both terms into the sum: ∂L / ∂z_k = ∑_{j=1}^{K} (∂L / ∂ŷ_j) * (∂ŷ_j / ∂z_k) = ∑_{j=1}^{K} (- y_j / ŷ_j) * (∂ŷ_j / ∂z_k)
      Substituting the two cases for ∂ŷ_j / ∂z_k into the summation:
      ∂L / ∂z_k = ∑_{j=1}^{K} (- y_j / ŷ_j) * [ŷ_j (δ_{jk} - ŷ_k)]
      Note that ŷ_j magically cancels out in the numerator and denominator!
      ∂L / ∂z_k = - ∑_{j=1}^{K} y_j (δ_{jk} - ŷ_k)
      Expanding the summation:
      ∂L / ∂z_k = - [ ∑_{j=1}^{K} y_j δ_{jk} - ∑_{j=1}^{K} y_j ŷ_k ]

      • First term ∑_{j=1}^{K} y_j δ_{jk}: δ_{jk}=1 only when j=k, so this term is simply y_k.
      • Second term ∑_{j=1}^{K} y_j ŷ_k: ŷ_k is independent of j, so it can be factored out, becoming ŷ_k ∑_{j=1}^{K} y_j. The true label y is a one-hot vector, so the sum of all its elements is 1, i.e., ∑_{j=1}^{K} y_j = 1. Therefore, the second term is ŷ_k * 1 = ŷ_k.
        Thus:
        ∂L / ∂z_k = - (y_k - ŷ_k) = ŷ_k - y_k
    • The Final Elegant Result
      The gradient of the loss L with respect to the k-th class score is:
      ∂L / ∂z_k = ŷ_k - y_k
      This is an extremely concise and important result. It states that the gradient is simply the model's predicted probability minus the true one-hot label.

  4. Completing the Full Gradient Calculation
    Now we return to the initial objective, calculating the gradient with respect to the weight w_k:
    ∂L / ∂w_k = (∂L / ∂z_k) * (∂z_k / ∂w_k) = (ŷ_k - y_k) * x
    Similarly, for the bias term b_k:
    ∂L / ∂b_k = ŷ_k - y_k

Summary
The key to this derivation process lies in proficiently applying the chain rule and the derivative of the Softmax function. The final gradient formula ∂L / ∂z_k = ŷ_k - y_k is remarkably simple in form, making its computation in backpropagation very efficient. This result also intuitively tells us that when the prediction ŷ_k is far from the true value y_k, the gradient is large, requiring a significant update to the model parameters; when the prediction is close to the truth, the gradient becomes smaller, reducing the magnitude of the update.