Transformer模型中的多头注意力机制原理与实现
字数 1216 2025-11-05 23:47:39
Transformer模型中的多头注意力机制原理与实现
问题描述
多头注意力机制是Transformer模型的核心组件,它通过并行运行多个独立的注意力头来增强模型的表达能力。每个头在不同的表示子空间中学习关注输入的不同部分,最后将各头的输出组合起来。这种设计使模型能够同时捕捉不同类型的依赖关系(如语法结构、语义关联等)。
核心原理分步解析
1. 单头注意力回顾
- 输入:查询矩阵Q、键矩阵K、值矩阵V
- 计算:Attention(Q,K,V) = softmax(QKᵀ/√dₖ)V
- 其中√dₖ是缩放因子,防止内积过大导致梯度消失
2. 多头注意力的核心思想
- 将Q、K、V通过不同的线性投影拆分成h个头(h通常取8-16)
- 每个头在低维子空间(dₖ = d_model/h)独立计算注意力
- 最后将所有头的输出拼接并通过线性变换融合
3. 具体实现步骤
步骤1:线性投影拆分
- 对每个头i(i=1,...,h)使用独立的权重矩阵:
- Q⁽ⁱ⁾ = QWᵢ^Q (Wᵢ^Q ∈ ℝ^{d_model × dₖ})
- K⁽ⁱ⁾ = KWᵢ^K
- V⁽ⁱ⁾ = VWᵢ^V
- 投影后每个头的维度为dₖ = d_model/h
步骤2:并行注意力计算
- 每个头独立计算缩放点积注意力:
- headᵢ = Attention(Q⁽ⁱ⁾, K⁽ⁱ⁾, V⁽ⁱ⁾)
- = softmax(Q⁽ⁱ⁾K⁽ⁱ⁾ᵀ/√dₖ)V⁽ⁱ⁾
步骤3:多头输出拼接
- 将所有头的输出沿特征维度拼接:
- MultiHead(Q,K,V) = Concat(head₁, head₂, ..., headₕ)
- 拼接后维度恢复为d_model
步骤4:最终线性投影
- 通过可学习的权重矩阵W^O进行融合:
- Output = MultiHead(Q,K,V)W^O
- 其中W^O ∈ ℝ^{d_model × d_model}
4. 数学公式完整表达
MultiHead(Q,K,V) = Concat(head₁,...,headₕ)W^O
headᵢ = Attention(QWᵢ^Q, KWᵢ^K, VWᵢ^V)
5. 多头注意力的优势分析
- 并行化:各头可同时计算,充分利用硬件加速
- 表示多样性:不同头学习关注不同模式(如局部/全局、语法/语义)
- 模型容量:增加头数相当于增加子网络数量,提升表达能力
- 梯度多样性:各头提供互补的梯度信号,改善优化过程
6. 实现细节与代码思路
# 伪代码示例
class MultiHeadAttention:
def __init__(self, d_model, num_heads):
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 初始化所有投影矩阵
self.W_Q = nn.Linear(d_model, d_model) # 实际实现中通常分开初始化
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, Q, K, V):
batch_size = Q.size(0)
# 线性投影并分头
Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k)
K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k)
V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k)
# 转置以便矩阵运算 (batch_size, num_heads, seq_len, d_k)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 转置回并拼接
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 最终线性变换
return self.W_O(attn_output)
7. 实际应用中的关键点
- 残差连接:多头注意力输出会与输入相加,缓解梯度消失
- 层归一化:通常在注意力前后应用,稳定训练过程
- 掩码机制:解码器中使用因果掩码确保自回归性质
- 计算效率:虽然头数增加,但每个头维度降低,总计算量基本不变
通过这种分头设计的机制,Transformer能够同时从多个角度分析输入序列的依赖关系,这是其在机器翻译、文本生成等任务上取得突破性进展的重要原因。