基于梯度的元学习(Gradient-based Meta-Learning)原理与MAML算法详解
字数 4458 2025-12-08 14:06:28

基于梯度的元学习(Gradient-based Meta-Learning)原理与MAML算法详解

1. 问题描述与背景

基于梯度的元学习是一种“学会学习”(Learning to Learn)的范式,旨在训练一个模型,使其能够仅用少量样本就快速适应新的任务。核心思想是:在一个由大量相关任务组成的数据集上进行“元训练”,使模型获得良好的初始化参数或快速适应的能力。当面对一个从未见过但同分布的新任务时,模型可以通过少量样本和几次梯度更新就能达到很好的性能。模型无关的元学习(Model-Agnostic Meta-Learning, MAML)是其中最著名和基础的算法之一。

  • 核心挑战:如何找到一组初始模型参数,使得从这个起点出发,对任意新任务进行少量几步梯度下降更新后,模型在该任务上的性能就能最大化。
  • 与普通监督学习的区别:普通监督学习在单个任务上训练,目标是让模型在该任务上表现好。元学习则在任务分布上训练,目标是让模型“学会如何快速适应”,其“损失”是在新任务上适应后的表现。

2. 核心概念与问题形式化

假设我们有一个任务分布 \(p(\mathcal{T})\)。对于从这个分布中采样的每个任务 \(\mathcal{T}_i\)

  • 都有一个支持集(Support Set) :用于任务内部的快速适应(类似于训练集)。
  • 都有一个查询集(Query Set) :用于评估适应后的模型性能,并计算“元损失”以更新初始参数(类似于测试集)。

MAML的目标是学习一组初始参数 \(\theta\)

  • 对于任务 \(\mathcal{T}_i\),从初始参数 \(\theta\) 开始,在任务的支持集上计算损失 \(\mathcal{L}_{\mathcal{T}_i}\),并进行一次(或几次)内循环(Inner Loop) 梯度更新,得到任务特定的参数 \(\theta'_i\)

\[ \theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) \]

其中 $ \alpha $ 是内循环的学习率,$ f_{\theta} $ 是参数化的模型。
  • 然后,在任务的查询集上评估适应后的模型 \(f_{\theta'_i}\),计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)。这个损失反映了从初始点 \(\theta\) 出发,经过一次适应后在新任务上的好坏。

元目标最小化所有任务上,适应后查询损失的总和

\[\min_{\theta} \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i}) = \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}(f_{\theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})}) \]

注意,最终优化的对象是初始参数 \(\theta\),而损失函数依赖于经过梯度更新后的参数 \(\theta'_i\)

3. MAML算法详解与步骤推导

MAML通过梯度下降来优化上述元目标,因此需要计算元损失对初始参数 \(\theta\) 的梯度。这个过程称为外循环(Outer Loop) 更新。

步骤1:采样任务批次
从任务分布 \(p(\mathcal{T})\) 中采样一个批次的任务,例如 \(\mathcal{T}_1, \mathcal{T}_2, ..., \mathcal{T}_n\)

步骤2:内循环适应(适应阶段)
对于每个任务 \(\mathcal{T}_i\)

  1. 用当前初始参数 \(\theta\) 初始化模型。
  2. 在任务 \(\mathcal{T}_i\) 的支持集数据上计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta})\)
  3. 计算内循环梯度 \(\nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})\)
  4. 执行一次(或k次)梯度更新,得到适应后的参数 \(\theta'_i\)

\[ \theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) \]

注意,这一步是“前向”计算,但产生了新的参数。

步骤3:计算元梯度(元优化阶段)

  1. 对于每个任务 \(\mathcal{T}_i\),使用适应后的参数 \(\theta'_i\),在任务 \(\mathcal{T}_i\) 的查询集上计算损失 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)。这个损失记为元损失
  2. 将所有任务的元损失求和:\(\mathcal{L}_{meta} = \sum_i \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)
  3. 计算元损失对初始参数 \(\theta\) 的梯度,即元梯度 \(\nabla_{\theta} \mathcal{L}_{meta}\)
    • 这是整个算法的核心和难点。因为 \(\theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta})\),所以 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)\(\theta\) 的一个复合函数。计算这个梯度需要用到二阶导数
    • 具体计算时,需要使用自动微分工具,在计算图上对 \(\theta\) 求导。这个过程等价于计算 \(\mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})\)\(\theta'_i\) 的梯度,然后再乘以 \(\theta'_i\)\(\theta\) 的雅可比矩阵。而这个雅可比矩阵包含了内循环梯度计算过程的微分。

