基于梯度的元学习(Gradient-based Meta-Learning)原理与MAML算法详解
1. 问题描述与背景
基于梯度的元学习是一种“学会学习”(Learning to Learn)的范式,旨在训练一个模型,使其能够仅用少量样本就快速适应新的任务。核心思想是:在一个由大量相关任务组成的数据集上进行“元训练”,使模型获得良好的初始化参数或快速适应的能力。当面对一个从未见过但同分布的新任务时,模型可以通过少量样本和几次梯度更新就能达到很好的性能。模型无关的元学习(Model-Agnostic Meta-Learning, MAML)是其中最著名和基础的算法之一。
- 核心挑战:如何找到一组初始模型参数,使得从这个起点出发,对任意新任务进行少量几步梯度下降更新后,模型在该任务上的性能就能最大化。
- 与普通监督学习的区别:普通监督学习在单个任务上训练,目标是让模型在该任务上表现好。元学习则在任务分布上训练,目标是让模型“学会如何快速适应”,其“损失”是在新任务上适应后的表现。
2. 核心概念与问题形式化
假设我们有一个任务分布 \(p(\mathcal{T})\)。对于从这个分布中采样的每个任务 \(\mathcal{T}_i\):
- 都有一个支持集(Support Set) :用于任务内部的快速适应(类似于训练集)。
- 都有一个查询集(Query Set) :用于评估适应后的模型性能,并计算“元损失”以更新初始参数(类似于测试集)。
MAML的目标是学习一组初始参数 \(\theta\)。
- 对于任务 \(\mathcal{T}_i\),从初始参数 \(\theta\) 开始,在任务的支持集上计算损失 \(\mathcal{L}_{\mathcal{T}_i}\),并进行一次(或几次)内循环(Inner Loop) 梯度更新,得到任务特定的参数 \(\theta'_i\):
\[ \theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) \]
其中 $ \alpha $ 是内循环的学习率,$ f_{\theta} $ 是参数化的模型。
- 然后,在任务的查询集上评估适应后的模型 \(f_{\theta'_i}\),计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)。这个损失反映了从初始点 \(\theta\) 出发,经过一次适应后在新任务上的好坏。
元目标是最小化所有任务上,适应后查询损失的总和:
\[\min_{\theta} \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i}) = \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}(f_{\theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})}) \]
注意,最终优化的对象是初始参数 \(\theta\),而损失函数依赖于经过梯度更新后的参数 \(\theta'_i\)。
3. MAML算法详解与步骤推导
MAML通过梯度下降来优化上述元目标,因此需要计算元损失对初始参数 \(\theta\) 的梯度。这个过程称为外循环(Outer Loop) 更新。
步骤1:采样任务批次
从任务分布 \(p(\mathcal{T})\) 中采样一个批次的任务,例如 \(\mathcal{T}_1, \mathcal{T}_2, ..., \mathcal{T}_n\)。
步骤2:内循环适应(适应阶段)
对于每个任务 \(\mathcal{T}_i\):
- 用当前初始参数 \(\theta\) 初始化模型。
- 在任务 \(\mathcal{T}_i\) 的支持集数据上计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta})\)。
- 计算内循环梯度 \(\nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})\)。
- 执行一次(或k次)梯度更新,得到适应后的参数 \(\theta'_i\):
\[ \theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) \]
注意,这一步是“前向”计算,但产生了新的参数。
步骤3:计算元梯度(元优化阶段)
- 对于每个任务 \(\mathcal{T}_i\),使用适应后的参数 \(\theta'_i\),在任务 \(\mathcal{T}_i\) 的查询集上计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)。这个损失记为元损失。
- 将所有任务的元损失求和:\(\mathcal{L}_{meta} = \sum_i \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)。
- 计算元损失对初始参数 \(\theta\) 的梯度,即元梯度 \(\nabla_{\theta} \mathcal{L}_{meta}\)。
- 这是整个算法的核心和难点。因为 \(\theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})\),所以 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\) 是 \(\theta\) 的一个复合函数。计算这个梯度需要用到二阶导数。
- 具体计算时,需要使用自动微分工具,在计算图上对 \(\theta\) 求导。这个过程等价于计算 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\) 对 \(\theta'_i\) 的梯度,然后再乘以 \(\theta'_i\) 对 \(\theta\) 的雅可比矩阵。而这个雅可比矩阵包含了内循环梯度计算过程的微分。
\[ \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i}) = \frac{\partial \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})}{\partial \theta} = \frac{\partial \mathcal{L}_{\mathcal{T}_i}}{\partial \theta'_i} \cdot \frac{\partial \theta'_i}{\partial \theta} \]
其中 $ \frac{\partial \theta'_i}{\partial \theta} = I - \alpha \frac{\partial^2 \mathcal{L}_{\mathcal{T}_i}(f_{\theta})}{\partial \theta^2} $(假设一次内循环更新),这里出现了二阶导(Hessian矩阵)。
步骤4:外循环更新(元更新)
使用计算出的元梯度,更新初始参数 \(\theta\):
\[\theta \leftarrow \theta - \beta \nabla_{\theta} \mathcal{L}_{meta} \]
其中 \(\beta\) 是外循环的学习率(元学习率)。
4. 直观理解与关键点
- 一阶近似(FOMAML):计算完整二阶导的代价很高。MAML论文提出,在很多时候可以忽略二阶项,只使用一阶梯度(即假设 \(\frac{\partial \theta'_i}{\partial \theta} \approx I\)),仍然能取得很好效果。这称为一阶MAML。
- 学习的是一个良好的初始化点:MAML不学习一个可以直接做预测的模型,而是学习一个“起点”。这个起点位于所有任务参数空间的“中心”位置,从这个点出发,朝任何一个具体任务的最优解方向走一小步(少量梯度更新),就能达到不错的效果。
- 与预训练(Pre-training)的区别:预训练(如在大量数据上训练,然后微调)寻找的是一个在任务分布上平均表现好的点。而MAML通过在元损失中显式地模拟测试时的适应过程(内循环),寻找的是一个对后续梯度更新敏感、能通过快速适应变好的点。MAML的点可能初始表现不如预训练模型,但它的潜力(适应性)更强。
5. 算法流程总结
- 初始化 模型参数 \(\theta\)。
- While 未收敛:
- 采样一批任务 \(\mathcal{T}_i \sim p(\mathcal{T})\)。
- For 每个任务 \(\mathcal{T}_i\):
- 内循环:计算适应后参数 \(\theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}^{support}(f_{\theta})\)。
- 在查询集上计算适应后损失 \(\mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)。
- 外循环:计算元梯度 \(\nabla_{\theta} \sum_i \mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)。
- 更新初始参数:\(\theta \leftarrow \theta - \beta \nabla_{\theta} \sum_i \mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)。
- 返回 学习到的初始参数 \(\theta\)。
在测试(或部署)时,对于新任务,我们利用学到的 \(\theta\) 作为起点,用该新任务提供的少量支持集样本,执行几次与内循环相同的梯度更新,即可得到适用于该新任务的模型。