Transformer模型中的梯度检查点(Gradient Checkpointing)技术详解
字数 1119 2025-11-15 01:48:33
Transformer模型中的梯度检查点(Gradient Checkpointing)技术详解
1. 问题背景
Transformer模型(尤其是大型预训练模型如BERT、GPT)在训练时面临显存瓶颈。模型显存占用主要包括两部分:
- 模型参数显存:存储权重、梯度等。
- 激活值显存:前向传播过程中产生的中间结果(如注意力分数、层输出),需保留用于反向传播计算梯度。
随着模型层数加深(如GPT-3有96层),激活值显存成为主要限制。例如,训练一个10亿参数的模型可能需要数十GB显存,其中大部分被激活值占用。
2. 梯度检查点的核心思想
目标:通过牺牲部分计算时间,显著减少激活值显存占用。
原理:在前向传播时,不保存所有中间激活值,而是只保留少量关键层的激活值。在反向传播时,临时重新计算缺失的中间结果。
- 传统方法:存储所有激活值,显存复杂度为O(L)(L为层数)。
- 梯度检查点:仅存储部分激活值(如每k层存一次),显存复杂度降至O(√L)或O(log L)。
3. 具体实现步骤
以Transformer的N层堆叠为例,假设选择每2层作为一个检查点段:
前向传播阶段:
- 对第1-2层正常计算,但只保留第2层的输出(检查点),释放第1层的激活值。
- 计算第3-4层时,仅保留第4层的输出,释放第3层激活值。
- 重复直到最后一层。
反向传播阶段(从顶层到底层):
- 需要计算第4层的梯度时,利用保存的第2层输出,重新前向计算第3-4层,得到第3层的激活值,再完成第4层的反向传播。
- 计算第2层的梯度时,利用第2层的检查点,重新计算第1-2层的中间结果。
关键细节:
- 检查点选择:通常均匀分段(如每k层),也可根据层显存需求动态调整。
- 重计算代价:每段需额外执行一次前向计算,总计算量增加约30%,但显存节省可达50%以上。
4. 数学与代码示意
设Transformer层函数为layer_i(x),传统前向传播:
activations = [input]
for i in range(L):
activations.append(layer_i(activations[-1]))
梯度检查点版本:
checkpoints = {} # 存储第k层的输出
segments = [(0,2), (2,4), ...] # 分段策略
for start, end in segments:
x = checkpoints.get(start, activations[start])
for i in range(start, end):
x = layer_i(x)
checkpoints[end] = x # 存储检查点
反向传播时,对每个段从右向左重计算并求导。
5. 实际应用与权衡
- 适用场景:显存受限的大模型训练(如GPT、T5)。
- 权衡因素:
- 显存节省:可减少50%-90%激活值显存。
- 时间开销:增加20%-40%计算时间(因重计算)。
- 扩展优化:结合CPU offloading(将检查点存到CPU)、选择性重计算(仅重计算高显存层)进一步优化。
6. 总结
梯度检查点是一种用计算换显存的技术,通过分段存储和重计算打破显存与模型规模的线性关系,使训练超大规模Transformer成为可能。其核心是在时间与空间之间找到平衡,是现代大模型训练的关键技术之一。