自注意力机制(Self-Attention)中的掩码(Masking)机制详解
一、知识点描述
在自注意力机制中,掩码(Masking)是一种控制注意力计算范围的技术,通过屏蔽(设置为负无穷大)某些注意力权重,防止模型“看到”或“关注”特定位置的信息。它在多种场景下至关重要,例如:
- 处理变长序列:在批量训练中,用特殊填充符(如
<pad>)补齐不同长度序列,掩码可防止注意力机制关注这些无意义的填充位置。 - 解码器的自回归生成:在Transformer解码器或类似自回归模型中,为保持生成过程的因果性(Causal),需确保每个位置在生成时只能关注其自身及之前的已知位置,不能“偷看”未来的信息。
- 特定任务的信息屏蔽:在某些任务中,可能有选择地屏蔽某些输入部分,以引导其关注特定信息。
二、解题过程/原理讲解
步骤1:理解自注意力的基础计算
给定一个输入序列的矩阵表示 \(X \in \mathbb{R}^{n \times d}\)(n为序列长度,d为特征维度),自注意力首先通过线性变换得到查询(Q)、键(K)、值(V)矩阵:
\[ Q = XW^Q, \quad K = XW^K, \quad V = XW^V \]
然后计算注意力分数矩阵 \(S \in \mathbb{R}^{n \times n}\):
\[ S = \frac{QK^T}{\sqrt{d_k}} \]
其中,\(d_k\) 是K的维度,缩放因子 \(\sqrt{d_k}\) 用于稳定梯度。接着,通过对S的每一行应用Softmax函数,得到注意力权重矩阵 \(A\):
\[ A = \text{softmax}(S) \]
最终输出是V的加权和:
\[ \text{Output} = A V \]
步骤2:引入掩码的核心思想
掩码的核心是在Softmax之前,修改注意力分数矩阵S。我们生成一个与S形状相同的掩码矩阵 \(M \in \mathbb{R}^{n \times n}\)。对于S中我们希望屏蔽的位置(i, j),在M的对应位置放置一个非常大的负数(如 -1e9 或负无穷);对于允许关注的位置,在M的对应位置放置0。
然后,执行一个元素级加法:
\[ S_{\text{masked}} = S + M \]
由于Softmax函数会指数放大较大的输入值,那些被加了一个很大负数的位置,在Softmax之后,其输出权重会趋近于0。即:
\[ A_{\text{masked}} = \text{softmax}(S + M) \]
这样,在计算输出时,模型就不会从那些被屏蔽位置的V中获取信息。
步骤3:详解两种主要掩码类型
-
填充掩码(Padding Mask)
- 目标:防止注意力机制关注输入序列中的填充符(
<pad>)。 - 生成方法:
- 输入一个形状为 \((batch\_size, seq\_len)\) 的序列,以及对应的填充位置标识(例如,
pad_token_id)。 - 创建一个布尔掩码(或0/1掩码),形状为 \((batch\_size, 1, 1, seq\_len)\) 或 \((batch\_size, 1, seq\_len, seq\_len)\)。对于每个序列,
pad_token_id对应的位置为True(或1),其余为False(或0)。通常会将这个掩码扩展为4维,以匹配注意力分数矩阵的维度(batch_size, num_heads, seq_len, seq_len)。 - 将这个掩码矩阵M中
True的位置替换为很大的负数(如-1e9),False的位置替换为0。
- 输入一个形状为 \((batch\_size, seq\_len)\) 的序列,以及对应的填充位置标识(例如,
- 计算效果:对于S中的每一行(对应一个查询位置),其在所有填充符键(K)位置上的分数都会被压低,使其注意力权重为0。例如,对于一个长度为4的序列
[token1, token2, <pad>, <pad>],掩码会确保所有查询位置在计算时都不关注第3、4个位置。
- 目标:防止注意力机制关注输入序列中的填充符(
-
因果掩码/前瞻掩码(Causal Mask / Look-Ahead Mask)
- 目标:在自回归生成任务(如文本生成、语音合成)中,确保解码器在生成第t个位置时,只能“看到”第1到第t-1个位置(及自身)的信息,不能看到未来的t+1到n位置。这保证了生成的因果性。
- 生成方法:
- 创建一个形状为 \((seq\_len, seq\_len)\) 的上三角矩阵(upper triangular matrix),主对角线及以上元素为
True(或1),其余为False(或0)。例如,当seq_len=4时:
这表示第i行(查询位置i)可以关注第j列(键位置j),其中j <= i。[[1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1]] - 将这个矩阵取反,得到因果掩码:
True的位置表示需要被屏蔽的未来信息。取反后矩阵为:
这里[[0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]]1(或True)表示屏蔽。注意,通常实现中主对角线为0(允许关注自身),但有时也屏蔽自身(即主对角线也为1)。 - 同样,将掩码矩阵M中
True的位置替换为很大的负数,False的位置替换为0,然后与S相加。
- 创建一个形状为 \((seq\_len, seq\_len)\) 的上三角矩阵(upper triangular matrix),主对角线及以上元素为
- 计算效果:对于S中的每一行(查询位置),其对应列索引大于行索引的所有位置(即未来的键位置)的分数都会被压低,权重为0。这迫使模型在生成当前词时,仅依赖于已生成的词。
步骤4:掩码的组合使用与计算实现
在实际的Transformer解码器中,通常需要同时使用填充掩码和因果掩码。
- 组合逻辑:最终的掩码M_final 是两个掩码的逻辑“或”(OR)操作。即,一个位置只要被任一掩码标记为需要屏蔽,就会被最终屏蔽。
M_final = PaddingMask | CausalMask - 计算步骤示例:
a. 假设S的形状为(batch_size, num_heads, seq_len, seq_len)。
b. 生成形状为(batch_size, 1, 1, seq_len)的填充掩码M_pad(需要广播到num_heads和seq_len维度)。
c. 生成形状为(1, 1, seq_len, seq_len)的因果掩码M_causal(需要广播到batch_size和num_heads维度)。
d. 合并掩码:M = M_pad | M_causal。
e. 掩码处理:S_masked = S + M * (-1e9)。
f. 计算注意力权重:A = softmax(S_masked, dim=-1)。
g. 计算输出:Output = A @ V。
三、核心要点总结
- 目的:掩码机制是控制自注意力“可见范围”的关键技术,用于处理变长序列和保证自回归生成的因果性。
- 位置:在计算注意力权重(Softmax)之前应用,通过向注意力分数中添加一个很大的负值来实现屏蔽。
- 类型:
- 填充掩码:防止关注无意义的填充符,基于输入序列的实际长度生成。
- 因果掩码:防止“偷看”未来信息,通常是一个上三角矩阵,确保位置i只能关注位置j(j <= i)。
- 实现:两种掩码可以组合使用,通过逻辑“或”合并,然后在Softmax前与原始注意力分数相加。