生成对抗网络(GAN)的基本原理与训练过程
字数 2476 2025-11-05 23:47:54
生成对抗网络(GAN)的基本原理与训练过程
题目描述
生成对抗网络(GAN)是一种通过对抗过程训练生成模型的框架。它包含两个核心组件:生成器(Generator)和判别器(Discriminator)。生成器的目标是学习真实数据的分布,从而生成足以“以假乱真”的数据;判别器的目标是准确区分输入数据是来自真实数据集还是生成器。两者在博弈中共同进化,最终使生成器能够产出高质量的数据。
基本原理
- 核心思想:GAN的灵感来源于博弈论中的零和博弈。它将生成问题建模为生成器(G)和判别器(D)之间的竞争。
- 生成器(G):通常是一个神经网络(如反卷积网络),它接收一个随机噪声向量z(从简单分布如高斯分布中采样)作为输入,并将其“映射”或“转换”成一张假图像G(z)。其目标是让G(z)尽可能像真实数据。
- 判别器(D):通常是一个神经网络(如卷积网络),它接收一张图像(可以是真实图像x,也可以是生成图像G(z))作为输入,输出一个标量概率值,表示该图像是真实数据的可能性。其目标是尽可能准确地进行判断。
训练过程(循序渐进)
GAN的训练是一个迭代过程,在每个迭代步骤中,通常会先更新判别器,再更新生成器。
步骤一:固定生成器G,更新判别器D
这一步的目标是提升判别器的辨别能力。
- 从真实数据集中采样:从一个批次(mini-batch)的真实数据中采样m个真实样本 {x^(1), x^(2), ..., x^(m)}。
- 从先验噪声中采样:从噪声分布p_z(z)(如标准正态分布)中采样m个噪声向量 {z^(1), z^(2), ..., z^(m)}。
- 生成假数据:将噪声向量输入当前的生成器G,得到m个生成样本 {G(z^(1)), G(z^(2)), ..., G(z^(m))}。
- 计算判别器损失:判别器的目标是最大化它给真实数据打高分、给生成数据打低分的能力。因此,其损失函数由两部分组成:
- 对于真实数据x,我们希望D(x)接近1(即判别为“真”)。
- 对于生成数据G(z),我们希望D(G(z))接近0(即判别为“假”)。
数学上,常用二元交叉熵损失函数。判别器的总损失函数为:
\(L_D = -\frac{1}{m} \sum_{i=1}^{m} [\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))]\) - 这个公式的第一项 \(\log D(x^{(i)})\) 鼓励D对真实样本输出高概率。
- 第二项 \(\log(1 - D(G(z^{(i)})))\) 鼓励D对生成样本输出低概率。
判别器的目标是最大化L_D,即使得这个损失值尽可能大(因为它是log概率的和)。
- 梯度上升更新D:在实践上,我们通常采用最小化一个等价的损失函数(即 \(-L_D\)),然后使用梯度下降法。计算损失 \(L_D\) 关于判别器参数 \(\theta_d\) 的梯度 \(\nabla_{\theta_d} L_D\),然后使用梯度上升(或对 \(-L_D\) 做梯度下降)来更新判别器的参数:
\(\theta_d \leftarrow \theta_d + \eta \nabla_{\theta_d} L_D\)
经过这一步,判别器的辨别能力得到了增强。
步骤二:固定判别器D,更新生成器G
这一步的目标是提升生成器的“造假”能力,让它能骗过当前的判别器。
- 从先验噪声中采样:再次从噪声分布p_z(z)中采样m个新的噪声向量 {z^(1), z^(2), ..., z^(m)}。(通常使用新的一批噪声)
- 生成假数据:将这些噪声向量输入生成器G,得到生成样本 {G(z^(1)), G(z^(2)), ..., G(z^(m))}。
- 计算生成器损失:生成器的目标是让判别器对自己生成的样本做出错误判断,即希望D(G(z))接近1(被判别为“真”)。因此,生成器的损失函数与判别器损失的第二部分相关:
\(L_G = -\frac{1}{m} \sum_{i=1}^{m} \log(D(G(z^{(i)})))\)- 这个公式的意思是,生成器希望它生成的样本G(z)在经过判别器D判断后,得到的概率值D(G(z))尽可能大。当D(G(z))接近1时,\(\log(D(G(z)))\) 接近0,损失 \(L_G\) 就会很小。
生成器的目标是最小化L_G。
(另一种常见且理论上更稳定的形式是 \(L_G = \frac{1}{m} \sum_{i=1}^{m} \log(1 - D(G(z^{(i)}))\),但其梯度在训练初期可能较差,所以上述形式更常用。)
- 这个公式的意思是,生成器希望它生成的样本G(z)在经过判别器D判断后,得到的概率值D(G(z))尽可能大。当D(G(z))接近1时,\(\log(D(G(z)))\) 接近0,损失 \(L_G\) 就会很小。
- 梯度下降更新G:计算损失 \(L_G\) 关于生成器参数 \(\theta_g\) 的梯度 \(\nabla_{\theta_g} L_G\),然后使用梯度下降来更新生成器的参数:
\(\theta_g \leftarrow \theta_g - \eta \nabla_{\theta_g} L_G\)
经过这一步,生成器变得更善于生成能欺骗当前判别器的数据。
循环与收敛
- 重复执行步骤一和步骤二,交替训练判别器D和生成器G。
- 理想状态(纳什均衡):当生成器生成的数据分布与真实数据分布完全一致,即 \(p_g = p_{data}\) 时,达到理想状态。此时,对于任何输入,判别器都无法做出有效判断,其输出概率将恒为0.5(即随机猜测)。
- 训练难点:GAN的训练过程非常不稳定,容易出现模式崩溃(Mode Collapse,即生成器只产生少数几种样本)或梯度消失等问题。需要精心设计网络结构、损失函数和训练技巧(如Wasserstein GAN)来改善。
通过这种对抗性的训练,生成器和判别器在相互博弈中不断进步,最终使生成器成为一个强大的数据生成模型。