大模型的显存占用

大模型的显存占用主要取决于以下几点:

  1. 参数量:以常见的大模型 Llama2为例,其常见的参数量包括7B、13B、70B;其中B表示十亿(billion)的参数级别,7B也就代表70亿个参数
  2. 参数精度:常见的浮点精度包括float32(占用4字节,32bit)、float16(16bit)、int8(8bit)、int4(4bit)等,占用空间依次递减,但模型的预测效果也会下滑

以Llama2-7B模型为例,在精度为float32的情况下,模型占用显存为: $$7\times 10^9\times 4 bit=28\times 10^9/1024 KB=28\times 10^9/1024^3 \ GB\approx26.077GB$$ 以此类推,精度为float16、int8和int4的情况下,模型占用显存约为13G、6.5G、3.26G

此外,模型推理时还需要存储一些中间过程文件,因此实际显存占用会比计算值高一些;一般经验认为,此额外开销 <= 20%,即实际推理显存占用 ≈ 1.2 倍的模型显存占用

不同尺度模型推理时的所需显存:

以上情况仅考虑了模型部署和推理时的内存占用,模型训练时则会大得多。一般的经验推论,认为模型训练所需的显存会是模型推理时的十几倍;另一种说法,则认为训练时的显存占用约是参数量的 20x 或 16x(比如训练一个 7B 参数量的模型,就需要至少 140G 的显存)

模型训练时的显存占用影响因素:参数量、梯度、优化器参数、样本大小、BatchSize

更详细的模型训练计算成本和显存占用的计算可参阅 Transformer Math 101

参考:大模型之显存占用计算

往年同期文章