Transformer模型中的掩码多头注意力机制详解
1. 题目/知识点描述
Transformer模型由编码器和解码器组成,而解码器的核心组件之一是掩码多头注意力机制。它的核心目标是在自回归生成任务中,确保解码器在生成某个位置的输出时,只能“看到”或“关注”到该位置之前(包括自身)的已知信息,而不能“看到”未来的信息,以保持生成过程的因果性。这个机制通过一个掩码来实现,是理解Transformer如何用于序列生成的关键。
2. 核心问题与解决思路
- 核心问题:在机器翻译、文本生成等任务中,当我们用解码器逐个生成目标序列的单词时,在生成第
t个单词时,我们只有前t-1个单词是已知的。模型在计算当前时刻的注意力分布时,必须防止其“偷看”到未来尚未生成的单词。 - 解决思路:在计算注意力权重之前,在注意力分数矩阵中加入一个“掩码”。这个掩码会将未来位置对应的注意力得分设置为一个绝对值非常大的负数(如
-1e9),这样在后续的Softmax操作中,这些位置的权重就会趋近于0,从而实现“屏蔽未来信息”的效果。
3. 知识详解与推导步骤
步骤一:回顾标准缩放点积注意力
标准注意力的计算过程为:
注意力(Q, K, V) = softmax( (Q * K^T) / sqrt(d_k) ) * V
其中,Q (查询)、K (键)、V (值) 是输入序列经过线性变换得到的矩阵。计算 Q*K^T 得到一个注意力分数矩阵,其尺寸为 [目标序列长度, 源序列长度]。每个元素 (i, j) 表示第i个目标位置 对 第j个源位置 的关注程度。
步骤二:引入因果掩码
在解码器的自注意力层中,Q、K、V 都来自同一个目标序列(上一个解码器层的输出或初始的输入嵌入)。为了保持因果性,我们需要确保在计算目标序列中第i个位置的输出时,只能关注到第1到第i个位置。
- 构造一个掩码矩阵
M,其尺寸与注意力分数矩阵S = (Q * K^T) / sqrt(d_k)相同。 M是一个上三角矩阵(不包含主对角线):- 对于所有
j > i(即未来位置)的M[i, j]设为-∞或一个很大的负数(如-1e9)。 - 对于所有
j <= i(即过去和当前位置)的M[i, j]设为0。
- 对于所有
- 掩码应用:在计算Softmax之前,将这个掩码矩阵加到注意力分数矩阵上:
S_masked = S + M。 - 数学表达:
因果注意力(Q, K, V) = softmax( (Q * K^T) / sqrt(d_k) + M ) * V - 效果:对于未来位置(
j > i),S_masked[i, j]会变成一个非常大的负数。经过Softmax计算后,这些位置对应的权重exp(很大负数) ≈ 0,从而实现了屏蔽。
步骤三:多头注意力中的掩码
多头注意力是将Q、K、V 在特征维度上分割成h个头,然后并行地对每个头独立进行步骤二的掩码注意力计算。其过程为:
- 对每个头
i(i=1,...,h),将Q, K, V通过不同的线性投影得到Q_i, K_i, V_i。 - 对每个头,计算掩码注意力:
head_i = 因果注意力(Q_i, K_i, V_i)。 - 将所有头的输出拼接起来:
Concat(head_1, ..., head_h)。 - 对拼接结果做一次线性投影,得到最终的多头注意力输出。
步骤四:在Transformer解码器中的位置
在一个标准的Transformer解码器层中,通常包含两个注意力子层:
- 掩码多头自注意力层:如上所述,
Q, K, V均来自解码器自身的上一层的输出。这是实现自回归生成的关键,确保每个位置只依赖于其左侧的已知上下文。 - 编码器-解码器注意力层:
Q来自解码器的掩码自注意力层输出,而K和V来自编码器的最终输出。这个层通常不加掩码,因为解码器可以任意关注完整的、已全部编码完成的源序列信息。
4. 示例说明
假设我们正在解码生成目标序列 [“我”, “爱”, “AI”],当前已生成 [“我”, “爱”],正在生成第三个词。
- 在掩码多头自注意力中,当计算“AI”这个词的表示时,它的
Q向量只能与“我”的K向量、“爱”的K向量以及“AI”自身的K向量(这是查询自身)计算分数,而不能与未来的(尚未生成的)词的K向量计算。掩码确保了这一点。 - 在编码器-解码器注意力中,当计算“AI”这个词的表示时,它的
Q向量可以与源语言序列(例如[“I”, “love”, “AI”])中所有位置的K向量计算注意力,从而决定“AI”应该重点“参考”源语言中的哪个部分。
5. 核心要点与意义
- 因果性保证:掩码多头注意力是Transformer模型能够用于序列生成任务(如GPT、翻译、摘要)的基石,它严格模拟了“从左到右”的生成过程。
- 并行训练:尽管在推断时是自回归的,但训练时,由于目标序列已知,可以一次性计算所有位置的掩码注意力,并通过掩码确保位置间的依赖关系正确。这比RNN的顺序计算效率更高。
- 区分“自”与“交叉”:要区分解码器内部的“掩码自注意力”(因果的)和与编码器交互的“交叉注意力”(非因果的)。
通过这种方法,Transformer解码器巧妙地结合了并行训练的效率和序列生成的因果约束,成为现代生成式模型的强大基础。