自监督学习中的DINO(Distillation with No Labels)方法原理与实现详解
描述:DINO(Distillation with No Labels)是一种在图像领域取得显著效果的自监督学习方法,其核心思想是在无标签数据上,通过知识蒸馏(Knowledge Distillation)的方式,让一个“学生网络”学习一个“教师网络”的表示。该方法的关键创新在于,教师网络的参数是学生网络参数的指数移动平均(Exponential Moving Average, EMA),并且通过避免使用负样本对的对比学习,避免了大型负样本内存库的需求,实现了高效且强大的视觉表示学习。
我将分步为您讲解其核心思想、实现细节和关键技巧。
第一步:DINO的基本框架与核心直觉
-
直觉来源:在监督学习中,知识蒸馏通常让一个小的学生网络去模仿一个大的、训练好的教师网络的输出(如分类概率)。DINO将这一思想迁移到无标签场景,核心问题是:在没有标签的情况下,教师网络应该预测什么来指导学生网络?
-
基本框架:DINO框架包含两个结构相同但参数不同的神经网络:学生网络 \(g_{\theta_s}\) 和教师网络 \(g_{\theta_t}\)。它们都包含一个主干网络(如ViT、ResNet)和一个投影头(Projection Head,将特征映射到一个紧凑的表示空间)。
- 输入:一张无标签图片 \(x\)。
- 数据增强:对 \(x\) 应用两种不同的随机增强(裁剪、颜色扰动、高斯模糊等),得到两个视图(views):局部视图 \(x_1\)(例如,随机小裁剪)和全局视图 \(x_2\)(例如,中心大裁剪或原图)。局部视图通常作为学生网络的输入,全局视图作为教师网络的输入。这是鼓励模型学习局部到全局的对应关系,即从图片的局部细节推断出全局的语义信息。
- 前向传播:学生网络处理局部视图,得到输出 \(P_s\);教师网络处理全局视图,得到输出 \(P_t\)。
- 目标:让学生网络的输出 \(P_s\) 与教师网络的输出 \(P_t\) 尽可能一致。
第二步:教师网络的更新机制——EMA(指数移动平均)
这是DINO最关键的技巧之一。教师网络并不是一个预训练好的固定网络,而是学生网络的“历史平均版本”。
- 更新公式:在每个训练步骤(step)结束时,教师网络的参数 \(\theta_t\) 会按照以下规则更新:
\[ \theta_t \leftarrow \lambda \theta_t + (1 - \lambda) \theta_s \]
其中,$ \theta_s $ 是当前学生网络的参数,$ \lambda $ 是一个接近1的动量系数(例如0.996)。这意味着教师网络的参数是学生网络参数在整个训练过程中的**平滑、缓慢变化的指数移动平均**。
- 为什么这样做?
- 稳定性:教师网络的变化比学生网络慢得多,这为学生网络提供了一个稳定、一致的监督信号。如果教师网络更新太快(比如直接复制学生网络),容易导致模型崩溃(例如,所有输出都收敛到同一个点)。
- 避免模型崩溃:EMA机制自然地产生了一个“慢速教师”,它综合了学生网络历史上探索过的较好参数,引导学生网络朝着一个更稳定、更通用的表示空间进化。
第三步:输出处理与损失函数——中心化和锐化
DINO的另一个核心技巧是对网络输出概率分布的特殊处理。
-
网络输出:学生和教师网络的最终输出都是一个 \(K\) 维向量,并经过 Softmax 函数 转换为一个概率分布。这里 \(K\) 是一个超参数(例如65536),可以理解为一个虚拟的类别数。模型的目标是学习将这些无标签图片归类到这 \(K\) 个虚拟“原型”(prototype)上。
-
中心化(Centering):
- 操作:在计算教师网络的输出概率前,先对其 \(K\) 维输出向量进行中心化,即减去一个滑动平均值 \(c\)。
\[ g_t(x) \leftarrow g_t(x) - c \]
* **动机**:防止模型坍缩到一个平凡解——即教师网络将所有输入都预测为同一个维度(虚拟类别)。中心化操作确保了输出分布在各个维度上保持平衡,鼓励模型去发现和利用所有可用的“虚拟类别”。
* **更新规则**:中心化参数 $ c $ 本身也通过EMA更新:$ c \leftarrow m c + (1-m) \frac{1}{B} \sum_{i=1}^{B} g_t(x_i) $,其中 $ B $ 是批次大小,$ m $ 是一个动量系数(例如0.9)。
- 温度系数锐化(Sharpening with Temperature):
- 在应用Softmax时,为教师网络和学生网络使用不同的温度系数 \(\tau_t\) 和 \(\tau_s\),且通常 \(\tau_t < \tau_s\)。
\[ P_t(x) = \text{softmax}(g_t(x) / \tau_t) \\ P_s(x) = \text{softmax}(g_s(x) / \tau_s) \]
* **作用**:较低的 $ \tau_t $ 会使教师网络的输出概率分布更“尖锐”(即概率更集中于少数维度),这为**学生网络提供了一个更清晰、置信度更高的目标**。而学生网络使用较高的温度,使其输出更平滑,更容易学习。
- 损失函数:使用标准的交叉熵损失(Cross-Entropy Loss),让学生网络的输出分布 \(P_s\) 去匹配教师网络的输出分布 \(P_t\)。
\[ \mathcal{L} = H(P_t, P_s) = -\sum_{i=1}^{K} P_t^{(i)} \log P_s^{(i)} \]
注意,这里的“标签”是教师网络产生的概率分布 $ P_t $。由于教师网络由EMA产生,并且使用了中心化和锐化,这个目标分布是稳定且有信息量的。
第四步:DINO算法流程总结
- 初始化:学生网络参数 \(\theta_s\) 随机初始化,教师网络参数 \(\theta_t\) 初始化为与 \(\theta_s\) 相同。中心化参数 \(c\) 初始化为0。
- 循环每个批次:
a. 采样:从数据集中采样一批图片 \(x\)。
b. 数据增强:为每张图片 \(x\) 生成两个增强视图:局部视图 \(x_1\) 和全局视图 \(x_2\)。
c. 前向传播:学生网络处理 \(x_1\),得到输出 \(g_s\);教师网络处理 \(x_2\),得到输出 \(g_t\)。
d. 输出处理:
* 对学生网络输出:计算 \(P_s = \text{softmax}(g_s / \tau_s)\)。
* 对教师网络输出:先进行中心化 \(g_t \leftarrow g_t - c\),再计算 \(P_t = \text{softmax}(g_t / \tau_t)\)。
e. 计算损失:\(\mathcal{L} = H(P_t, P_s)\)。
f. 参数更新:
* 通过梯度下降,只更新学生网络参数 \(\theta_s\)。
* 通过EMA更新教师网络参数:\(\theta_t \leftarrow \lambda \theta_t + (1-\lambda) \theta_s\)。
* 通过EMA更新中心化参数:\(c \leftarrow m c + (1-m) \frac{1}{B} \sum_{i=1}^{B} g_t(x_{2}^{(i)})\)。 - 训练结束:训练完成后,丢弃投影头,使用主干网络提取的特征作为下游任务(如图像分类、目标检测、分割)的输入特征或用于微调。
核心优势:
- 无需负样本:避免了维护庞大负样本内存库的开销和复杂性。
- 简单有效:结合了知识蒸馏、EMA、多裁剪视图等成熟思想,形成了强大的表示学习框架。
- 可扩展性强:特别适合与视觉Transformer(ViT)结合,在ImageNet等数据集上取得了当时最优的自监督学习性能。
通过以上步骤,DINO成功地构建了一个稳定的、自我改进的学习循环:学生网络不断向一个“更好的历史自我”(教师网络)学习,从而在无标签数据上学习到具有丰富语义信息的视觉表示。