最近算法同学反馈模型的显存太大了,batch size 打不上去,让我们看一下,所以有了这次排查。

以下提到的模型信息和真实模型无关。

排查的工具

工欲善其事必先利其器,那么用什么东西来排查 PyTorch 的显存呢。

PyTorch 内直接获取显存信息

在 PyTorch 内我们可以通过 torch.cuda.memory_stats() 来查看当前的显存状态,一般来说我们会关注这几个指标。

  • reserved_bytes.all.peak: 显存池中的最大显存
  • allocated_bytes.all.peak:使用中的最大显存

在 torch 中显存由显存池管理,如果删掉了一个 tensor 并不会直接释放它的显存,而是留在池子中等待再次分配。

正因为此,torch 中 allocated 和 reserved 没有满的话并不代表没有显存压力,比如 80G 的显卡 reserved 了 75G,allocated 仅 60G,但还是可能会出现申请一个 6G 显存但是 reserved 池子中没有,只能调用 torch.cuda.empty_cache() 来清空缓存,这样会造成训练降速。

PyTorch memory snapshot

打印 reserved 和 allocated 还是不太直观,我们可以用 PyTorch 提供的另一个工具来 dump 显存以更精确的显存问题分析。

torch.cuda.memory._record_memory_history()

run_your_code()
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")

打出这个 dump 文件之后,可以在 memory_viz 里查看文件内容,可以详细的看到每块显存是何时申请的。

pasted-image-1772958494354.webp

PyTorch Timeline

Timeline 是我们常用的分析 torch 训练性能的工具,在 Timeline 中,除了算子的耗时,你也可以看到算子的显存使用情况。

在 CPU Timeline 中,可以点击算子下面的小三角,即可在详细信息中看到显存的申请和释放情况。

pasted-image-1772958640675.webp

显存分析

显存主要是两块,模型显存和激活值。

FSDP 的模型显存分析

首先看一下模型的情况,模型有 4 层,每层的参数量是 1B,总参数量约为 4B,也就是说只看参数的话应该是 16G 显存,参数、梯度和优化器状态应该是 16G + 16G + 32G 为 64 G。

但是我们的模型使用了 FSDP2,除此之外还需要考虑 FSDP2 的作用。默认情况下 FSDP2 会对模型做全切分,包括参数,梯度和优化器状态,假设我们现在是 8 个 rank,参数、梯度和优化器状态应该在每一个 rank 上占用 8G,当然梯度并不是每一刻都存在,我们可以先不讨论梯度,只有参数和优化器状态之下每一个 rank 应该占用 6G。

以上是在 fp32 下进行讨论,那如果我们配置了 FSDP2 的混合精度呢?我们把 param_dtypereduce_dtype 都设置成 bf16 能节省显存吗?答案是并不会减少这 6G 的显存占用。

FSDP2 的混合精度计算流程是这样的,如果正确配置了 submodules 和 forward prefetch,那么会在 layer N 计算时,对 layer N + 1 的参数做 All Gather,因为配置了 reduce_dtype=bf16,所以会先 cast 到 bf16 再通信,又因为配置了 param_dtype=bf16,所以通信完了之后获取到了参数就是 bf16 的。

在 layer N 的前向计算完之后,FSDP2 的默认行为是释放掉这个 bf16 的完整参数。当反向时,layer N 反向时,会先 All Gather Layer N - 1 的参数,同时 Layer N + 1 的梯度已经计算好了也不会留着,而是 Reduce Scatter 到每一个 rank 上。最后做 Optimizer.step(),这里就不需要额外的通信和cast了,直接用切分的fp32权重和梯度做计算即可。

所以在 FSDP 前向计算中,最多同时存在两个 layer 的 bf16 完整参数,所以此时的显存峰值应该是

shard_param (fp32) + shard_optimizer (fp32) + 2 * layer_full_param (bf16) = 10G

在反向计算中,最多同时存在 layer N + 1 的梯度,layer N 的参数和 layer N 的梯度,所以此时的显存峰值应该是

shard_param (fp32) + shard_optimizer (fp32) + 2 * layer_full_grad (bf16) + layer_full_weight (bf16) = 12G

pasted-image-1773764469031.webp

这个图有点问题,没能体现出来分层的 all gather, reduce scatter,并且没有体现出来反向前还要做一次 all gather,凑合看吧

激活值分析

激活值则主要来自给模型输入的 batch。在大模型训练中,大模型的 batch 里可能是没有过 tokenizer 的原始语句或者是过完 tokenizer 的 id,进到模型之后用 nn.Embedding 去取 Embedding,所以它 batch 中单条样本通常不会多大。

而推荐模型一般不会使用 nn.Embedding,而是使用一个外置的 Parameter Server,在数据流中去查表,我们称之为 pull_sparse,查出 Embedding 后给 Tensor 配置 requires_grad,然后放入 batch 中。反向结束之后需要从 Embedding Tensor 获取 grad,然后做 push sparse 把梯度返回给 Parameter Server。

在看完前面说的 Memory Snapshot 之后,我确定问题就出在激活值这里。

我们的框架做了 prefetch,正常情况下,我们会在前向开始时做上一步的 push_sparse,在后向开始时做下一步的 pull_sparse。所以在合理情况下,最多只会有三份 embedding 大小的显存,如下图所示。

pasted-image-1772970680750.webp

但实际的情况是什么呢?

Step N - 1 的 batch 和 grad 一直到 step N 结束才被释放,也就是训练中最多会存在5份 embedding 大小的显存。

pasted-image-1772971315154.webp

原因分析

现在我们看到了 Embedding 显存异常占用的表象,但是为什么会这样呢?为什么 batch 和 grad 没有能正常释放,这个问题困扰了我很久,我用 objgraph 发现找不到引用。最后发现这是一个 PyTorch 和 Python,框架和模型代码共同作用导致的问题。

我们的训练代码大概是这样的

while train_not_end():
    data = next(dataset)
    outputs = train_one_step(data)
    process_outputs(outputs)

模型代码是这样的

def train_one_step(data):
    loss = model(data)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {"loss": loss} 

乍一看应该就是普通的训练循环,但问题出在这里

    return {"loss": loss} 

只是返回 loss 为什么会造成显存泄露呢?这里有两个原因:

  1. PyTorch 中 loss 会引用整个计算图,所有带 backward fn 的参数和梯度都不会释放。
  2. Python 中循环变量不会在循环的结尾释放,而是在循环下一次运行到这里时,变量被覆盖才会释放。

两者共同作用,导致配置了 requires_grad 的 Embedding Tensor 无法被释放。

所以修复的方式也很简单

  1. 不要直接返回 loss 变量,而是返回 loss.item()
  2. 或者在训练循环的结尾删掉 data 和 output

总结

Python 里总有这种奇怪的小特性,这次算是被坑了一把。虽然病因很简单,但是 profile 显存,算显存理论值这块还是挺有趣的,也算是有些收获。

Reference

  1. Understanding CUDA Memory Usage
  2. Getting Started with Fully Sharded Data Parallel (FSDP2)
  3. Autograd mechanics