Transformer模型中的缩放点积注意力(Scaled Dot-Product Attention)原理与实现
字数 1060 2025-11-11 09:26:02

Transformer模型中的缩放点积注意力(Scaled Dot-Product Attention)原理与实现

描述
缩放点积注意力是Transformer模型的核心组件,用于计算输入序列中各个位置之间的相关性权重。它通过查询(Query)、键(Key)和值(Value)三个矩阵的交互,实现对输入信息的加权聚合。与普通点积注意力相比,缩放因子(Scaling Factor)的引入解决了点积值过大导致的梯度消失问题,是保证模型稳定训练的关键设计。

解题过程

  1. 基本概念:Query、Key、Value

    • Query(Q):表示当前需要计算注意力的位置,好比"提问者"。
    • Key(K):表示序列中所有位置提供的标识,用于与Query匹配,好比"答案的标签"。
    • Value(V):存储每个位置的实际信息内容,在注意力权重确定后被加权求和,好比"答案的内容"。
    • 三者的维度:设输入序列长度为L,向量维度为d。Q、K、V通常通过线性变换从输入序列得到,维度均为L×d(实际处理中常使用批量和多头形式,此处以单个注意力头为例)。
  2. 计算未缩放的注意力分数

    • 步骤:通过Query和Key的点积计算相关性分数。
      分数矩阵 = Q · K^T
      
    • 维度:(L×d) · (d×L) = L×L,得到一个L×L的矩阵,其中每个元素表示两个位置之间的相关性分数。
    • 问题:当维度d较大时,点积的结果可能非常大,导致softmax函数的梯度极小(因为softmax在输入值极大时梯度趋近于0)。
  3. 引入缩放因子(Scaling)

    • 解决方法:将点积结果除以√d_k(d_k是Key的维度),缩小方差。
      缩放后的分数 = (Q · K^T) / √d_k
      
    • 原理:点积的方差随d_k增大而增加,除以√d_k使方差稳定在1左右,确保梯度处于健康范围。
    • 若未缩放:梯度消失问题会使模型难以训练。
  4. 应用softmax获取注意力权重

    • 步骤:对缩放后的分数矩阵的每一行应用softmax函数,使得每一行的权重之和为1。
      注意力权重 = softmax(缩放后的分数, dim=-1)
      
    • 作用:将分数转化为概率分布,表示每个位置对当前Query的贡献程度。
  5. 加权求和得到输出

    • 步骤:用注意力权重对Value矩阵进行加权求和。
      输出 = 注意力权重 · V
      
    • 维度:(L×L) · (L×d) = L×d,输出与输入序列长度相同,每个位置是全局信息的加权组合。

完整公式与代码示意

  • 公式:
    Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V
    
  • Python代码示例(PyTorch风格):
    import torch
    import torch.nn.functional as F
    
    def scaled_dot_product_attention(Q, K, V):
        d_k = Q.size(-1)  # 获取Key的维度
        scores = torch.matmul(Q, K.transpose(-2, -1))  # Q · K^T
        scores = scores / (d_k ** 0.5)  # 缩放
        attn_weights = F.softmax(scores, dim=-1)  # 按行softmax
        output = torch.matmul(attn_weights, V)  # 加权求和
        return output
    

总结
缩放点积注意力通过缩放因子解决了高维点积的梯度问题,使Transformer能够稳定处理长序列。其设计实现了序列内任意位置的直接交互,突破了RNN的顺序计算限制,成为现代大语言模型的基础。

Transformer模型中的缩放点积注意力(Scaled Dot-Product Attention)原理与实现 描述 缩放点积注意力是Transformer模型的核心组件,用于计算输入序列中各个位置之间的相关性权重。它通过查询(Query)、键(Key)和值(Value)三个矩阵的交互,实现对输入信息的加权聚合。与普通点积注意力相比,缩放因子(Scaling Factor)的引入解决了点积值过大导致的梯度消失问题,是保证模型稳定训练的关键设计。 解题过程 基本概念:Query、Key、Value Query(Q) :表示当前需要计算注意力的位置,好比"提问者"。 Key(K) :表示序列中所有位置提供的标识,用于与Query匹配,好比"答案的标签"。 Value(V) :存储每个位置的实际信息内容,在注意力权重确定后被加权求和,好比"答案的内容"。 三者的维度:设输入序列长度为L,向量维度为d。Q、K、V通常通过线性变换从输入序列得到,维度均为L×d(实际处理中常使用批量和多头形式,此处以单个注意力头为例)。 计算未缩放的注意力分数 步骤:通过Query和Key的点积计算相关性分数。 维度:(L×d) · (d×L) = L×L,得到一个L×L的矩阵,其中每个元素表示两个位置之间的相关性分数。 问题:当维度d较大时,点积的结果可能非常大,导致softmax函数的梯度极小(因为softmax在输入值极大时梯度趋近于0)。 引入缩放因子(Scaling) 解决方法:将点积结果除以√d_ k(d_ k是Key的维度),缩小方差。 原理:点积的方差随d_ k增大而增加,除以√d_ k使方差稳定在1左右,确保梯度处于健康范围。 若未缩放:梯度消失问题会使模型难以训练。 应用softmax获取注意力权重 步骤:对缩放后的分数矩阵的每一行应用softmax函数,使得每一行的权重之和为1。 作用:将分数转化为概率分布,表示每个位置对当前Query的贡献程度。 加权求和得到输出 步骤:用注意力权重对Value矩阵进行加权求和。 维度:(L×L) · (L×d) = L×d,输出与输入序列长度相同,每个位置是全局信息的加权组合。 完整公式与代码示意 公式: Python代码示例(PyTorch风格): 总结 缩放点积注意力通过缩放因子解决了高维点积的梯度问题,使Transformer能够稳定处理长序列。其设计实现了序列内任意位置的直接交互,突破了RNN的顺序计算限制,成为现代大语言模型的基础。