自注意力机制(Self-Attention)中的掩码(Masking)机制详解
字数 3012 2025-12-07 20:41:22

自注意力机制(Self-Attention)中的掩码(Masking)机制详解

一、知识点描述

在自注意力机制中,掩码(Masking)是一种控制注意力计算范围的技术,通过屏蔽(设置为负无穷大)某些注意力权重,防止模型“看到”或“关注”特定位置的信息。它在多种场景下至关重要,例如:

  1. 处理变长序列:在批量训练中,用特殊填充符(如<pad>)补齐不同长度序列,掩码可防止注意力机制关注这些无意义的填充位置。
  2. 解码器的自回归生成:在Transformer解码器或类似自回归模型中,为保持生成过程的因果性(Causal),需确保每个位置在生成时只能关注其自身及之前的已知位置,不能“偷看”未来的信息。
  3. 特定任务的信息屏蔽:在某些任务中,可能有选择地屏蔽某些输入部分,以引导其关注特定信息。

二、解题过程/原理讲解

步骤1:理解自注意力的基础计算
给定一个输入序列的矩阵表示 \(X \in \mathbb{R}^{n \times d}\)(n为序列长度,d为特征维度),自注意力首先通过线性变换得到查询(Q)、键(K)、值(V)矩阵:

\[ Q = XW^Q, \quad K = XW^K, \quad V = XW^V \]

然后计算注意力分数矩阵 \(S \in \mathbb{R}^{n \times n}\)

\[ S = \frac{QK^T}{\sqrt{d_k}} \]

其中,\(d_k\) 是K的维度,缩放因子 \(\sqrt{d_k}\) 用于稳定梯度。接着,通过对S的每一行应用Softmax函数,得到注意力权重矩阵 \(A\)

\[ A = \text{softmax}(S) \]

最终输出是V的加权和:

\[ \text{Output} = A V \]

步骤2:引入掩码的核心思想
掩码的核心是在Softmax之前,修改注意力分数矩阵S。我们生成一个与S形状相同的掩码矩阵 \(M \in \mathbb{R}^{n \times n}\)。对于S中我们希望屏蔽的位置(i, j),在M的对应位置放置一个非常大的负数(如 -1e9 或负无穷);对于允许关注的位置,在M的对应位置放置0。
然后,执行一个元素级加法:

\[ S_{\text{masked}} = S + M \]

由于Softmax函数会指数放大较大的输入值,那些被加了一个很大负数的位置,在Softmax之后,其输出权重会趋近于0。即:

\[ A_{\text{masked}} = \text{softmax}(S + M) \]

这样,在计算输出时,模型就不会从那些被屏蔽位置的V中获取信息。

步骤3:详解两种主要掩码类型

  • 填充掩码(Padding Mask)

    • 目标:防止注意力机制关注输入序列中的填充符(<pad>)。
    • 生成方法
      1. 输入一个形状为 \((batch\_size, seq\_len)\) 的序列,以及对应的填充位置标识(例如,pad_token_id)。
      2. 创建一个布尔掩码(或0/1掩码),形状为 \((batch\_size, 1, 1, seq\_len)\)\((batch\_size, 1, seq\_len, seq\_len)\)。对于每个序列,pad_token_id对应的位置为True(或1),其余为False(或0)。通常会将这个掩码扩展为4维,以匹配注意力分数矩阵的维度(batch_size, num_heads, seq_len, seq_len)。
      3. 将这个掩码矩阵M中True的位置替换为很大的负数(如-1e9),False的位置替换为0。
    • 计算效果:对于S中的每一行(对应一个查询位置),其在所有填充符键(K)位置上的分数都会被压低,使其注意力权重为0。例如,对于一个长度为4的序列 [token1, token2, <pad>, <pad>],掩码会确保所有查询位置在计算时都不关注第3、4个位置。
  • 因果掩码/前瞻掩码(Causal Mask / Look-Ahead Mask)

    • 目标:在自回归生成任务(如文本生成、语音合成)中,确保解码器在生成第t个位置时,只能“看到”第1到第t-1个位置(及自身)的信息,不能看到未来的t+1到n位置。这保证了生成的因果性。
    • 生成方法
      1. 创建一个形状为 \((seq\_len, seq\_len)\) 的上三角矩阵(upper triangular matrix),主对角线及以上元素为True(或1),其余为False(或0)。例如,当seq_len=4时:
        [[1, 1, 1, 1],
         [0, 1, 1, 1],
         [0, 0, 1, 1],
         [0, 0, 0, 1]]
        
        这表示第i行(查询位置i)可以关注第j列(键位置j),其中j <= i。
      2. 将这个矩阵取反,得到因果掩码True的位置表示需要被屏蔽的未来信息。取反后矩阵为:
        [[0, 1, 1, 1],
         [1, 0, 1, 1],
         [1, 1, 0, 1],
         [1, 1, 1, 0]]
        
        这里1(或True)表示屏蔽。注意,通常实现中主对角线为0(允许关注自身),但有时也屏蔽自身(即主对角线也为1)。
      3. 同样,将掩码矩阵M中True的位置替换为很大的负数,False的位置替换为0,然后与S相加。
    • 计算效果:对于S中的每一行(查询位置),其对应列索引大于行索引的所有位置(即未来的键位置)的分数都会被压低,权重为0。这迫使模型在生成当前词时,仅依赖于已生成的词。

