背景

ChatGPT出现后,惊人的效果完全颠覆了业界人员包括笔者的认知,抛开其模型细节层面的因素,已公开的训练方法,需要巨量的数据和计算资源,门槛非常高。本文基于公开资料,希望以量化方式分多篇介绍ChatGPT的分析结论,具体内容包含以下三篇,本文为存储占用分析篇。

ChatGPT模型结构为Transformer,Transformer模型运行时存储可以分成两部分,其一是模型参数,这部分规模是固定的,其二是中间激活,这部分和batch size、sequence length有线性关系,下面对两者分别分析。

模型参数

变量定义:

训练阶段

在训练阶段,采用Adam优化器,先看下Adam的公式SGD优化算法的各种变体

Adam超参数:

  • 学习率
  • 惯性项
  • 惯性项
  • 衰减率

Adam更新公式:

Adam中有两个动量累积值,因此每个模型训练参数会对应4个浮点值,包括参数本身、梯度值、Adam两个动量累积值。这部分参数对精度要求较高,需要用到float32(4 bytes)存储,每个参数占用存储16 bytes。 因此,训练阶段全部模型参数共占用存储为: GPT3 175B模型,训练阶段参数存储占用为2595GB。

推理阶段

推理阶段只有前向过程,用float16存储的话,存储占用为: GPT3 175B模型,推理阶段参数存储占用为324GB。可以看到,A100单卡80G也无法放下,需要更低的存储精度或者多级存储方案,或者单机多卡。

中间激活

变量定义:

TensorFlow、PyTorch、MXNet等深度学习框架以逻辑计算图描述模型,以运行时计算图启动计算,计算图以Tensor(数据)和Operatopn(算子)组织,Operation依赖输入Tensor,通过内部计算得到输出Tensor。更具体一点,前向计算过程中,Operation接收上一Operation产出的激活Tensor和本Operation模型参数Tensor,计算得到激活Tensor给下游Operation,既: 后向计算过程,除上述Tensor外,还有下游Operation后向传回来的梯度Tensor,既: 其中,需要传给上游Operation用于后向传播的梯度计算,传给优化器用于梯度更新。模型参数param和其梯度param grad在上一节内容已经包括在内,所以本节只考虑activation相关的Tensor存储占用即可。Operation内部也有中间变量,这些临时变量生命中期较短,因此本文讨论存储占用时,不考虑这部分,只考虑计算图中Operation之间的Tensor。一个Operation的输出Tensor是下游多个Operation的输入Tensor,此Tensor只占用一份存储即可,因此下文统计存储时只考虑Operation的输出Tensor。中间激活也和Operation的数量或粒度有关,细粒度Operation接口的激活值更多,粗粒度Operation将部分激活转化成了内部临时变量,因此接口的激活值更少。

下面按Operation进行分析。训练阶段经常使用混合精度训练,推理阶段也会采用量化,因此下面统一以半精度float16(2 bytes)进行浮点存储和计算。

Embedding

embedding层是lookup操作,输入是词序列,输出形状是,存储占用为:

Transformer Blocks

transformer block的计算图如下,每个transformer block主要包含四部分,既multi-head attention和mlp,以及两个add&norm。

图片替换文本

Multi-head Attention

multi-head attention结构如图:

图片替换文本

变量定义: multi-head attention的计算逻辑为: 其中,

  • 矩阵乘法算子的输入激活Tensor是E,E在第一个block是词本身embedding和position embedding之和,在其他block是上游block的输出,E的形状为,乘法输出Tensor形状为,因此存储占用为
  • Softmax算子的输入Tensor为,输出形状为,存储占用为
  • 的输入Tensor为,输出为,存储为
  • 每个head的存储为,共个head,存储占用为:
  • concat的输入为,输出为,存储为
  • 输入为激活和参数,输出为,因此存储为

实际实现时,往往不会使用concat,而是将多头的Q、K、V合并成大矩阵,因此将concat存储忽略。综上,multi-head attention的输入存储总占用为:

MLP

multi-head attention后面,接两层的全连接网络,计算逻辑为: 其中,

  • 矩阵乘法算子的输入激活Tensor x是multi-head attention输出,形状为,输出Tensor形状为,因此存储占用为
  • 加法部分输入是激活和参数,输出是,存储占用为
  • 输入和输出都是,输入存储为
  • 的输入是激活和参数,输出是,输入存储为
  • 加法部分输入是激活和参数,输出是,输入存储占用为

综上,mlp部分的存储占用为:

Add&Norm

add部分的输入激活可以与multi-head attention和mlp里的输入共享存储,因此不用计算。norm的输入和输出都是,两个norm的存储占用为。 所以,add&norm部分的存储占用为:

中间激活存储占用

综上,每个transformer block的存储占用为: 个block,因此总存储占用为:

训练阶段

在训练阶段,中间激活需要保留用于反向传播的梯度计算,并且每个激活tensor都对应一个梯度tensor,因此存储占用翻倍,既。 按GPT3 175B配置计算: 的话,每种的存储占用为:

中间激活存储(GB)
128 42624
512 170496
1024 340992
4096 1363968

这里将较小的embedding存储忽略,可以看到,相比模型参数部分占用的2595GB,中间激活存储占比更大,需要借助混合并行(数据并行、流水并行、模型并行)、recompute、ZeRO等技术来解决。在batch为128时,存储消耗2595+42624G,需要近600张A100 80G卡。

推理阶段

在推理阶段,中间激活使用完成后可以立即释放。多个block串行计算,block间可以共享内存。block内部有些tensor也串行依赖,tensor使用完成后即可释放。因此,推理阶段最大存储占用比单block存储大小更小一点。 按GPT3 175B配置计算: 的话,约为1.7G,相比于模型参数占用324G很小,因此推理阶段存储主要消耗在模型参数上。

参考资料