Transformer模型中的缩放点积注意力(Scaled Dot-Product Attention)原理与实现
字数 1060 2025-11-11 09:26:02
Transformer模型中的缩放点积注意力(Scaled Dot-Product Attention)原理与实现
描述
缩放点积注意力是Transformer模型的核心组件,用于计算输入序列中各个位置之间的相关性权重。它通过查询(Query)、键(Key)和值(Value)三个矩阵的交互,实现对输入信息的加权聚合。与普通点积注意力相比,缩放因子(Scaling Factor)的引入解决了点积值过大导致的梯度消失问题,是保证模型稳定训练的关键设计。
解题过程
-
基本概念:Query、Key、Value
- Query(Q):表示当前需要计算注意力的位置,好比"提问者"。
- Key(K):表示序列中所有位置提供的标识,用于与Query匹配,好比"答案的标签"。
- Value(V):存储每个位置的实际信息内容,在注意力权重确定后被加权求和,好比"答案的内容"。
- 三者的维度:设输入序列长度为L,向量维度为d。Q、K、V通常通过线性变换从输入序列得到,维度均为L×d(实际处理中常使用批量和多头形式,此处以单个注意力头为例)。
-
计算未缩放的注意力分数
- 步骤:通过Query和Key的点积计算相关性分数。
分数矩阵 = Q · K^T - 维度:(L×d) · (d×L) = L×L,得到一个L×L的矩阵,其中每个元素表示两个位置之间的相关性分数。
- 问题:当维度d较大时,点积的结果可能非常大,导致softmax函数的梯度极小(因为softmax在输入值极大时梯度趋近于0)。
- 步骤:通过Query和Key的点积计算相关性分数。
-
引入缩放因子(Scaling)
- 解决方法:将点积结果除以√d_k(d_k是Key的维度),缩小方差。
缩放后的分数 = (Q · K^T) / √d_k - 原理:点积的方差随d_k增大而增加,除以√d_k使方差稳定在1左右,确保梯度处于健康范围。
- 若未缩放:梯度消失问题会使模型难以训练。
- 解决方法:将点积结果除以√d_k(d_k是Key的维度),缩小方差。
-
应用softmax获取注意力权重
- 步骤:对缩放后的分数矩阵的每一行应用softmax函数,使得每一行的权重之和为1。
注意力权重 = softmax(缩放后的分数, dim=-1) - 作用:将分数转化为概率分布,表示每个位置对当前Query的贡献程度。
- 步骤:对缩放后的分数矩阵的每一行应用softmax函数,使得每一行的权重之和为1。
-
加权求和得到输出
- 步骤:用注意力权重对Value矩阵进行加权求和。
输出 = 注意力权重 · V - 维度:(L×L) · (L×d) = L×d,输出与输入序列长度相同,每个位置是全局信息的加权组合。
- 步骤:用注意力权重对Value矩阵进行加权求和。
完整公式与代码示意
- 公式:
Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V - Python代码示例(PyTorch风格):
import torch import torch.nn.functional as F def scaled_dot_product_attention(Q, K, V): d_k = Q.size(-1) # 获取Key的维度 scores = torch.matmul(Q, K.transpose(-2, -1)) # Q · K^T scores = scores / (d_k ** 0.5) # 缩放 attn_weights = F.softmax(scores, dim=-1) # 按行softmax output = torch.matmul(attn_weights, V) # 加权求和 return output
总结
缩放点积注意力通过缩放因子解决了高维点积的梯度问题,使Transformer能够稳定处理长序列。其设计实现了序列内任意位置的直接交互,突破了RNN的顺序计算限制,成为现代大语言模型的基础。