步骤4:掩码的组合使用与计算实现
在实际的Transformer解码器中,通常需要同时使用填充掩码和因果掩码

  1. 组合逻辑:最终的掩码M_final 是两个掩码的逻辑“或”(OR)操作。即,一个位置只要被任一掩码标记为需要屏蔽,就会被最终屏蔽。
    M_final = PaddingMask | CausalMask
  2. 计算步骤示例
    a. 假设S的形状为 (batch_size, num_heads, seq_len, seq_len)
    b. 生成形状为 (batch_size, 1, 1, seq_len) 的填充掩码M_pad(需要广播到num_heads和seq_len维度)。
    c. 生成形状为 (1, 1, seq_len, seq_len) 的因果掩码M_causal(需要广播到batch_size和num_heads维度)。
    d. 合并掩码:M = M_pad | M_causal
    e. 掩码处理:S_masked = S + M * (-1e9)
    f. 计算注意力权重:A = softmax(S_masked, dim=-1)
    g. 计算输出:Output = A @ V

三、核心要点总结

  • 目的:掩码机制是控制自注意力“可见范围”的关键技术,用于处理变长序列和保证自回归生成的因果性
  • 位置:在计算注意力权重(Softmax)之前应用,通过向注意力分数中添加一个很大的负值来实现屏蔽。
  • 类型
    • 填充掩码:防止关注无意义的填充符,基于输入序列的实际长度生成。
    • 因果掩码:防止“偷看”未来信息,通常是一个上三角矩阵,确保位置i只能关注位置j(j <= i)。
  • 实现:两种掩码可以组合使用,通过逻辑“或”合并,然后在Softmax前与原始注意力分数相加。
