生成对抗网络(GAN)中的Wasserstein距离与WGAN改进
字数 1831 2025-11-06 22:53:22

生成对抗网络(GAN)中的Wasserstein距离与WGAN改进

描述
在原始生成对抗网络(GAN)的训练中,生成器与判别器通过Jensen-Shannon散度进行分布匹配,但这种方式容易导致梯度消失或模式崩溃。Wasserstein GAN(WGAN)通过引入Wasserstein距离(也称推土机距离)替代原始损失函数,显著提升了训练稳定性。Wasserstein距离衡量两个概率分布之间相互转化所需的最小成本,其关键优势在于即使分布不相交时仍能提供有意义的梯度。

解题过程

  1. 原始GAN的缺陷分析

    • 原始GAN的判别器输出经过Sigmoid函数,损失函数为二元交叉熵。当生成分布与真实分布重叠度低时,判别器容易达到完美分类(损失接近0),导致梯度消失。
    • 例如,真实分布P和生成分布Q没有重叠时,Jensen-Shannon散度恒为log2,梯度为0,生成器无法更新。
  2. Wasserstein距离定义

    • Wasserstein-1距离定义为:
      \(W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim\gamma} [\|x-y\|]\)
      其中\(\Pi(P_r, P_g)\)\(P_r\)(真实分布)和\(P_g\)(生成分布)的所有联合分布集合,\(\gamma\)是联合分布的耦合(coupling),表示从\(P_r\)移动质量到\(P_g\)的运输方案。
    • 直观理解:将分布\(P_r\)的“沙土”搬运成分布\(P_g\)所需的最小工作量。
  3. 从理论到可计算形式

    • 直接计算Wasserstein距离不可行(需要遍历所有耦合)。通过Kantorovich-Rubinstein对偶性转化为:
      \(W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x\sim P_r}[f(x)] - \mathbb{E}_{x\sim P_g}[f(x)]\)
      其中上确界取自所有1-Lipschitz函数(即满足\(|f(x)-f(y)| \leq |x-y|\)的函数)。
    • 意义:判别器\(f\)不再区分真假,而是拟合一个Lipschitz函数,最大化真实样本与生成样本的输出差值。
  4. WGAN的改进措施

    • 损失函数设计
      判别器(WGAN中称Critic)损失:\(L_D = \mathbb{E}_{x\sim P_g}[f(x)] - \mathbb{E}_{x\sim P_r}[f(x)]\)
      生成器损失:\(L_G = -\mathbb{E}_{x\sim P_g}[f(x)]\)
    • 权重裁剪:为满足Lipschitz约束,早期WGAN强制将判别器参数裁剪到\([-c, c]\)区间。例如设\(c=0.01\),但可能导致梯度减弱或参数聚集在边界。
    • 梯度惩罚(WGAN-GP):后续改进通过梯度惩罚项替代权重裁剪:
      \(L_{GP} = \lambda \mathbb{E}_{\hat{x}\sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} f(\hat{x})\|_2 - 1)^2]\)
      其中\(\hat{x}\)是真实样本与生成样本连线上的随机插值点(\(\hat{x} = \epsilon x_r + (1-\epsilon) x_g, \epsilon \sim U[0,1]\)),\(\lambda\)为惩罚系数(常取10)。
  5. 训练流程示例

    • 步骤1:初始化生成器G和判别器D(Critic)。
    • 步骤2:循环训练直至收敛:
      a. 固定G,更新D多次(如5次):
      • 采样真实样本\(\{x_r\}\)和生成样本\(\{x_g\}\)
      • 计算D的损失\(L_D\)(若用WGAN-GP则加上梯度惩罚项)。
      • 反向传播更新D的参数。
        b. 固定D,更新G一次:
      • 采样噪声向量生成样本,计算\(L_G = -\mathbb{E}[f(G(z))]\)
      • 反向传播更新G的参数。
  6. 关键优势总结

    • Wasserstein距离始终连续,提供平滑的梯度信号。
    • 训练稳定性提升,减少模式崩溃现象。
    • 判别器(Critic)的损失值可近似反映生成质量(值越小说明分布越接近)。
