Transformer模型中的多头注意力机制原理与实现
字数 1155 2025-11-17 03:13:00
Transformer模型中的多头注意力机制原理与实现
一、多头注意力机制的基本概念
多头注意力机制是Transformer模型的核心组成部分,它通过并行运行多个独立的注意力头来捕捉输入序列中不同子空间的特征。其核心思想是:将输入映射到多个不同的表示子空间,在每个子空间中计算注意力,最后将结果合并。这样可以让模型同时关注不同位置和不同语义层面的信息。
二、多头注意力的实现步骤
-
输入线性变换
输入序列经过三个独立的线性层(权重矩阵W_Q、W_K、W_V),分别生成查询(Q)、键(K)、值(V)矩阵。假设输入维度为d_model(例如512),则每个线性层的输出维度为d_k(例如64)。- 公式:Q = XW_Q, K = XW_K, V = XW_V(X为输入矩阵)
-
分头(Head Splitting)
将Q、K、V矩阵按头数(h,例如8)分割成h个子矩阵。每个头的维度为d_k = d_model / h(例如512/8=64)。- 操作:将Q、K、V的最后一维重塑为(h, seq_len, d_k),使每个头独立处理子空间。
-
缩放点积注意力计算
在每个头上独立计算注意力:- 步骤1:计算Q与K的点积,得到注意力分数矩阵(维度:seq_len × seq_len)。
- 步骤2:缩放分数(除以√d_k),防止点积结果过大导致梯度消失。
- 步骤3:应用Softmax归一化,得到注意力权重。
- 步骤4:用权重对V加权求和,得到该头的输出(维度:seq_len × d_k)。
- 公式:Attention(Q, K, V) = Softmax(QK^T / √d_k) V
-
多头输出合并
将h个头的输出拼接(Concat)在一起,恢复为原始维度d_model(例如8个头×64维=512维)。- 操作:将h个(seq_len, d_k)矩阵拼接为(seq_len, h×d_k)。
-
输出线性变换
通过一个线性层(权重矩阵W_O)将拼接后的结果映射到最终输出维度(通常保持为d_model)。- 公式:MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O
三、多头注意力的优势
- 并行化能力:多个头可并行计算,提升效率。
- 多样化特征捕捉:不同头可能关注不同模式(如语法结构、语义关联)。
- 模型容量提升:通过子空间分解增加参数和表达能力。
四、代码实现示意(PyTorch风格)
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, h):
super().__init__()
self.d_k = d_model // h
self.h = h
self.W_Q = nn.Linear(d_model, d_model) # 实际实现中通常分拆为h个小矩阵
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
# 分头:重塑为 (batch_size, h, seq_len, d_k)
Q = self.W_Q(Q).view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(self.d_k)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 合并头并输出
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.W_O(attn_output)
五、总结
多头注意力通过分解特征空间,使Transformer能够同时从多个角度捕捉依赖关系。其设计平衡了计算效率与表达能力,成为现代大语言模型的基石。实际应用中,头数(h)和维度(d_k)需根据任务和资源调整。