Principle and Role of Batch Normalization

Principle and Role of Batch Normalization

1. Problem Background

During the training of deep neural networks, Internal Covariate Shift is a common issue: the distribution of inputs to each layer changes as the parameters of the previous layer update, leading to a training process that requires lower learning rates and more careful parameter initialization, and also results in slower model convergence. Batch Normalization (BN) was proposed to address this problem.


2. The Core Idea of Batch Normalization

The core idea of BN is: Normalize the input of each layer to have zero mean and unit variance, thereby stabilizing the data distribution. However, simple normalization would destroy the network's original representational capacity, so it's necessary to introduce learnable scaling and shifting parameters.


3. Detailed Algorithm Steps

Assume the input to a layer is a batch of data \(\mathbf{X} \in \mathbb{R}^{m \times d}\), where \(m\) is the batch size and \(d\) is the feature dimension.

Step 1: Calculate Batch Mean and Variance

For each feature dimension \(j\) (column), compute:

\[\mu_j = \frac{1}{m} \sum_{i=1}^m x_{ij}, \quad \sigma_j^2 = \frac{1}{m} \sum_{i=1}^m (x_{ij} - \mu_j)^2 \]

(In practice, the variance calculation may use an unbiased estimator or a smoothed version, but the original paper uses the batch variance directly.)

Step 2: Normalization

Standardize each feature value:

\[\hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} \]

where \(\epsilon > 0\) is a very small constant (e.g., \(10^{-5}\)) to prevent division by zero.

Step 3: Scale and Shift (Learnable Parameters)

Introduce two learnable parameters \(\gamma_j\) and \(\beta_j\) to transform the normalized values:

\[y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j \]

This step allows the network to potentially restore the representational capacity of the original distribution (when \(\gamma_j = \sqrt{\sigma_j^2}\) and \(\beta_j = \mu_j\), it's equivalent to no normalization).


4. Differences Between Training and Inference Phases

  • Training Phase: The mean and variance are computed from the current batch data.
  • Inference Phase: Batch statistics from a single sample cannot be used, so global statistics are employed—typically estimated using the Exponential Moving Average (EMA) of the means and variances from all batches during training.

5. The Roles of Batch Normalization

  1. Accelerates Training Convergence: Reduces internal covariate shift, allowing for higher learning rates.
  2. Alleviates Gradient Vanishing/Exploding: Normalization keeps activation values within a stable range, making gradients more controllable.
  3. Provides a Mild Regularization Effect: The noise in the estimation for each batch introduces randomness, similar to Dropout.
  4. Reduces Sensitivity to Parameter Initialization.

6. Considerations and Limitations

  • Small Batch Size Problem: When the batch size is too small, the estimated mean and variance are inaccurate, potentially affecting performance.
  • Challenges in RNNs: The statistics vary dynamically across different time steps in sequence models, requiring adaptations (e.g., Layer Normalization).
  • Interaction with Dropout: Since BN already has a regularization effect, the Dropout rate may need adjustment when used together.

Through the above steps, BN has become a widely adopted component in deep learning, especially indispensable in modern convolutional networks and residual networks.