自注意力机制(Self-Attention)的计算复杂度分析
描述
自注意力机制是Transformer模型的核心组件,它赋予了模型捕捉序列中任意位置元素间依赖关系的能力。然而,这种强大的能力伴随着计算上的代价。理解其计算复杂度对于模型设计、优化以及在长序列场景下的应用至关重要。这个知识点旨在详细拆解自注意力机制的计算过程,并逐步分析其时间与空间复杂度。
解题过程
第一步:回顾自注意力机制的基本计算步骤
首先,我们快速回顾自注意力机制的计算流程。对于一个包含 n 个 token 的输入序列,每个 token 被表示为一个 d 维的向量。我们将所有 token 的向量堆叠成一个矩阵 X,其形状为 (n, d)。
自注意力的计算涉及以下关键步骤:
-
线性变换:将输入
X通过三个可学习的权重矩阵W^Q,W^K,W^V(维度均为(d, d_k),通常d_k = d)进行投影,得到查询(Query)、键(Key)、值(Value)矩阵:Q = X W^Q(形状:(n, d_k))K = X W^K(形状:(n, d_k))V = X W^V(形状:(n, d_k))
-
注意力分数计算:计算每个查询向量与所有键向量的点积,得到注意力分数矩阵
S。S = Q K^T(形状:(n, n))- 这里的
K^T是键矩阵K的转置,形状为(d_k, n)。
-
缩放与 Softmax:对注意力分数矩阵进行缩放(除以
sqrt(d_k))并应用 Softmax 函数,将其转化为概率分布形式的注意力权重矩阵A。A = softmax(S / sqrt(d_k))(形状:(n, n))
-
加权求和:使用注意力权重矩阵
A对值矩阵V进行加权求和,得到最终的输出矩阵O。O = A V(形状:(n, d_k))
第二步:逐步骤分析时间复杂度
时间复杂度通常用大O符号表示,衡量计算步骤数随输入规模(这里主要是序列长度 n 和特征维度 d)的增长速度。
-
线性变换 (
Q,K,V) 的计算复杂度:- 计算
Q = X W^Q。矩阵X的形状是(n, d),W^Q的形状是(d, d_k)。根据矩阵乘法规则,计算Q需要n * d * d_k次乘加操作。 - 同理,计算
K和V也各需要n * d * d_k次操作。 - 因此,线性变换步骤的总复杂度是 O(n * d * d_k)。由于通常
d_k与d是同一数量级(例如d_k = d),我们可以简化为 O(n * d²)。
- 计算
-
注意力分数计算 (
S = Q K^T) 的复杂度:- 矩阵
Q的形状是(n, d_k),K^T的形状是(d_k, n)。 - 它们的乘积
S是一个(n, n)的矩阵。计算这个矩阵中的每一个元素都需要d_k次乘加操作(一个行向量和一个列向量的点积)。 - 由于
S有n * n个元素,所以总操作次数是n * n * d_k。 - 因此,这一步的复杂度是 O(n² * d_k)。同样,若
d_k为常数或与d相关,可简化为 O(n² * d)。这是自注意力机制复杂度的关键部分,因为它与序列长度n的平方成正比。
- 矩阵
-
Softmax 计算的复杂度:
- 对
S矩阵的每一行(对应一个 token 的注意力分数)进行 Softmax 操作。计算一行的 Softmax 需要先计算该行的指数和(O(n) 操作),然后每个元素除以这个和(O(n) 操作)。 - 因为有
n行,所以总复杂度是 O(n²)。
- 对
-
加权求和 (
O = A V) 的复杂度:- 矩阵
A的形状是(n, n),V的形状是(n, d_k)。 - 它们的乘积
O是一个(n, d_k)的矩阵。计算这个矩阵中的每一个元素都需要对n个元素进行加权求和。 - 因此,总操作次数是
n * n * d_k。 - 这一步的复杂度也是 O(n² * d_k) 或 O(n² * d)。
- 矩阵
第三步:总结整体复杂度
将上述所有步骤的复杂度相加:
- O(n * d²) (线性变换)
- O(n² * d) (注意力分数)
- O(n²) (Softmax)
- O(n² * d) (加权求和)
在大O表示法中,我们只保留最高阶的项。当序列长度 n 很大时,n² 项将主导整个计算时间。而 d(特征维度)通常是固定的(如 512, 1024)。因此,自注意力机制的整体时间复杂度是 O(n² * d)。
核心结论:自注意力机制的计算成本随着序列长度 n 呈二次方增长。这是它在处理超长文本、高分辨率图像或长视频序列时的主要瓶颈。
第四步:空间复杂度分析
空间复杂度关注的是存储中间结果所需的内存。
- 在计算过程中,我们需要存储最大的中间矩阵是注意力分数矩阵
S和注意力权重矩阵A,它们的形状都是(n, n)。 - 因此,空间复杂度也是 O(n²)。对于一个有 1000 个 token 的序列,就需要存储一个 1000x1000 的矩阵,这对于内存是很大的负担。
第五步:复杂度的意义与改进方向
理解这个复杂度分析的意义在于:
- 解释模型限制:它解释了为什么原始的Transformer模型难以直接处理非常长的序列(例如,一本书或一整部电影)。
- 指导模型设计:为了克服 O(n²) 的复杂度,研究者们提出了多种高效的注意力变体,例如:
- 局部注意力:只计算每个 token 与一个局部窗口内其他 token 的注意力,将复杂度降至 O(n * k * d),其中
k是窗口大小。 - 稀疏注意力:设计特定的稀疏模式,只计算部分 token 对之间的注意力。
- 线性注意力:通过核函数近似技巧,将 Q 和 K 的乘积顺序重排,从而将复杂度降至 O(n * d²)。
- 分块计算:如 Longformer、BigBird 等模型就采用了这些思想。
- 局部注意力:只计算每个 token 与一个局部窗口内其他 token 的注意力,将复杂度降至 O(n * k * d),其中
通过以上循序渐进的分析,我们可以看到,自注意力机制的强大功能是以计算上的二次方复杂度为代价的,而这也正是当前许多研究致力于优化和改进的重点领域。