ChatGPT量化分析(二) - 存储占用分析
背景
ChatGPT出现后,惊人的效果完全颠覆了业界人员包括笔者的认知,抛开其模型细节层面的因素,已公开的训练方法,需要巨量的数据和计算资源,门槛非常高。本文基于公开资料,希望以量化方式分多篇介绍ChatGPT的分析结论,具体内容包含以下三篇,本文为存储占用分析篇。
ChatGPT模型结构为Transformer,Transformer模型运行时存储可以分成两部分,其一是模型参数,这部分规模是固定的,其二是中间激活,这部分和batch size、sequence length有线性关系,下面对两者分别分析。
模型参数
变量定义:
训练阶段
在训练阶段,采用Adam优化器,先看下Adam的公式SGD优化算法的各种变体。
Adam超参数:
- 学习率
- 惯性项
- 惯性项
- 衰减率
Adam更新公式:
Adam中有
推理阶段
推理阶段只有前向过程,用float16存储的话,存储占用为:
中间激活
变量定义:
TensorFlow、PyTorch、MXNet等深度学习框架以逻辑计算图描述模型,以运行时计算图启动计算,计算图以Tensor(数据)和Operatopn(算子)组织,Operation依赖输入Tensor,通过内部计算得到输出Tensor。更具体一点,前向计算过程中,Operation接收上一Operation产出的激活Tensor和本Operation模型参数Tensor,计算得到激活Tensor给下游Operation,既:
下面按Operation进行分析。训练阶段经常使用混合精度训练,推理阶段也会采用量化,因此下面统一以半精度float16(2 bytes)进行浮点存储和计算。
Embedding
embedding层是lookup操作,输入是词序列,输出形状是
Transformer Blocks
transformer block的计算图如下,每个transformer block主要包含四部分,既multi-head attention和mlp,以及两个add&norm。Multi-head Attention
multi-head attention结构如图:变量定义:
- 矩阵乘法算子
的输入激活Tensor是E,E在第一个block是词本身embedding和position embedding之和,在其他block是上游block的输出,E的形状为 ,乘法输出Tensor形状为 ,因此存储占用为 , 和 同 ; - Softmax算子的输入Tensor为
和 ,输出形状为 ,存储占用为 ; 的输入Tensor为 和 ,输出为 ,存储为 ;- 每个head的存储为
,共 个head,存储占用为: ; - concat的输入为
个 ,输出为 ,存储为 ; 输入为激活 和参数 ,输出为 ,因此存储为 。
实际实现时,往往不会使用concat,而是将多头的Q、K、V合并成大矩阵,因此将concat存储忽略。综上,multi-head attention的输入存储总占用为:
MLP
multi-head attention后面,接两层的全连接网络,计算逻辑为:
- 矩阵乘法算子
的输入激活Tensor x是multi-head attention输出,形状为 ,输出Tensor形状为