自监督学习中的SimSiam(Simple Siamese Networks)方法原理与实现详解
题目描述
在自监督学习领域,SimSiam是一种简单而有效的非对比学习方法,它无需使用负样本、大型批次或动量编码器,就能学习到高质量的视觉表示。本题要求深入理解SimSiam的核心思想、结构设计、损失函数以及其避免表示崩塌(Collapse)的原理,并掌握其实现的关键细节。
循序渐进讲解
步骤1:背景与动机
问题背景: 自监督学习的核心目标是利用无标签数据学习通用的特征表示。对比学习(如SimCLR)通过拉近正样本对(同一图像的不同增强视图)、推开负样本对(不同图像的视图)来实现,但这通常需要大量负样本和/或大型批次,计算成本高。BYOL和MoCo引入了动量编码器等复杂组件。
SimSiam的动机: 提出一个更简单的框架,仅使用孪生网络(Siamese Network) 和停止梯度(Stop-Gradient) 操作,无需上述复杂机制,也能学习有效表示并避免表示崩塌(即所有输入映射到同一特征向量的平凡解)。
步骤2:SimSiam的整体架构与流程
SimSiam的结构极其简洁:
- 输入: 对同一张输入图像 \(x\) 应用两种随机数据增强(如裁剪、颜色抖动),得到两个视图 \(x_1\) 和 \(x_2\)。
- 孪生编码网络: 两个视图共享权重的编码器网络 \(f\)(通常是ResNet等骨干网络),将图像映射为特征向量。得到两个特征输出:\(z_1 = f(x_1)\),\(z_2 = f(x_2)\)。
- 投影MLP: 一个共享权重的多层感知机 \(h\)(称为投影头),将特征向量映射到投影空间。得到投影向量:\(p_1 = h(z_1)\),\(p_2 = h(z_2)\)。
- 预测MLP: 一个额外的多层感知机 \(g\)(称为预测头),只应用于其中一个分支(例如处理 \(p_1\)),输出预测向量。得到预测:\(q_1 = g(p_1)\)。
- 对称损失计算: 计算 \(q_1\) 与 \(p_2\) 之间的相似性损失(如负余弦相似度),同时交换两个视图的角色,计算对称损失。
- 停止梯度(关键操作): 在计算损失时,对其中一个分支的投影输出(如 \(p_2\) )执行停止梯度(stop-grad) 操作,这意味着在反向传播时,梯度不会通过 \(p_2\) 回传到编码器 \(f\) 和投影头 \(h\)。
核心流程公式化:
对于一个视图对:
\[ \mathcal{L}(p_1, q_2) = -\frac{p_1}{\|p_1\|_2} \cdot \frac{q_2}{\|q_2\|_2}, \quad \text{其中 } q_2 = g(h(f(x_2))) \]
总损失是对称的:
\[ \mathcal{L} = \frac{1}{2} [\mathcal{L}(p_1, \text{stopgrad}(q_2)) + \mathcal{L}(p_2, \text{stopgrad}(q_1))] \]
步骤3:核心组件详解
- 编码器 \(f\): 骨干网络,如ResNet。输出是特征表示 \(z\)。
- 投影头 \(h\): 通常是一个3层MLP(全连接层+BN+ReLU),将特征 \(z\) 映射到投影空间。其输出 \(p\) 参与损失计算,但其中一个分支的 \(p\) 会被停止梯度。
- 预测头 \(g\): 一个2层MLP(隐藏层+BN+ReLU,输出层无BN/ReLU),输入是 \(p\),输出是预测 \(q\)。预测头 \(g\) 是SimSiam防止崩塌的关键,它为网络提供了一个非对称的结构,使得两个分支不完全相同。
- 停止梯度(Stop-Gradient): 这是SimSiam最精妙的设计。在对称损失中,对于其中一个视图的投影输出 \(p\),我们将其视为“常数”或“目标”(类似于对比学习中的动量编码器输出),不更新生成它的编码器和投影头参数。这创造了一种动态:一个分支努力使预测 \(q\) 匹配另一个分支的固定目标 \(p\),而目标 \(p\) 本身又由不断更新的网络产生(只是梯度不从此处回传)。这避免了两个分支轻易地协商退化到同一个常数解。
步骤4:SimSiam如何避免表示崩塌?
表示崩塌是指所有输入都映射到相同的输出向量,使得损失达到一个较低的平凡值(例如,如果 \(p\) 和 \(q\) 都是常数向量,它们的余弦相似度可能很高)。SimSiam通过非对称结构 + 停止梯度来避免:
- 非对称性: 预测头 \(g\) 只存在于一个分支。如果网络崩塌,所有 \(p\) 都相同,那么 \(q = g(p)\) 也会相同。但由于 \(g\) 的存在且可学习,如果输入是常数,\(g\) 将很难优化(梯度很小),学习会停滞。实际上,\(g\) 鼓励编码器产生多样化的 \(p\),以便 \(g\) 能学到有意义的变换。
- 停止梯度: 这防止了崩塌的 trivial 解成为一个简单的固定点。如果将停止梯度移除,系统可能迅速崩塌,因为两个分支可以互相适应并收敛到常数解。停止梯度使得目标 \(p\) 不完全跟随预测分支 \(q\) 的变化,引入了一种“追赶”动态,稳定了学习过程。可以将其理解为一种隐式的期望最大化(EM)算法:停止梯度的分支类似于“E步”计算目标表示,带预测头的分支类似于“M步”更新参数以匹配目标。
步骤5:损失函数与优化
- 损失函数: 使用负余弦相似度,等价于L2归一化后的均方误差。
\[ D(p, q) = -\frac{p}{\|p\|_2} \cdot \frac{q}{\|q\|_2} \]
对称总损失:\(\mathcal{L} = \frac{1}{2}(D(p_1, \text{stopgrad}(q_2)) + D(p_2, \text{stopgrad}(q_1)))\)。
- 优化器: 通常使用SGD或Adam,结合余弦学习率衰减。批次大小可以较小(如256),无需像对比学习那样需要超大批次。
- 评估: 预训练完成后,丢弃投影头 \(h\) 和预测头 \(g\),使用编码器 \(f\) 提取的特征 \(z\) 在下游任务(如图像分类)上训练线性分类器或进行微调。
步骤6:关键实现细节
- 预测头的结构: 预测头必须是非线性的(至少包含一个隐藏层),且输出层不应有BN或ReLU,以保持预测的多样性。
- 停止梯度的实现: 在深度学习框架中,可通过
.detach()(PyTorch)或tf.stop_gradient(TensorFlow)实现。 - 对称损失: 必须对两个视图都计算损失并平均,以充分利用数据。
- 数据增强: 与SimCLR等类似,使用随机裁剪、颜色抖动、灰度化、高斯模糊等组合。增强策略对性能至关重要。
- Batch Normalization的使用: 在投影头和预测头中使用BN有助于训练,但需注意BN可能隐含的跨样本通信(在SimSiam中,由于批次较小且BN在单个设备上运行,其影响有限)。
步骤7:总结与意义
SimSiam证明了:
- 表示学习不一定需要负样本、动量编码器或极大批次。
- 简单的孪生网络架构,结合停止梯度和预测头,足以学习强大的特征表示。
- 其成功挑战了之前对比学习的一些“必要”条件,推动了自监督学习向更简洁、高效的方向发展。
核心要点: SimSiam的精髓在于“非对称孪生网络”和“停止梯度”的协同,创造了一个动态的、稳定的优化目标,从而在避免崩塌的同时学习到有区分度的表示。