生成对抗网络(GAN)中的Wasserstein距离与WGAN改进 描述 在原始生成对抗网络(GAN)的训练中,生成器与判别器通过Jensen-Shannon散度进行分布匹配,但这种方式容易导致梯度消失或模式崩溃。Wasserstein GAN(WGAN)通过引入Wasserstein距离(也称推土机距离)替代原始损失函数,显著提升了训练稳定性。Wasserstein距离衡量两个概率分布之间相互转化所需的最小成本,其关键优势在于即使分布不相交时仍能提供有意义的梯度。 解题过程 原始GAN的缺陷分析 原始GAN的判别器输出经过Sigmoid函数,损失函数为二元交叉熵。当生成分布与真实分布重叠度低时,判别器容易达到完美分类(损失接近0),导致梯度消失。 例如,真实分布P和生成分布Q没有重叠时,Jensen-Shannon散度恒为log2,梯度为0,生成器无法更新。 Wasserstein距离定义 Wasserstein-1距离定义为: \( W(P_ r, P_ g) = \inf_ {\gamma \in \Pi(P_ r, P_ g)} \mathbb{E}_ {(x,y)\sim\gamma} [ \|x-y\| ] \) 其中\(\Pi(P_ r, P_ g)\)是\(P_ r\)(真实分布)和\(P_ g\)(生成分布)的所有联合分布集合,\(\gamma\)是联合分布的耦合(coupling),表示从\(P_ r\)移动质量到\(P_ g\)的运输方案。 直观理解:将分布\(P_ r\)的“沙土”搬运成分布\(P_ g\)所需的最小工作量。 从理论到可计算形式 直接计算Wasserstein距离不可行(需要遍历所有耦合)。通过Kantorovich-Rubinstein对偶性转化为: \( W(P_ r, P_ g) = \sup_ {\|f\| L \leq 1} \mathbb{E} {x\sim P_ r}[ f(x)] - \mathbb{E}_ {x\sim P_ g}[ f(x) ] \) 其中上确界取自所有1-Lipschitz函数(即满足\(|f(x)-f(y)| \leq |x-y|\)的函数)。 意义:判别器\(f\)不再区分真假,而是拟合一个Lipschitz函数,最大化真实样本与生成样本的输出差值。 WGAN的改进措施 损失函数设计 : 判别器(WGAN中称Critic)损失:\( L_ D = \mathbb{E} {x\sim P_ g}[ f(x)] - \mathbb{E} {x\sim P_ r}[ f(x) ] \) 生成器损失:\( L_ G = -\mathbb{E}_ {x\sim P_ g}[ f(x) ] \) 权重裁剪 :为满足Lipschitz约束,早期WGAN强制将判别器参数裁剪到\([ -c, c ]\)区间。例如设\(c=0.01\),但可能导致梯度减弱或参数聚集在边界。 梯度惩罚(WGAN-GP) :后续改进通过梯度惩罚项替代权重裁剪: \( L_ {GP} = \lambda \mathbb{E} {\hat{x}\sim P {\hat{x}}}[ (\|\nabla_ {\hat{x}} f(\hat{x})\|_ 2 - 1)^2 ] \) 其中\(\hat{x}\)是真实样本与生成样本连线上的随机插值点(\( \hat{x} = \epsilon x_ r + (1-\epsilon) x_ g, \epsilon \sim U[ 0,1 ] \)),\(\lambda\)为惩罚系数(常取10)。 训练流程示例 步骤1:初始化生成器G和判别器D(Critic)。 步骤2:循环训练直至收敛: a. 固定G,更新D多次(如5次): 采样真实样本\(\{x_ r\}\)和生成样本\(\{x_ g\}\)。 计算D的损失\(L_ D\)(若用WGAN-GP则加上梯度惩罚项)。 反向传播更新D的参数。 b. 固定D,更新G一次: 采样噪声向量生成样本,计算\(L_ G = -\mathbb{E}[ f(G(z)) ]\)。 反向传播更新G的参数。 关键优势总结 Wasserstein距离始终连续,提供平滑的梯度信号。 训练稳定性提升,减少模式崩溃现象。 判别器(Critic)的损失值可近似反映生成质量(值越小说明分布越接近)。