自监督学习中的SimSiam(Simple Siamese Networks)方法原理与实现详解
字数 3298 2025-12-12 07:05:58

自监督学习中的SimSiam(Simple Siamese Networks)方法原理与实现详解

题目描述

在自监督学习领域,SimSiam是一种简单而有效的非对比学习方法,它无需使用负样本、大型批次或动量编码器,就能学习到高质量的视觉表示。本题要求深入理解SimSiam的核心思想、结构设计、损失函数以及其避免表示崩塌(Collapse)的原理,并掌握其实现的关键细节。

循序渐进讲解

步骤1:背景与动机

问题背景: 自监督学习的核心目标是利用无标签数据学习通用的特征表示。对比学习(如SimCLR)通过拉近正样本对(同一图像的不同增强视图)、推开负样本对(不同图像的视图)来实现,但这通常需要大量负样本和/或大型批次,计算成本高。BYOL和MoCo引入了动量编码器等复杂组件。
SimSiam的动机: 提出一个更简单的框架,仅使用孪生网络(Siamese Network)停止梯度(Stop-Gradient) 操作,无需上述复杂机制,也能学习有效表示并避免表示崩塌(即所有输入映射到同一特征向量的平凡解)。

步骤2:SimSiam的整体架构与流程

SimSiam的结构极其简洁:

  1. 输入: 对同一张输入图像 \(x\) 应用两种随机数据增强(如裁剪、颜色抖动),得到两个视图 \(x_1\)\(x_2\)
  2. 孪生编码网络: 两个视图共享权重的编码器网络 \(f\)(通常是ResNet等骨干网络),将图像映射为特征向量。得到两个特征输出:\(z_1 = f(x_1)\)\(z_2 = f(x_2)\)
  3. 投影MLP: 一个共享权重的多层感知机 \(h\)(称为投影头),将特征向量映射到投影空间。得到投影向量:\(p_1 = h(z_1)\)\(p_2 = h(z_2)\)
  4. 预测MLP: 一个额外的多层感知机 \(g\)(称为预测头),只应用于其中一个分支(例如处理 \(p_1\)),输出预测向量。得到预测:\(q_1 = g(p_1)\)
  5. 对称损失计算: 计算 \(q_1\)\(p_2\) 之间的相似性损失(如负余弦相似度),同时交换两个视图的角色,计算对称损失。
  6. 停止梯度(关键操作): 在计算损失时,对其中一个分支的投影输出(如 \(p_2\) )执行停止梯度(stop-grad) 操作,这意味着在反向传播时,梯度不会通过 \(p_2\) 回传到编码器 \(f\) 和投影头 \(h\)

核心流程公式化:
对于一个视图对:

\[ \mathcal{L}(p_1, q_2) = -\frac{p_1}{\|p_1\|_2} \cdot \frac{q_2}{\|q_2\|_2}, \quad \text{其中 } q_2 = g(h(f(x_2))) \]

总损失是对称的:

\[ \mathcal{L} = \frac{1}{2} [\mathcal{L}(p_1, \text{stopgrad}(q_2)) + \mathcal{L}(p_2, \text{stopgrad}(q_1))] \]

步骤3:核心组件详解

  1. 编码器 \(f\) 骨干网络,如ResNet。输出是特征表示 \(z\)
  2. 投影头 \(h\) 通常是一个3层MLP(全连接层+BN+ReLU),将特征 \(z\) 映射到投影空间。其输出 \(p\) 参与损失计算,但其中一个分支的 \(p\) 会被停止梯度。
  3. 预测头 \(g\) 一个2层MLP(隐藏层+BN+ReLU,输出层无BN/ReLU),输入是 \(p\),输出是预测 \(q\)预测头 \(g\) 是SimSiam防止崩塌的关键,它为网络提供了一个非对称的结构,使得两个分支不完全相同。
  4. 停止梯度(Stop-Gradient): 这是SimSiam最精妙的设计。在对称损失中,对于其中一个视图的投影输出 \(p\),我们将其视为“常数”或“目标”(类似于对比学习中的动量编码器输出),不更新生成它的编码器和投影头参数。这创造了一种动态:一个分支努力使预测 \(q\) 匹配另一个分支的固定目标 \(p\),而目标 \(p\) 本身又由不断更新的网络产生(只是梯度不从此处回传)。这避免了两个分支轻易地协商退化到同一个常数解。

