自注意力机制(Self-Attention)中的相对位置编码(Relative Positional Encoding)原理与实现详解
1. 题目/知识点描述
在标准的自注意力机制(如Transformer的原始版本)中,模型本身是置换等变的(permutation-equivariant),即对输入序列的顺序不敏感。为了引入序列的顺序信息,原始Transformer使用了绝对位置编码(Absolute Positional Encoding),即为序列中每个位置的词向量加上一个固定的、基于正弦余弦函数的位置编码向量。然而,绝对位置编码存在一个潜在局限:模型在训练时学习到的是固定的绝对位置表示,但在推理时,如果遇到比训练时更长的序列,其位置表示是模型未见过的,可能导致泛化能力下降。此外,在某些任务中(如机器翻译、音乐生成),词与词之间的相对位置关系(例如,“A在B前面3个位置”)比绝对位置(例如,“A在第5个位置”)更为重要。相对位置编码(Relative Positional Encoding)正是为了解决这些问题而提出的。它的核心思想是不将位置信息直接加到输入嵌入上,而是在计算注意力权重时,显式地建模查询向量(Query)与键向量(Key)之间的相对位置偏移。本知识点将详细讲解相对位置编码的动机、核心思想、经典实现方法(如Shaw et al. 2018 和 Transformer-XL/XLNet 中的方法)以及其优势。
2. 循序渐进讲解
步骤1:回顾绝对位置编码的局限性
- 在原始Transformer中,输入嵌入是词嵌入(Word Embedding)和位置嵌入(Position Embedding)的加和:
X_i = E(word_i) + P(i)。这里P(i)是一个仅依赖于绝对位置i的向量。 - 局限性1(长度外推):
P(i)通常是预先用正弦余弦函数定义好的,或者作为可学习参数。模型在训练时只“见过”从1到最大序列长度(如512)的P(i)。在推理时,如果序列长度超过512,模型就遇到了从未见过的位置编码,性能可能下降。 - 局限性2(相对关系建模):在计算注意力时,模型需要判断词
i和词j之间的关系。绝对位置编码需要模型从P(i)和P(j)中隐式地学习到相对距离i-j的信息。这个过程并非直接,可能不够高效和准确。
步骤2:相对位置编码的核心直觉
- 相对位置编码改变了思路:我们不关心词
i在序列中的绝对位置是5,词j的绝对位置是8,我们更关心j在i的后面3个位置(即相对距离i-j = -3或j-i = 3)。 - 核心思想是:在计算词
i(查询)对词j(键)的注意力时,除了原本基于内容的相似度(Q_i · K_j),额外引入一个基于两者相对距离(i - j)的偏置项。 - 数学表达上,注意力分数变为:
e_{ij} = (Q_i * K_j^T) + b_{i-j}。这里b_{i-j}是一个可学习的标量(或向量),它只依赖于相对位置i-j。
步骤3:经典实现方法一 - Shaw et al. 的相对位置编码
- 这篇论文(“Self-Attention with Relative Position Representations”)提出了一种直观的方法。
- 1. 定义相对位置矩阵:假设序列最大长度为L,我们定义所有可能的相对位置偏移
k = i - j,其中-L+1 <= k <= L-1。为每个可能的k学习一个嵌入向量a_k(维度与键向量K相同)。 - 2. 修改注意力计算:在计算注意力权重时,用这个相对位置向量
a_{i-j}来修正键向量K_j。- 原始注意力分数:
e_{ij} = (Q_i * K_j^T) / sqrt(d_k) - 修改后:
e_{ij} = (Q_i * (K_j + a_{i-j})^T) / sqrt(d_k)
- 原始注意力分数:
- 3. 理解:这相当于在衡量
Q_i和K_j的相似度时,不仅看K_j本身的内容,还看j相对于i的位置。a_{i-j}作为一个偏置,告诉模型“当键处于查询的某个相对位置时,其表示应该是什么样子”。 - 4. 实现优化:由于
i和j的组合有L×L种,但i-j只有2L-1种可能,我们可以预先计算好所有a_k,然后通过一个索引查找表(Look-up Table)高效地获取a_{i-j}。
步骤4:经典实现方法二 - Transformer-XL/XLNet 的相对位置编码
- 这种方法更为主流和优雅,被Transformer-XL和XLNet采用。
- 1. 分解注意力分数:它将注意力分数
e_{ij}明确分解为四项:
e_{ij} = (Q_i * K_j^T) + (Q_i * R_{i-j}^T) + (u * K_j^T) + (v * R_{i-j}^T)。这里:(Q_i * K_j^T):基于内容的寻址(Content-based Addressing)。(Q_i * R_{i-j}^T):基于内容的位置偏置(Content-dependent Position Bias)。R_{i-j}是相对位置i-j的编码向量(通常是正弦余弦固定编码或可学习参数)。(u * K_j^T):全局内容偏置(Global Content Bias)。u是一个可学习的向量,代表一个全局的查询向量,用于衡量键K_j的重要性,与位置无关。(v * R_{i-j}^T):全局位置偏置(Global Position Bias)。v是一个可学习的向量,用于衡量相对位置i-j本身的重要性。
- 2. 优势:这种分解使得模型可以灵活地学习不同组成部分的重要性。例如,有些注意力头可能更关注内容(第一项),有些则更关注相对位置(第二、四项)。
- 3. 长度外推性:由于
R_{i-j}通常使用正弦余弦函数(定义域是所有整数),i-j可以远远超过训练时见过的最大长度,模型依然能计算出合理的R_{i-j},从而具有良好的长度外推能力。
步骤5:相对位置编码的优势总结
- 更好的长度外推性:核心依赖相对距离
i-j,而非绝对位置i和j。只要相对距离在训练时覆盖的范围内,模型就能处理更长的序列。 - 更自然的相对关系建模:显式地建模相对距离,更符合许多任务(如语言理解、音乐)的本质,模型更容易学习到“附近”、“远处”等概念。
- 计算效率的权衡:虽然引入了额外的项,但相对位置嵌入
R或a可以复用,且现代深度学习框架(如PyTorch, TensorFlow)有高效的矩阵操作支持,额外的计算开销是可控的。
总结:相对位置编码是对Transformer架构中位置信息表示方式的一个重要改进。它摒弃了为每个绝对位置学习固定表示的做法,转而在注意力计算过程中,动态注入查询与键之间的相对位置信息。通过将相对距离作为一个可学习的偏置项融入注意力分数,模型能够更直接、更泛化地利用序列的顺序信息,特别是在处理长序列和需要理解元素间相对关系的任务上表现出优势。