Wasserstein Distance in Generative Adversarial Networks (GANs) and the WGAN Improvement
Description
In the training of original Generative Adversarial Networks (GANs), the generator and discriminator perform distribution matching using the Jensen-Shannon divergence. However, this approach is prone to issues like gradient vanishing or mode collapse. The Wasserstein GAN (WGAN) significantly improves training stability by introducing the Wasserstein distance (also known as Earth Mover's Distance) to replace the original loss function. The Wasserstein distance measures the minimum cost required to transform one probability distribution into another. Its key advantage is that it provides meaningful gradients even when the distributions do not overlap.
Solution Process
-
Analysis of Original GAN's Defects
- The discriminator's output in the original GAN passes through a Sigmoid function, with the loss function being binary cross-entropy. When the overlap between the generated distribution and the real distribution is low, the discriminator can easily achieve perfect classification (loss close to 0), leading to gradient vanishing.
- For example, when the real distribution \(P\) and the generated distribution \(Q\) have no overlap, the Jensen-Shannon divergence remains constant at \(\log 2\), the gradient becomes 0, and the generator cannot be updated.
-
Definition of Wasserstein Distance
- The Wasserstein-1 distance is defined as:
\(W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim\gamma} [\|x-y\|]\)
where \(\Pi(P_r, P_g)\) is the set of all joint distributions of \(P_r\) (real distribution) and \(P_g\) (generated distribution), and \(\gamma\) is a coupling of the joint distribution, representing a transport plan for moving mass from \(P_r\) to \(P_g\). - Intuitive understanding: The minimum "work" required to move the "earth" of distribution \(P_r\) to match distribution \(P_g\).
- The Wasserstein-1 distance is defined as:
-
From Theory to Computable Form
- Directly computing the Wasserstein distance is infeasible (requires enumerating all couplings). It is transformed via Kantorovich-Rubinstein duality into:
\(W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x\sim P_r}[f(x)] - \mathbb{E}_{x\sim P_g}[f(x)]\)
where the supremum is taken over all 1-Lipschitz functions (i.e., functions satisfying \(|f(x)-f(y)| \leq |x-y|\)). - Implication: The discriminator \(f\) no longer distinguishes real from fake, but fits a Lipschitz function to maximize the difference in output between real and generated samples.
- Directly computing the Wasserstein distance is infeasible (requires enumerating all couplings). It is transformed via Kantorovich-Rubinstein duality into:
-
Improvements in WGAN
- Loss Function Design:
Discriminator (called Critic in WGAN) loss: \(L_D = \mathbb{E}_{x\sim P_g}[f(x)] - \mathbb{E}_{x\sim P_r}[f(x)]\)
Generator loss: \(L_G = -\mathbb{E}_{x\sim P_g}[f(x)]\) - Weight Clipping: To enforce the Lipschitz constraint, early WGAN forced the discriminator's parameters to be clipped within a \([-c, c]\) interval. For example, setting \(c=0.01\), but this could lead to weakened gradients or parameters clustering at the boundaries.
- Gradient Penalty (WGAN-GP): A subsequent improvement replaced weight clipping with a gradient penalty term:
\(L_{GP} = \lambda \mathbb{E}_{\hat{x}\sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} f(\hat{x})\|_2 - 1)^2]\)
where \(\hat{x}\) is a random interpolation point on the line between real and generated samples (\(\hat{x} = \epsilon x_r + (1-\epsilon) x_g, \epsilon \sim U[0,1]\)), and \(\lambda\) is the penalty coefficient (often set to 10).
- Loss Function Design:
-
Training Process Example
- Step 1: Initialize the generator G and discriminator D (Critic).
- Step 2: Iterate until convergence:
a. Fix G, update D multiple times (e.g., 5 times):- Sample real samples \(\{x_r\}\) and generated samples \(\{x_g\}\).
- Compute D's loss \(L_D\) (add the gradient penalty term if using WGAN-GP).
- Perform backpropagation to update D's parameters.
b. Fix D, update G once: - Sample noise vectors to generate samples, compute \(L_G = -\mathbb{E}[f(G(z))]\).
- Perform backpropagation to update G's parameters.
-
Summary of Key Advantages
- The Wasserstein distance is always continuous, providing smooth gradient signals.
- Training stability is improved, reducing the phenomenon of mode collapse.
- The discriminator (Critic) loss value can approximately reflect the quality of generation (a smaller value indicates the distributions are closer).