步骤4:SimSiam如何避免表示崩塌?

表示崩塌是指所有输入都映射到相同的输出向量,使得损失达到一个较低的平凡值(例如,如果 \(p\)\(q\) 都是常数向量,它们的余弦相似度可能很高)。SimSiam通过非对称结构 + 停止梯度来避免:

  • 非对称性: 预测头 \(g\) 只存在于一个分支。如果网络崩塌,所有 \(p\) 都相同,那么 \(q = g(p)\) 也会相同。但由于 \(g\) 的存在且可学习,如果输入是常数,\(g\) 将很难优化(梯度很小),学习会停滞。实际上,\(g\) 鼓励编码器产生多样化的 \(p\),以便 \(g\) 能学到有意义的变换。
  • 停止梯度: 这防止了崩塌的 trivial 解成为一个简单的固定点。如果将停止梯度移除,系统可能迅速崩塌,因为两个分支可以互相适应并收敛到常数解。停止梯度使得目标 \(p\) 不完全跟随预测分支 \(q\) 的变化,引入了一种“追赶”动态,稳定了学习过程。可以将其理解为一种隐式的期望最大化(EM)算法:停止梯度的分支类似于“E步”计算目标表示,带预测头的分支类似于“M步”更新参数以匹配目标。

步骤5:损失函数与优化

  • 损失函数: 使用负余弦相似度,等价于L2归一化后的均方误差。

\[ D(p, q) = -\frac{p}{\|p\|_2} \cdot \frac{q}{\|q\|_2} \]

对称总损失:\(\mathcal{L} = \frac{1}{2}(D(p_1, \text{stopgrad}(q_2)) + D(p_2, \text{stopgrad}(q_1)))\)

  • 优化器: 通常使用SGD或Adam,结合余弦学习率衰减。批次大小可以较小(如256),无需像对比学习那样需要超大批次。
  • 评估: 预训练完成后,丢弃投影头 \(h\) 和预测头 \(g\),使用编码器 \(f\) 提取的特征 \(z\) 在下游任务(如图像分类)上训练线性分类器或进行微调。

步骤6:关键实现细节

  1. 预测头的结构: 预测头必须是非线性的(至少包含一个隐藏层),且输出层不应有BN或ReLU,以保持预测的多样性。
  2. 停止梯度的实现: 在深度学习框架中,可通过 .detach()(PyTorch)或 tf.stop_gradient(TensorFlow)实现。
  3. 对称损失: 必须对两个视图都计算损失并平均,以充分利用数据。
  4. 数据增强: 与SimCLR等类似,使用随机裁剪、颜色抖动、灰度化、高斯模糊等组合。增强策略对性能至关重要。
  5. Batch Normalization的使用: 在投影头和预测头中使用BN有助于训练,但需注意BN可能隐含的跨样本通信(在SimSiam中,由于批次较小且BN在单个设备上运行,其影响有限)。

步骤7:总结与意义

SimSiam证明了:

  • 表示学习不一定需要负样本、动量编码器或极大批次
  • 简单的孪生网络架构,结合停止梯度和预测头,足以学习强大的特征表示
  • 其成功挑战了之前对比学习的一些“必要”条件,推动了自监督学习向更简洁、高效的方向发展。

核心要点: SimSiam的精髓在于“非对称孪生网络”和“停止梯度”的协同,创造了一个动态的、稳定的优化目标,从而在避免崩塌的同时学习到有区分度的表示。

