生成对抗网络(GAN)中的Wasserstein距离与WGAN改进
题目描述
生成对抗网络(GAN)在训练中常面临模式崩溃和训练不稳定的问题。传统GAN通过Jensen-Shannon散度来衡量真实分布与生成分布之间的差异,但在两分布不重叠或重叠可忽略时,其梯度会消失,导致训练困难。Wasserstein GAN(WGAN)通过引入Wasserstein距离(又称Earth-Mover距离)替代Jensen-Shannon散度,有效缓解了梯度消失问题,提供了更稳定的训练过程和损失指标。本题目将详细解释Wasserstein距离的动机、定义,以及WGAN如何利用此距离改进GAN的训练。
解题过程
步骤1:传统GAN的训练问题根源
- 核心目标:GAN包含一个生成器G和一个判别器D。判别器的目标是区分真实样本和生成样本,生成器的目标是欺骗判别器。理想状态是达到纳什均衡。
- 损失函数:原始GAN的损失函数基于Jensen-Shannon散度。当两分布(真实分布\(P_r\)和生成分布\(P_g\))不重叠或重叠可忽略时,Jensen-Shannon散度为常数,梯度几乎为零,导致生成器无法获得有效的梯度更新。
- 后果:判别器过早达到最优,梯度消失,生成器停止优化。这表现为模式崩溃(生成样本多样性低)和训练不稳定。
步骤2:Wasserstein距离的引入动机
- 核心思想:Wasserstein距离衡量将一个分布“移动”成另一个分布所需的最小“工作量”。即使两分布无重叠,该距离仍能提供平滑的度量,从而避免梯度消失。
- 直观比喻:将两堆土(分布)相互搬运所需的最小土方量。距离值越小,两堆土形状越相似。
- 数学优势:Wasserstein距离是连续的,即使在分布不重叠时也能提供有意义的梯度信号。
步骤3:Wasserstein距离的定义
- 形式化定义:对于真实分布\(P_r\)和生成分布\(P_g\),其Wasserstein-1距离(Earth-Mover距离)定义为:
\[ W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\|x - y\|] \]
其中,\(\Pi(P_r, P_g)\)是\(P_r\)和\(P_g\)所有可能的联合分布集合,每个联合分布\(\gamma\)表示一个“搬运方案”,\(\|x - y\|\)是搬运成本。
2. 直观理解:在所有可能的联合分布中,寻找使期望成本最小的那个。这个最小期望成本就是Wasserstein距离。
步骤4:从理论距离到可优化损失(Kantorovich-Rubinstein对偶)
- 优化难题:上述原始定义涉及在所有联合分布上求下确界,难以直接计算。
- 关键转换:通过Kantorovich-Rubinstein对偶定理,Wasserstein距离可转化为:
\[ W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \left[ \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)] \right] \]
其中,\(f\)是满足1-Lipschitz连续的函数集合中的任一函数(即函数变化率不超过1)。
3. 对偶形式的意义:寻找一个函数\(f\),使真实样本的期望值与生成样本的期望值之差最大,但函数本身需满足Lipschitz约束(梯度不超过1)。
步骤5:WGAN的算法实现
- 判别器改为Critic:在WGAN中,判别器D不再输出样本为真的概率,而是输出一个评分(称为Critic),其目标是最大化上述对偶形式中的差值。
- 损失函数:
- Critic损失:\(L_D = -\left( \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim p(z)}[D(G(z))] \right)\),需最小化。
- 生成器损失:\(L_G = -\mathbb{E}_{z \sim p(z)}[D(G(z))]\),需最小化。
- Lipschitz约束的实现:为确保Critic满足1-Lipschitz约束,WGAN提出权重裁剪(Weight Clipping),即将Critic的参数限制在固定范围(如[-0.01, 0.01])。这强制Critic成为K-Lipschitz函数。
- 训练流程:
- 对Critic进行多次更新(如5次),每次更新后裁剪权重。
- 对生成器进行一次更新。
- 重复直至收敛。
步骤6:WGAN的改进与后续发展
- 权重裁剪的问题:权重裁剪可能导致Critic学习能力受限(参数聚集在裁剪边界),梯度消失或爆炸。
- WGAN-GP的改进:提出梯度惩罚(Gradient Penalty)替代权重裁剪。在损失函数中添加一项,强制Critic的梯度范数接近1:
\[ \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}\left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right] \]
其中,\(\hat{x}\)是真实样本和生成样本连线上的随机插值点。这使Critic更稳定地满足Lipschitz约束。
3. 优势总结:
- 训练更稳定,避免模式崩溃。
- Critic损失可反映生成样本质量,适合作为训练监控指标。
- 对网络架构和超参数更鲁棒。
步骤7:关键点回顾
- Wasserstein距离通过“最小搬运成本”衡量分布差异,提供平滑梯度。
- 通过对偶形式将优化问题转化为寻找Lipschitz函数的最大化差值问题。
- WGAN用Critic替代判别器,损失函数基于Wasserstein距离的对偶形式。
- 通过权重裁剪或梯度惩罚实现Lipschitz约束。
- WGAN-GP进一步提高了训练稳定性和生成质量。
通过以上步骤,你可以理解Wasserstein距离如何解决传统GAN的梯度消失问题,以及WGAN如何将其转化为可优化的损失函数。掌握这一知识点,有助于设计更稳定的生成模型。