自注意力机制(Self-Attention)中的掩码(Masking)机制详解 一、知识点描述 在自注意力机制中,掩码(Masking)是一种控制注意力计算范围的技术,通过屏蔽(设置为负无穷大)某些注意力权重,防止模型“看到”或“关注”特定位置的信息。它在多种场景下至关重要,例如: 处理变长序列 :在批量训练中,用特殊填充符(如 <pad> )补齐不同长度序列,掩码可防止注意力机制关注这些无意义的填充位置。 解码器的自回归生成 :在Transformer解码器或类似自回归模型中,为保持生成过程的因果性(Causal),需确保每个位置在生成时只能关注其自身及之前的已知位置,不能“偷看”未来的信息。 特定任务的信息屏蔽 :在某些任务中,可能有选择地屏蔽某些输入部分,以引导其关注特定信息。 二、解题过程/原理讲解 步骤1:理解自注意力的基础计算 给定一个输入序列的矩阵表示 \( X \in \mathbb{R}^{n \times d} \)(n为序列长度,d为特征维度),自注意力首先通过线性变换得到查询(Q)、键(K)、值(V)矩阵: \[ Q = XW^Q, \quad K = XW^K, \quad V = XW^V \] 然后计算注意力分数矩阵 \( S \in \mathbb{R}^{n \times n} \): \[ S = \frac{QK^T}{\sqrt{d_ k}} \] 其中,\( d_ k \) 是K的维度,缩放因子 \( \sqrt{d_ k} \) 用于稳定梯度。接着,通过对S的每一行应用Softmax函数,得到注意力权重矩阵 \( A \): \[ A = \text{softmax}(S) \] 最终输出是V的加权和: \[ \text{Output} = A V \] 步骤2:引入掩码的核心思想 掩码的核心是 在Softmax之前,修改注意力分数矩阵S 。我们生成一个与S形状相同的掩码矩阵 \( M \in \mathbb{R}^{n \times n} \)。对于S中我们希望屏蔽的位置(i, j),在M的对应位置放置一个非常大的负数(如 -1e9 或负无穷);对于允许关注的位置,在M的对应位置放置0。 然后,执行一个元素级加法: \[ S_ {\text{masked}} = S + M \] 由于Softmax函数会指数放大较大的输入值,那些被加了一个很大负数的位置,在Softmax之后,其输出权重会趋近于0。即: \[ A_ {\text{masked}} = \text{softmax}(S + M) \] 这样,在计算输出时,模型就不会从那些被屏蔽位置的V中获取信息。 步骤3:详解两种主要掩码类型 填充掩码(Padding Mask) 目标 :防止注意力机制关注输入序列中的填充符( <pad> )。 生成方法 : 输入一个形状为 \( (batch\_size, seq\_len) \) 的序列,以及对应的填充位置标识(例如, pad_token_id )。 创建一个布尔掩码(或0/1掩码),形状为 \( (batch\_size, 1, 1, seq\_len) \) 或 \( (batch\_size, 1, seq\_len, seq\_len) \)。对于每个序列, pad_token_id 对应的位置为 True (或1),其余为 False (或0)。通常会将这个掩码扩展为4维,以匹配注意力分数矩阵的维度( batch_size, num_heads, seq_len, seq_len )。 将这个掩码矩阵M中 True 的位置替换为很大的负数(如 -1e9 ), False 的位置替换为0。 计算效果 :对于S中的每一行(对应一个查询位置),其在所有填充符键(K)位置上的分数都会被压低,使其注意力权重为0。例如,对于一个长度为4的序列 [token1, token2, <pad>, <pad>] ,掩码会确保所有查询位置在计算时都不关注第3、4个位置。 因果掩码/前瞻掩码(Causal Mask / Look-Ahead Mask) 目标 :在自回归生成任务(如文本生成、语音合成)中,确保解码器在生成第t个位置时,只能“看到”第1到第t-1个位置(及自身)的信息,不能看到未来的t+1到n位置。这保证了生成的因果性。 生成方法 : 创建一个形状为 \( (seq\_len, seq\_len) \) 的上三角矩阵(upper triangular matrix),主对角线及以上元素为 True (或1),其余为 False (或0)。例如,当 seq_len=4 时: 这表示第i行(查询位置i)可以关注第j列(键位置j),其中j <= i。 将这个矩阵取反,得到 因果掩码 : True 的位置表示需要被屏蔽的未来信息。取反后矩阵为: 这里 1 (或 True )表示屏蔽。注意,通常实现中主对角线为0(允许关注自身),但有时也屏蔽自身(即主对角线也为1)。 同样,将掩码矩阵M中 True 的位置替换为很大的负数, False 的位置替换为0,然后与S相加。 计算效果 :对于S中的每一行(查询位置),其对应列索引大于行索引的所有位置(即未来的键位置)的分数都会被压低,权重为0。这迫使模型在生成当前词时,仅依赖于已生成的词。 步骤4:掩码的组合使用与计算实现 在实际的Transformer解码器中,通常需要 同时使用填充掩码和因果掩码 。 组合逻辑 :最终的掩码M_ final 是两个掩码的逻辑“或”(OR)操作。即,一个位置只要被任一掩码标记为需要屏蔽,就会被最终屏蔽。 M_final = PaddingMask | CausalMask 计算步骤示例 : a. 假设S的形状为 (batch_size, num_heads, seq_len, seq_len) 。 b. 生成形状为 (batch_size, 1, 1, seq_len) 的填充掩码M_ pad(需要广播到num_ heads和seq_ len维度)。 c. 生成形状为 (1, 1, seq_len, seq_len) 的因果掩码M_ causal(需要广播到batch_ size和num_ heads维度)。 d. 合并掩码: M = M_pad | M_causal 。 e. 掩码处理: S_masked = S + M * (-1e9) 。 f. 计算注意力权重: A = softmax(S_masked, dim=-1) 。 g. 计算输出: Output = A @ V 。 三、核心要点总结 目的 :掩码机制是控制自注意力“可见范围”的关键技术,用于处理 变长序列 和保证 自回归生成的因果性 。 位置 :在计算注意力权重(Softmax) 之前 应用,通过向注意力分数中添加一个很大的负值来实现屏蔽。 类型 : 填充掩码 :防止关注无意义的填充符,基于输入序列的实际长度生成。 因果掩码 :防止“偷看”未来信息,通常是一个上三角矩阵,确保位置i只能关注位置j(j <= i)。 实现 :两种掩码可以组合使用,通过逻辑“或”合并,然后在Softmax前与原始注意力分数相加。