Weight Decay and Regularization in Softmax Regression

Weight Decay and Regularization in Softmax Regression

Description
In Softmax regression, weight decay (or L2 regularization) is a key technique used to prevent model overfitting. It works by adding a penalty term proportional to the squared norm of the model's weights (parameters) to the original loss function. This constrains the magnitude of the weights, encouraging the model to learn simpler, more generalizable patterns and avoid becoming overly sensitive to noise in the training data.

Detailed Explanation

  1. Background: Softmax Regression and Overfitting

    • Softmax Regression Review: Softmax regression is a generalization of logistic regression for multi-class classification problems. For a classification problem with K classes, the model learns a weight vector W_k for each class k (and a bias term b_k, often folded into W for simplicity). Given an input sample x, the probability that it belongs to class k is given by the Softmax function:
      P(y=k | x) = exp(W_k^T x) / (∑_{j=1}^{K} exp(W_j^T x))
    • Overfitting Risk: When the model has too many parameters or insufficient training data, it might learn specific patterns (or even noise) that exist only in the training set. This leads to good performance on the training set but poor performance on unseen test data, which is known as overfitting. In Softmax regression, if the weights W become very large, the model becomes extremely sensitive to changes in the input features, and the decision boundaries become very "complex" and "sharp," making overfitting likely.
  2. Introducing Regularization: Core Idea

    • Intuition: We generally prefer models with smaller weight values. Smaller weights imply the model does not rely excessively on any single input feature, making its output less sensitive to input variations and thus more robust.
    • Method: To control the magnitude of the weights, we add a regularization term (or penalty term) to the original loss function (typically cross-entropy loss). This term increases as the norm (magnitude) of the weight vectors increases. Consequently, during the optimization process to minimize the loss function, the algorithm must balance minimizing the prediction error (original loss) and keeping the weights small.
  3. Mathematical Form of L2 Regularization (Weight Decay)

    • Original Loss Function: For a training set with N samples, the original cross-entropy loss function for Softmax regression is:
      J(W) = - (1/N) ∑{i=1}^{N} ∑{k=1}^{K} 1{y_i = k} log( P(y_i=k | x_i) )
      where 1{y_i = k} is an indicator function.
    • Adding the Regularization Term: The new loss function J_regularized(W) with L2 regularization becomes:
      J_regularized(W) = J(W) + (λ / 2) ||W||_F^2
      • J(W): The original cross-entropy loss.
      • ||W||_F^2: The squared Frobenius norm of the weight matrix W. This can be understood as summing the squares of all weight parameters (every element in all W_k vectors). That is, ||W||F^2 = ∑{k=1}^{K} ∑{d=1}^{D} W{k,d}^2, where D is the feature dimension.
      • λ: The regularization strength hyperparameter (λ > 0). It controls the importance of the regularization term in the total loss.
        • λ = 0: The regularization term has no effect, reverting to the original loss function.
        • λ is very large: The regularization term dominates, forcing the model to prioritize making weights small, potentially leading to underfitting (the model becomes too simple to capture meaningful patterns in the data).
      • (1/2): Adding 1/2 is for computational convenience when taking derivatives, leading to a cleaner gradient form.
  4. How Regularization Affects Model Optimization (Gradient Descent)
    The model minimizes the loss function via gradient descent. Let's examine how adding the L2 regularization term changes the weight update process.

    • Original Gradient Descent (No Regularization):
      The weight update rule is: W := W - α * ∇_W J(W)
      where α is the learning rate, and ∇_W J(W) is the gradient of the original loss function J(W) with respect to the weights W.
    • Gradient Descent with L2 Regularization:
      First, compute the gradient of the new loss function J_regularized(W) with respect to W:
      ∇_W J_regularized(W) = ∇_W [J(W) + (λ / 2) ||W||_F^2] = ∇_W J(W) + λW
      (Because the derivative of (λ/2) ||W||_F^2 is λW, as the 1/2 cancels with the 2 from the derivative of the square term.)
      The weight update rule then becomes:
      W := W - α * (∇_W J(W) + λW)
      We can rearrange this formula:
      W := W - αλW - α * ∇_W J(W)
      W := (1 - αλ)W - α * ∇_W J(W)
    • Origin of "Weight Decay": Observe the new update formula W := (1 - αλ)W - .... Before each weight update, we multiply the current weight W by a factor slightly less than 1, (1 - αλ). This operation is independent of the gradient and causes the weight values to undergo a small, steady shrinkage (or "decay") at every step, before the normal gradient update is applied. This is the origin of the term "weight decay." The hyperparameters α and λ jointly determine the decay rate.
  5. Summary of L2 Regularization Effects

    • Suppresses Large Weights: By penalizing large weight values, it makes the weight vectors smoother and smaller in magnitude overall.
    • Reduces Model Complexity: This essentially forces the model to use all features but prevents it from relying too heavily on any single one, leading to a simpler, more stable decision boundary.
    • Improves Generalization: By mitigating overfitting, it enhances the model's performance on unseen data (the test set).

Key Points Review

  • Goal: Prevent overfitting in Softmax regression models.
  • Method: Add the squared L2 norm of the weights as a penalty term to the loss function.
  • Core Hyperparameter: λ (regularization strength), which needs to be tuned based on validation set performance.
  • Optimization Impact: In gradient descent, it is equivalent to applying a decay to the weights before each update step.

By introducing weight decay, we add an effective "braking" mechanism to the Softmax regression model, guiding it towards a solution with stronger generalization capabilities.