自监督学习中的SimSiam(Simple Siamese Networks)方法原理与实现详解 题目描述 在自监督学习领域,SimSiam是一种简单而有效的非对比学习方法,它无需使用负样本、大型批次或动量编码器,就能学习到高质量的视觉表示。本题要求深入理解SimSiam的核心思想、结构设计、损失函数以及其避免表示崩塌(Collapse)的原理,并掌握其实现的关键细节。 循序渐进讲解 步骤1:背景与动机 问题背景: 自监督学习的核心目标是利用无标签数据学习通用的特征表示。对比学习(如SimCLR)通过拉近正样本对(同一图像的不同增强视图)、推开负样本对(不同图像的视图)来实现,但这通常需要大量负样本和/或大型批次,计算成本高。BYOL和MoCo引入了动量编码器等复杂组件。 SimSiam的动机: 提出一个更简单的框架,仅使用 孪生网络(Siamese Network) 和 停止梯度(Stop-Gradient) 操作,无需上述复杂机制,也能学习有效表示并避免表示崩塌(即所有输入映射到同一特征向量的平凡解)。 步骤2:SimSiam的整体架构与流程 SimSiam的结构极其简洁: 输入: 对同一张输入图像 \( x \) 应用两种随机数据增强(如裁剪、颜色抖动),得到两个视图 \( x_ 1 \) 和 \( x_ 2 \)。 孪生编码网络: 两个视图共享权重的编码器网络 \( f \)(通常是ResNet等骨干网络),将图像映射为特征向量。得到两个特征输出:\( z_ 1 = f(x_ 1) \),\( z_ 2 = f(x_ 2) \)。 投影MLP: 一个共享权重的多层感知机 \( h \)(称为投影头),将特征向量映射到投影空间。得到投影向量:\( p_ 1 = h(z_ 1) \),\( p_ 2 = h(z_ 2) \)。 预测MLP: 一个额外的多层感知机 \( g \)(称为预测头),只应用于其中一个分支(例如处理 \( p_ 1 \)),输出预测向量。得到预测:\( q_ 1 = g(p_ 1) \)。 对称损失计算: 计算 \( q_ 1 \) 与 \( p_ 2 \) 之间的相似性损失(如负余弦相似度),同时交换两个视图的角色,计算对称损失。 停止梯度(关键操作): 在计算损失时,对其中一个分支的投影输出(如 \( p_ 2 \) )执行 停止梯度(stop-grad) 操作,这意味着在反向传播时,梯度不会通过 \( p_ 2 \) 回传到编码器 \( f \) 和投影头 \( h \)。 核心流程公式化: 对于一个视图对: \[ \mathcal{L}(p_ 1, q_ 2) = -\frac{p_ 1}{\|p_ 1\|_ 2} \cdot \frac{q_ 2}{\|q_ 2\|_ 2}, \quad \text{其中 } q_ 2 = g(h(f(x_ 2))) \] 总损失是对称的: \[ \mathcal{L} = \frac{1}{2} [ \mathcal{L}(p_ 1, \text{stopgrad}(q_ 2)) + \mathcal{L}(p_ 2, \text{stopgrad}(q_ 1)) ] \] 步骤3:核心组件详解 编码器 \( f \): 骨干网络,如ResNet。输出是特征表示 \( z \)。 投影头 \( h \): 通常是一个3层MLP(全连接层+BN+ReLU),将特征 \( z \) 映射到投影空间。其输出 \( p \) 参与损失计算,但其中一个分支的 \( p \) 会被停止梯度。 预测头 \( g \): 一个2层MLP(隐藏层+BN+ReLU,输出层无BN/ReLU),输入是 \( p \),输出是预测 \( q \)。 预测头 \( g \) 是SimSiam防止崩塌的关键 ,它为网络提供了一个非对称的结构,使得两个分支不完全相同。 停止梯度(Stop-Gradient): 这是SimSiam最精妙的设计。在对称损失中,对于其中一个视图的投影输出 \( p \),我们将其视为“常数”或“目标”(类似于对比学习中的动量编码器输出),不更新生成它的编码器和投影头参数。这创造了一种 动态 :一个分支努力使预测 \( q \) 匹配另一个分支的固定目标 \( p \),而目标 \( p \) 本身又由不断更新的网络产生(只是梯度不从此处回传)。这避免了两个分支轻易地协商退化到同一个常数解。 步骤4:SimSiam如何避免表示崩塌? 表示崩塌是指所有输入都映射到相同的输出向量,使得损失达到一个较低的平凡值(例如,如果 \( p \) 和 \( q \) 都是常数向量,它们的余弦相似度可能很高)。SimSiam通过 非对称结构 + 停止梯度 来避免: 非对称性: 预测头 \( g \) 只存在于一个分支。如果网络崩塌,所有 \( p \) 都相同,那么 \( q = g(p) \) 也会相同。但由于 \( g \) 的存在且可学习,如果输入是常数,\( g \) 将很难优化(梯度很小),学习会停滞。实际上,\( g \) 鼓励编码器产生多样化的 \( p \),以便 \( g \) 能学到有意义的变换。 停止梯度: 这防止了崩塌的 trivial 解成为一个简单的固定点。如果将停止梯度移除,系统可能迅速崩塌,因为两个分支可以互相适应并收敛到常数解。停止梯度使得目标 \( p \) 不完全跟随预测分支 \( q \) 的变化,引入了一种“追赶”动态,稳定了学习过程。 可以将其理解为一种隐式的期望最大化(EM)算法 :停止梯度的分支类似于“E步”计算目标表示,带预测头的分支类似于“M步”更新参数以匹配目标。 步骤5:损失函数与优化 损失函数: 使用负余弦相似度,等价于L2归一化后的均方误差。 \[ D(p, q) = -\frac{p}{\|p\|_ 2} \cdot \frac{q}{\|q\|_ 2} \] 对称总损失:\( \mathcal{L} = \frac{1}{2}(D(p_ 1, \text{stopgrad}(q_ 2)) + D(p_ 2, \text{stopgrad}(q_ 1))) \)。 优化器: 通常使用SGD或Adam,结合余弦学习率衰减。批次大小可以较小(如256),无需像对比学习那样需要超大批次。 评估: 预训练完成后,丢弃投影头 \( h \) 和预测头 \( g \),使用编码器 \( f \) 提取的特征 \( z \) 在下游任务(如图像分类)上训练线性分类器或进行微调。 步骤6:关键实现细节 预测头的结构: 预测头必须是非线性的(至少包含一个隐藏层),且输出层不应有BN或ReLU,以保持预测的多样性。 停止梯度的实现: 在深度学习框架中,可通过 .detach() (PyTorch)或 tf.stop_gradient (TensorFlow)实现。 对称损失: 必须对两个视图都计算损失并平均,以充分利用数据。 数据增强: 与SimCLR等类似,使用随机裁剪、颜色抖动、灰度化、高斯模糊等组合。增强策略对性能至关重要。 Batch Normalization的使用: 在投影头和预测头中使用BN有助于训练,但需注意BN可能隐含的跨样本通信(在SimSiam中,由于批次较小且BN在单个设备上运行,其影响有限)。 步骤7:总结与意义 SimSiam证明了: 表示学习 不一定需要负样本、动量编码器或极大批次 。 简单的孪生网络架构,结合停止梯度和预测头,足以学习强大的特征表示 。 其成功挑战了之前对比学习的一些“必要”条件,推动了自监督学习向更简洁、高效的方向发展。 核心要点: SimSiam的精髓在于“非对称孪生网络”和“停止梯度”的协同,创造了一个动态的、稳定的优化目标,从而在避免崩塌的同时学习到有区分度的表示。