Basic Principles and Training Process of Generative Adversarial Networks (GANs)

Basic Principles and Training Process of Generative Adversarial Networks (GANs)

Problem Description
A Generative Adversarial Network (GAN) is a framework for training generative models through an adversarial process. It consists of two core components: a Generator (G) and a Discriminator (D). The goal of the generator is to learn the distribution of real data to produce data convincing enough to "pass as real"; the goal of the discriminator is to accurately distinguish whether input data comes from the real dataset or the generator. They co-evolve in a game, ultimately enabling the generator to produce high-quality data.

Basic Principles

  1. Core Idea: The inspiration for GANs comes from zero-sum games in game theory. It models the generation problem as a competition between a generator (G) and a discriminator (D).
  2. Generator (G): Typically a neural network (e.g., a deconvolutional network), it takes a random noise vector z (sampled from a simple distribution like a Gaussian) as input and "maps" or "transforms" it into a fake image G(z). Its objective is to make G(z) as similar as possible to real data.
  3. Discriminator (D): Typically a neural network (e.g., a convolutional network), it takes an image (either a real image x or a generated image G(z)) as input and outputs a scalar probability value representing the likelihood that the image is from real data. Its goal is to make judgments as accurately as possible.

Training Process (Step-by-Step)
GAN training is an iterative process. In each iteration step, the discriminator is typically updated first, followed by the generator.

Step One: Fix the Generator G, Update the Discriminator D
The goal of this step is to improve the discriminator's ability to distinguish.

  1. Sample from the real dataset: Sample m real samples {x^(1), x^(2), ..., x^(m)} from a mini-batch of real data.
  2. Sample from the prior noise: Sample m noise vectors {z^(1), z^(2), ..., z^(m)} from a noise distribution p_z(z) (e.g., standard normal distribution).
  3. Generate fake data: Input the noise vectors into the current generator G to obtain m generated samples {G(z^(1)), G(z^(2)), ..., G(z^(m))}.
  4. Calculate the discriminator loss: The discriminator's objective is to maximize its ability to assign high scores to real data and low scores to generated data. Therefore, its loss function consists of two parts:
    • For real data x, we want D(x) to be close to 1 (i.e., classified as "real").
    • For generated data G(z), we want D(G(z)) to be close to 0 (i.e., classified as "fake").
      Mathematically, the binary cross-entropy loss function is commonly used. The total loss function for the discriminator is:
      \(L_D = -\frac{1}{m} \sum_{i=1}^{m} [\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))]\)
    • The first term \(\log D(x^{(i)})\) encourages D to output a high probability for real samples.
    • The second term \(\log(1 - D(G(z^{(i)})))\) encourages D to output a low probability for generated samples.
      The discriminator's goal is to maximize L_D, i.e., to make this loss value as large as possible (since it is a sum of log probabilities).
  5. Gradient ascent to update D: In practice, we usually minimize an equivalent loss function (i.e., \(-L_D\)) and then use gradient descent. Compute the gradient of the loss \(L_D\) with respect to the discriminator parameters \(\theta_d\), \(\nabla_{\theta_d} L_D\), and then use gradient ascent (or perform gradient descent on \(-L_D\)) to update the discriminator's parameters:
    \(\theta_d \leftarrow \theta_d + \eta \nabla_{\theta_d} L_D\)
    After this step, the discriminator's ability to distinguish is enhanced.

Step Two: Fix the Discriminator D, Update the Generator G
The goal of this step is to improve the generator's "forgery" ability, allowing it to fool the current discriminator.

  1. Sample from the prior noise: Again, sample m new noise vectors {z^(1), z^(2), ..., z^(m)} from the noise distribution p_z(z). (Typically a new batch of noise is used.)
  2. Generate fake data: Input these noise vectors into the generator G to obtain generated samples {G(z^(1)), G(z^(2)), ..., G(z^(m))}.
  3. Calculate the generator loss: The generator's goal is to make the discriminator misjudge the samples it generates, i.e., it wants D(G(z)) to be close to 1 (classified as "real"). Therefore, the generator's loss function is related to the second part of the discriminator's loss:
    \(L_G = -\frac{1}{m} \sum_{i=1}^{m} \log(D(G(z^{(i)})))\)
    • This formula means that the generator wants the probability value D(G(z)) obtained after its generated sample G(z) is judged by the discriminator D to be as large as possible. When D(G(z)) is close to 1, \(\log(D(G(z)))\) is close to 0, and the loss \(L_G\) becomes very small.
      The generator's goal is to minimize L_G.
      (Another common and theoretically more stable form is \(L_G = \frac{1}{m} \sum_{i=1}^{m} \log(1 - D(G(z^{(i)}))\), but its gradient may be poor in the early stages of training, so the above form is more commonly used.)
  4. Gradient descent to update G: Compute the gradient of the loss \(L_G\) with respect to the generator parameters \(\theta_g\), \(\nabla_{\theta_g} L_G\), and then use gradient descent to update the generator's parameters:
    \(\theta_g \leftarrow \theta_g - \eta \nabla_{\theta_g} L_G\)
    After this step, the generator becomes better at producing data that can deceive the current discriminator.

Looping and Convergence

  1. Repeat Step One and Step Two, alternately training the discriminator D and the generator G.
  2. Ideal State (Nash Equilibrium): The ideal state is reached when the distribution of data generated by the generator exactly matches the real data distribution, i.e., \(p_g = p_{data}\). At this point, the discriminator cannot make effective judgments for any input, and its output probability will be constant at 0.5 (i.e., random guessing).
  3. Training Difficulties: The GAN training process is very unstable and prone to issues such as mode collapse (where the generator produces only a few types of samples) or vanishing gradients. Careful design of network architecture, loss functions, and training techniques (such as Wasserstein GAN) is required to mitigate these problems.

Through this adversarial training, the generator and discriminator continuously improve in their mutual competition, ultimately making the generator a powerful data generation model.