生成对抗网络(GAN)中的Wasserstein距离与WGAN改进
字数 1831 2025-11-06 22:53:22
生成对抗网络(GAN)中的Wasserstein距离与WGAN改进
描述
在原始生成对抗网络(GAN)的训练中,生成器与判别器通过Jensen-Shannon散度进行分布匹配,但这种方式容易导致梯度消失或模式崩溃。Wasserstein GAN(WGAN)通过引入Wasserstein距离(也称推土机距离)替代原始损失函数,显著提升了训练稳定性。Wasserstein距离衡量两个概率分布之间相互转化所需的最小成本,其关键优势在于即使分布不相交时仍能提供有意义的梯度。
解题过程
-
原始GAN的缺陷分析
- 原始GAN的判别器输出经过Sigmoid函数,损失函数为二元交叉熵。当生成分布与真实分布重叠度低时,判别器容易达到完美分类(损失接近0),导致梯度消失。
- 例如,真实分布P和生成分布Q没有重叠时,Jensen-Shannon散度恒为log2,梯度为0,生成器无法更新。
-
Wasserstein距离定义
- Wasserstein-1距离定义为:
\(W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim\gamma} [\|x-y\|]\)
其中\(\Pi(P_r, P_g)\)是\(P_r\)(真实分布)和\(P_g\)(生成分布)的所有联合分布集合,\(\gamma\)是联合分布的耦合(coupling),表示从\(P_r\)移动质量到\(P_g\)的运输方案。 - 直观理解:将分布\(P_r\)的“沙土”搬运成分布\(P_g\)所需的最小工作量。
- Wasserstein-1距离定义为:
-
从理论到可计算形式
- 直接计算Wasserstein距离不可行(需要遍历所有耦合)。通过Kantorovich-Rubinstein对偶性转化为:
\(W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x\sim P_r}[f(x)] - \mathbb{E}_{x\sim P_g}[f(x)]\)
其中上确界取自所有1-Lipschitz函数(即满足\(|f(x)-f(y)| \leq |x-y|\)的函数)。 - 意义:判别器\(f\)不再区分真假,而是拟合一个Lipschitz函数,最大化真实样本与生成样本的输出差值。
- 直接计算Wasserstein距离不可行(需要遍历所有耦合)。通过Kantorovich-Rubinstein对偶性转化为:
-
WGAN的改进措施
- 损失函数设计:
判别器(WGAN中称Critic)损失:\(L_D = \mathbb{E}_{x\sim P_g}[f(x)] - \mathbb{E}_{x\sim P_r}[f(x)]\)
生成器损失:\(L_G = -\mathbb{E}_{x\sim P_g}[f(x)]\) - 权重裁剪:为满足Lipschitz约束,早期WGAN强制将判别器参数裁剪到\([-c, c]\)区间。例如设\(c=0.01\),但可能导致梯度减弱或参数聚集在边界。
- 梯度惩罚(WGAN-GP):后续改进通过梯度惩罚项替代权重裁剪:
\(L_{GP} = \lambda \mathbb{E}_{\hat{x}\sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} f(\hat{x})\|_2 - 1)^2]\)
其中\(\hat{x}\)是真实样本与生成样本连线上的随机插值点(\(\hat{x} = \epsilon x_r + (1-\epsilon) x_g, \epsilon \sim U[0,1]\)),\(\lambda\)为惩罚系数(常取10)。
- 损失函数设计:
-
训练流程示例
- 步骤1:初始化生成器G和判别器D(Critic)。
- 步骤2:循环训练直至收敛:
a. 固定G,更新D多次(如5次):- 采样真实样本\(\{x_r\}\)和生成样本\(\{x_g\}\)。
- 计算D的损失\(L_D\)(若用WGAN-GP则加上梯度惩罚项)。
- 反向传播更新D的参数。
b. 固定D,更新G一次: - 采样噪声向量生成样本,计算\(L_G = -\mathbb{E}[f(G(z))]\)。
- 反向传播更新G的参数。
-
关键优势总结
- Wasserstein距离始终连续,提供平滑的梯度信号。
- 训练稳定性提升,减少模式崩溃现象。
- 判别器(Critic)的损失值可近似反映生成质量(值越小说明分布越接近)。