\[ \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i}) = \frac{\partial \mathcal{L}_{\mathcal{T}_i}(f_{\theta'_i})}{\partial \theta} = \frac{\partial \mathcal{L}_{\mathcal{T}_i}}{\partial \theta'_i} \cdot \frac{\partial \theta'_i}{\partial \theta} \]

其中 $ \frac{\partial \theta'_i}{\partial \theta} = I - \alpha \frac{\partial^2 \mathcal{L}_{\mathcal{T}_i}(f_{\theta})}{\partial \theta^2} $(假设一次内循环更新),这里出现了二阶导(Hessian矩阵)。

步骤4:外循环更新(元更新)
使用计算出的元梯度,更新初始参数 \(\theta\)

\[\theta \leftarrow \theta - \beta \nabla_{\theta} \mathcal{L}_{meta} \]

其中 \(\beta\) 是外循环的学习率(元学习率)。

4. 直观理解与关键点

  • 一阶近似(FOMAML):计算完整二阶导的代价很高。MAML论文提出,在很多时候可以忽略二阶项,只使用一阶梯度(即假设 \(\frac{\partial \theta'_i}{\partial \theta} \approx I\)),仍然能取得很好效果。这称为一阶MAML。
  • 学习的是一个良好的初始化点:MAML不学习一个可以直接做预测的模型,而是学习一个“起点”。这个起点位于所有任务参数空间的“中心”位置,从这个点出发,朝任何一个具体任务的最优解方向走一小步(少量梯度更新),就能达到不错的效果。
  • 与预训练(Pre-training)的区别:预训练(如在大量数据上训练,然后微调)寻找的是一个在任务分布上平均表现好的点。而MAML通过在元损失中显式地模拟测试时的适应过程(内循环),寻找的是一个对后续梯度更新敏感、能通过快速适应变好的点。MAML的点可能初始表现不如预训练模型,但它的潜力(适应性)更强。

5. 算法流程总结

  1. 初始化 模型参数 \(\theta\)
  2. While 未收敛:
    • 采样一批任务 \(\mathcal{T}_i \sim p(\mathcal{T})\)
    • For 每个任务 \(\mathcal{T}_i\)
      • 内循环:计算适应后参数 \(\theta'_i = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}^{support}(f_{\theta})\)
      • 在查询集上计算适应后损失 \(\mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)
    • 外循环:计算元梯度 \(\nabla_{\theta} \sum_i \mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)
    • 更新初始参数:\(\theta \leftarrow \theta - \beta \nabla_{\theta} \sum_i \mathcal{L}_{\mathcal{T}_i}^{query}(f_{\theta'_i})\)
  3. 返回 学习到的初始参数 \(\theta\)

在测试(或部署)时,对于新任务,我们利用学到的 \(\theta\) 作为起点,用该新任务提供的少量支持集样本,执行几次与内循环相同的梯度更新,即可得到适用于该新任务的模型。

基于梯度的元学习(Gradient-based Meta-Learning)原理与MAML算法详解 1. 问题描述与背景 基于梯度的元学习是一种“学会学习”(Learning to Learn)的范式,旨在训练一个模型,使其能够仅用少量样本就快速适应新的任务。核心思想是:在一个由大量相关任务组成的数据集上进行“元训练”,使模型获得良好的初始化参数或快速适应的能力。当面对一个从未见过但同分布的新任务时,模型可以通过少量样本和几次梯度更新就能达到很好的性能。模型无关的元学习(Model-Agnostic Meta-Learning, MAML)是其中最著名和基础的算法之一。 核心挑战 :如何找到一组 初始模型参数 ,使得从这个起点出发,对 任意 新任务进行少量几步梯度下降更新后,模型在该任务上的性能就能最大化。 与普通监督学习的区别 :普通监督学习在单个任务上训练,目标是让模型在该任务上表现好。元学习则在任务分布上训练,目标是让模型“学会如何快速适应”,其“损失”是在新任务上适应后的表现。 2. 核心概念与问题形式化 假设我们有一个任务分布 \( p(\mathcal{T}) \)。对于从这个分布中采样的每个任务 \( \mathcal{T}_ i \): 都有一个 支持集(Support Set) :用于任务内部的快速适应(类似于训练集)。 都有一个 查询集(Query Set) :用于评估适应后的模型性能,并计算“元损失”以更新初始参数(类似于测试集)。 MAML的目标是学习一组 初始参数 \( \theta \)。 对于任务 \( \mathcal{T} i \),从初始参数 \( \theta \) 开始,在任务的支持集上计算损失 \( \mathcal{L} {\mathcal{T} i} \),并进行一次(或几次) 内循环(Inner Loop) 梯度更新,得到任务特定的参数 \( \theta' i \): \[ \theta' i = \theta - \alpha \nabla {\theta} \mathcal{L} {\mathcal{T} i}(f {\theta}) \] 其中 \( \alpha \) 是内循环的学习率,\( f {\theta} \) 是参数化的模型。 然后,在任务的查询集上评估适应后的模型 \( f_ {\theta' i} \),计算损失 \( \mathcal{L} {\mathcal{T} i}(f {\theta'_ i}) \)。这个损失反映了从初始点 \( \theta \) 出发,经过一次适应后在新任务上的好坏。 元目标 是 最小化所有任务上,适应后查询损失的总和 : \[ \min_ {\theta} \sum_ {\mathcal{T} i \sim p(\mathcal{T})} \mathcal{L} {\mathcal{T} i}(f {\theta' i}) = \sum {\mathcal{T} i \sim p(\mathcal{T})} \mathcal{L} {\mathcal{T} i}(f {\theta - \alpha \nabla_ {\theta} \mathcal{L}_ {\mathcal{T} i}(f {\theta})}) \] 注意,最终优化的对象是初始参数 \( \theta \),而损失函数依赖于经过梯度更新后的参数 \( \theta'_ i \)。 3. MAML算法详解与步骤推导 MAML通过梯度下降来优化上述元目标,因此需要计算元损失对初始参数 \( \theta \) 的梯度。这个过程称为 外循环(Outer Loop) 更新。 步骤1:采样任务批次 从任务分布 \( p(\mathcal{T}) \) 中采样一个批次的任务,例如 \( \mathcal{T}_ 1, \mathcal{T}_ 2, ..., \mathcal{T}_ n \)。 步骤2:内循环适应(适应阶段) 对于每个任务 \( \mathcal{T}_ i \): 用当前初始参数 \( \theta \) 初始化模型。 在任务 \( \mathcal{T} i \) 的支持集数据上计算损失 \( \mathcal{L} {\mathcal{T} i}(f {\theta}) \)。 计算 内循环梯度 \( \nabla_ {\theta} \mathcal{L}_ {\mathcal{T} i}(f {\theta}) \)。 执行一次(或k次)梯度更新,得到 适应后的参数 \( \theta' i \): \[ \theta' i = \theta - \alpha \nabla {\theta} \mathcal{L} {\mathcal{T} i}(f {\theta}) \] 注意,这一步是“前向”计算,但产生了新的参数。 步骤3:计算元梯度(元优化阶段) 对于每个任务 \( \mathcal{T}_ i \),使用 适应后的参数 \( \theta'_ i \),在任务 \( \mathcal{T} i \) 的查询集上计算损失 \( \mathcal{L} {\mathcal{T} i}(f {\theta'_ i}) \)。这个损失记为 元损失 。 将所有任务的元损失求和:\( \mathcal{L} {meta} = \sum_ i \mathcal{L} {\mathcal{T} i}(f {\theta'_ i}) \)。 计算 元损失对初始参数 \( \theta \) 的梯度 ,即元梯度 \( \nabla_ {\theta} \mathcal{L}_ {meta} \)。 这是整个算法的核心和难点。因为 \( \theta' i = \theta - \alpha \nabla {\theta} \mathcal{L} {\mathcal{T} i}(f {\theta}) \),所以 \( \mathcal{L} {\mathcal{T} i}(f {\theta'_ i}) \) 是 \( \theta \) 的一个复合函数。计算这个梯度需要用到 二阶导数 。 具体计算时,需要使用自动微分工具,在计算图上对 \( \theta \) 求导。这个过程等价于计算 \( \mathcal{L}_ {\mathcal{T} i}(f {\theta'_ i}) \) 对 \( \theta' i \) 的梯度,然后再乘以 \( \theta' i \) 对 \( \theta \) 的雅可比矩阵。而这个雅可比矩阵包含了内循环梯度计算过程的微分。 \[ \nabla {\theta} \mathcal{L} {\mathcal{T} i}(f {\theta' i}) = \frac{\partial \mathcal{L} {\mathcal{T} i}(f {\theta' i})}{\partial \theta} = \frac{\partial \mathcal{L} {\mathcal{T}_ i}}{\partial \theta'_ i} \cdot \frac{\partial \theta'_ i}{\partial \theta} \] 其中 \( \frac{\partial \theta' i}{\partial \theta} = I - \alpha \frac{\partial^2 \mathcal{L} {\mathcal{T} i}(f {\theta})}{\partial \theta^2} \)(假设一次内循环更新),这里出现了二阶导(Hessian矩阵)。 步骤4:外循环更新(元更新) 使用计算出的元梯度,更新初始参数 \( \theta \): \[ \theta \leftarrow \theta - \beta \nabla_ {\theta} \mathcal{L}_ {meta} \] 其中 \( \beta \) 是外循环的学习率(元学习率)。 4. 直观理解与关键点 一阶近似(FOMAML) :计算完整二阶导的代价很高。MAML论文提出,在很多时候可以忽略二阶项,只使用一阶梯度(即假设 \( \frac{\partial \theta'_ i}{\partial \theta} \approx I \)),仍然能取得很好效果。这称为一阶MAML。 学习的是一个良好的初始化点 :MAML不学习一个可以直接做预测的模型,而是学习一个“起点”。这个起点位于所有任务参数空间的“中心”位置,从这个点出发,朝任何一个具体任务的最优解方向走一小步(少量梯度更新),就能达到不错的效果。 与预训练(Pre-training)的区别 :预训练(如在大量数据上训练,然后微调)寻找的是一个在任务分布上平均表现好的点。而MAML通过 在元损失中显式地模拟测试时的适应过程(内循环) ,寻找的是一个 对后续梯度更新敏感、能通过快速适应变好 的点。MAML的点可能初始表现不如预训练模型,但它的潜力(适应性)更强。 5. 算法流程总结 初始化 模型参数 \( \theta \)。 While 未收敛: 采样一批任务 \( \mathcal{T}_ i \sim p(\mathcal{T}) \)。 For 每个任务 \( \mathcal{T}_ i \): 内循环 :计算适应后参数 \( \theta' i = \theta - \alpha \nabla {\theta} \mathcal{L}_ {\mathcal{T} i}^{support}(f {\theta}) \)。 在查询集上计算适应后损失 \( \mathcal{L}_ {\mathcal{T} i}^{query}(f {\theta'_ i}) \)。 外循环 :计算元梯度 \( \nabla_ {\theta} \sum_ i \mathcal{L}_ {\mathcal{T} i}^{query}(f {\theta'_ i}) \)。 更新初始参数:\( \theta \leftarrow \theta - \beta \nabla_ {\theta} \sum_ i \mathcal{L}_ {\mathcal{T} i}^{query}(f {\theta'_ i}) \)。 返回 学习到的初始参数 \( \theta \)。 在测试(或部署)时,对于新任务,我们利用学到的 \( \theta \) 作为起点,用该新任务提供的少量支持集样本,执行几次与内循环相同的梯度更新,即可得到适用于该新任务的模型。