Mamba:选择性状态空间的线性时序建模

中文标题:Mamba:选择性状态空间的线性时序建模

英文标题:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

发布平台:预印本

发布日期:2023-12-01

引用量(非实时):201

DOI:

作者:Albert Gu, Tri Dao

关键字: #Mamba

文章类型:preprint

品读时间:2024-04-24 10:16

1 文章萃取

1.1 核心观点

本文以结构化状态空间模型(SSM)为基础提出了 Mamba 架构,Mamba 允许模型根据当前输入选择性地记忆或遗忘信息,而这种时间感知的引入阻碍了卷积结构 SSM 的并行化计算,为此 Mamba 设计了一种在循环结构下的并行扫描算法,并借助硬件感知实现计算加速。

实验分析发现,相比于 Transformer,Mamba 的推理速度快 5×,同时其计算性能随着时序长度呈线性增长;以 Mamba 作为主干的模型在不太下游任务中均有较好的性能表现

1.2 综合评价

  • Mamba 在 S4 模型基础上进行了创新性调整,加入了自动信息选择机制
  • Mamba 实现了循环结构的并行化,同时在内存和计算的成本上具备优势
  • Mamba 在长序列等领域具备较大的优势,值得继续探索挖掘背后的潜力

1.3 主观评分:⭐⭐⭐⭐⭐

2 精读笔记

原始论文不推荐缺乏 SSM 相关前置知识的读者直接观看

建议优先阅读本笔记进行知识补齐,或者本人推荐的[第三方参考](A Visual Guide to Mamba and State Space Models]( https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state )

2.1 状态空间模型

状态空间模型(State Space Models,SSM)

  • 源自传统的控制理论,即通过状态变量对动态系统进行建模
  • SSM 根据当前状态表示 $h(t-1)$ 和输入信息 $x(t)$ 来预测未来的状态 $y(t)$

  • 系数矩阵 $A,B,C,D$ 均是可学习参数,其中 $D$ 类似 skip-connection
  • $A\in \mathbb{R}^{N\times N}$ 表示状态转移矩阵,$B\in \mathbb{R}^{N\times 1}$ 表示输入门,$C\in \mathbb{R}^{N\times 1}$ 表示输出门
  • 矩阵 D 通常不包含在 SSM 模型内部,SSM 模型的整体计算细节如下:

SSM 的结构与 RNN 很相似,只是把权重系数替换为了矩阵系数,去掉了非线性激活

其中状态转移矩阵 $A$ 对应 RNN 中的隐藏状态(hidden state),用于历史信息的总结

2.2 S4:结构化 SSM 时序建模

S4 架构:一种用于超长距离序列建模任务的 SSM

  • S4 架构主要包含三个部分:连续信号的 SSM(上一节已介绍)、HiPPO 矩阵捕捉长期的信息依赖、离散信号下的结构化 SSM(卷积结构与循环结构)
2.2.1 S4 细节 1:处理离散信号

零阶保持技术 (Zero-order hold technique)——从连续信号到离散信号

  • 文本序列一般为离散信号,因此需要引入额外的参数 $\Delta$ 来构建连续信号
  • 参数 $\Delta$ 被称为步长(step size),实现对输入的阶段性保持(resolution)
  • 每次接受到的离散信号后,都保持该值不变,只到接收到新的离散信号:

  • 模型输出阶段,按照步长 $\Delta$ 对连续信号进行采样,得到离散信号

数学形式上,使用以下方式来应用零阶保持:

2.2.2 S4 细节 2:结构化 SSM

状态空间模型的结构化:主要考虑循环结构和(一维)卷积结构两种情况

  • 循环结构的优势是能实现推理高效且支持可变的序列长度,但不支持并行训练
  • 卷积结构的优势是能实现更高效的训练(可并行训练),但对序列长度有要求
  • 两种结构中,状态表示都具备线性时间不变性 (Linear Time Invariance,LTI)
  • LTI 意味着 SSM 的矩阵参数($A,B,C$)对于所有时间步长是固定不变的;即对于任意的输入序列,SSM 使用相同的矩阵参数,SSM 在不感知时间的情况下实现输入的静态表示

思考:如何同时发挥循环结构的推理优势和卷积结构的训练优势?

S4 架构的亮点在于打通了 SSM 的循环结构和卷积结构

在循环结构的情况下,SSM 的计算过程如下(初始化 $x_{-1}=0$): $$ \begin{aligned} &\text{h}_{0}=\bar{B}x_{0} \\ &y_{0} =Ch_0=C\bar{B}x_0 \\ &\text{h} _1=\bar{A}h_0+\bar{B}x_1=\bar{A}\bar{B}x_0+\bar{B}x_1 \\ &y_{1} =Ch_{1}=C\bigg(\bar{A}\bar{B}x_{0}+\bar{B}x_{1}\bigg)=C\bar{A}\bar{B}x_{0}+C\bar{B}x_{1} \\ &\text{h} =\bar{A}h_{1}+\bar{B}x_{2}=\bar{A}\Big(\bar{A}\bar{B}x_{0}+\bar{B}x_{1}\Big)+\bar{B}x_{2}=\bar{A}^{2}\bar{B}x_{0}+\bar{A}\bar{B}x_{1}+\bar{B}x_{2} \\ &y_{2} =Ch_{2}=C\Big(\bar{A}^{2}\bar{B}x_{0}+\bar{A}\bar{B}x_{1}+\bar{B}x_{2}\Big)=C\bar{A}^{2}\bar{B}x_{0}+C\bar{A}\bar{B}x_{1}+C\bar{B}x_{2} \\ &y_{k} =C\bar{A}^{k}\bar{B}x_{0}+C\bar{A}^{k-1}\bar{B}x_{1}+\cdots+C\bar{A}\bar{B}x_{k-1}+C\bar{B}x_{k} \end{aligned} $$ 以上计算过程可以向量化为卷积运算,定义卷积核 $\overline{K}$: $$ \begin{aligned} \overline{K}&=(C\overline{B},C\overline{AB},\ldots,C\overline{A}^{k-1}\overline{B})\in\mathbb{R}^K \\ y_{k} &= x * \overline{K} \end{aligned} $$

  • 其中 $L$ 表示模型的层数;$\overline{A,B}$ 表示零阶保持中的矩阵参数
  • 由于离散矩阵参数满足 LTI,因此在实际训练中每次迭代前提前计算好 $\overline{K}$

至此 S4 便可以在训练时借助卷积实现并行化,在推理时转化为循环结构的 SMM

2.2.3 S4 细节 3:HiPPO 矩阵

HiPPO 矩阵:

  • 类似于 RNN,普通循环结构中的状态转移矩阵 $A$ 只能记录有限的状态信息
  • HiPPO 矩阵会尝试将过去的所有输入信号压缩成一个系数向量,因此借助 HiPPO 矩阵来初始化 $A$ 可以更好地近似历史信息/重建近期信号,并实现模型对更长期信息的依赖
  • HiPPO 矩阵初始化可以显著提升 SSM 在 MNIST 上的基准性能(60%->98%)

HiPPO 矩阵的定义如下:

HiPPO 矩阵的示例:

数学理解,HiPPO 矩阵通过跟踪勒让德多项式的系数来实现历史信息的近似

除了以上技术细节,S4 架构还在 HiPPO 矩阵部分进行了很多计算和内存使用上的优化,缓解长序列容易引起的梯度消失/爆炸问题。比如通过针对矩阵 $A$ 进行对角线加低秩(DPLR)处理;或者在快速傅里叶 FFT 在频率空间计算 SSM 的截断生成函数

关于 S4 架构或 HiPPO 的更多技术细节和代码实现可参阅 The Annotated S4

2.3 从 S4 到 Mamba

Mamba 是一种具备选择性的 SSM

  • Mamba 允许模型过滤掉不相关的信息,并无限期地记住重要的信息
  • Mamba 通过硬件感知算法克服了 SSM 中的 LTI 约束,实现性能加速

2.4 Mamba 细节 1:信息的选择

目前 S4 架构的局限性:

  • Transformer 能根据输入序列动态计算注意力得分,从而进行信息的过滤
  • 而在 SSM 中,矩阵 A、B 和 C 的静态性质限制了模型对内容的感知问题

S4 架构中的关键矩阵参数:

已知输入 $x_k \in \mathbb{R}^{B\times L\times D}$,输出 $y_k \in \mathbb{R}^{B\times L\times D}$

其中 $B$ 表示 batch size;$L$ 表示序列长度;$D$ 表示输入向量的维度

Mamba 架构中的关键矩阵参数:

  • Mamba 改变了关键矩阵参数的张量/维度设计,使得它们具备了时间感知性;即矩阵 $\Delta,B,C$ 会随着输入的序列长度 $L$ 或 batch size $B$ 的变动而改变
  • 在 Mamba 中状态转移矩阵 $A$ 保持不变,但影响状态的方式($B,C$)是动态的
  • 较小的步长 $\Delta$ 会使得模型更关注历史信息;反之则会使得模型更关注当前的输入

至此,Mamba 架构实现了对信息的选择(自主选择信息在隐藏状态中的保留与删除)

思考:Mamba 对 S4架构的改动,引入了具备时间感知性的动态矩阵参数,违反了 LTI 约束;SSM 模型循环结构无法转化为卷积结构,也就无法在训练阶段实现并行化计算

2.5 Mamba 细节 2:并行与硬件感知

复习:RBB 的扫描操作(scan operation)

  • 上图展示了扫描操作的基本过程,该计算循环过程很显然不适合并行化计算

Mamba 通过计算部分序列并迭代组合,来实现并行扫描算法

SRAM:常用于 GPU 中的 L1、L2 缓存,高功耗,低容量,低延迟

DRAM:常用于 GPU 中的显存,功耗更低,容量更大,但延迟更高

硬件感知算法 Hardware-aware Algorithm:

  • GPU 的缺陷: SRAM 和 DRAM 之间的频繁通信导致的计算瓶颈
  • 类似 Flash Attention,Mamba 通过核融合限制 SRAM 和 DRAM 间的通信次数
  • 在 SRAM 中积累一批结果再集中写入 DRAM,从而降低来回读写的次数

最后,硬件感知算法不会保存中间状态,而是在后向传递时对中间状态进行重新计算;因为重新计算的成本,比从相对较慢的 DRAM 中读取中间状态的成本更低

2.5.1 Mamba 架构总结

最终的 Mamba 计算流程可表示如下:

  1. 针对输入 $x_t$ 进行选择性扫描和压缩(Selection Mechanism,Project)
  2. 通过步长 $\Delta_t$ 和零阶保持对离散信号进行处理(Discretize)后进入结构化 SSM
  3. 结构化 SSM 模块先收集来自输入门 $B_t$ 的信息和历史潜在表示 $h_{t-1}$,然后经过状态转移矩阵 $A$ 更新潜在表示得到 $h_{t}$,并最终通过输出门 $C_t$ 实现对目标的预测 $y_t$
  4. 矩阵 A 通过 HiPPO 初始化捕获长程依赖关系;并行扫描和硬件感知算法用于计算加速

类似于 Transformer,Mamba 可融入任意的神经网络模型

Mamba 块:

  • H3(Hungry Hungry Hippo)是一种经典的 SSM 架构
  • 本文通过组合 H3 与门控 MLP,提出了 Mamba 块的基本结构

重复 Mamba 块,与标准归一化和残差连接交错,形成 Mamba 模型架构:

2.6 实验结果分析

用于对比模型盘点:

  • Hyena:使用全局卷积来近似 SSM,在音频和视觉等领域表现出色
  • RWKV:基于线性注意力近似,旨在为语言建模设计高效的 RNN
  • RetNet:通过引入额外的门控来优化 SSM 的并行计算路径
  • H3++:结合了线性注意力和 SSM,通过门控连接来增强模型的性能
  • Transformer++:当前最优的 Transformer 架构,计算快效果好

在 Chinchilla 缩放定律下预训练时,语言任务优于同类开源模型:

  • 左图对应序列长度为 2048 的情况;左图对应序列长度为 8192 的情况
  • 最终效果(评价指标为困惑度 perplexity,越低越好)打平 Transformer++

随着时序长度的增加,Mamba 计算耗时更少并且不会内存溢出:

运行效率分析:

  • 左子图结论:本文提出的并行扫描算法比 PyTorch 的标准实现快 40×
  • 右子图结论:作为循环模型,Mamba 的计算吞吐量比 Transformer 高 5×

消融实验小结:

  • Mamba 块与 H3 块的差异不大,性能差异主要体现在选择性 SSM
  • $\Delta$ 是最重要的参数,对模型性能影响最大;$\Delta$ 即使是 1 维也是有效的
  • 增加 SSM 中的隐藏状态维度 $N$ 能显著提高性能,同时总参数量增加不大

其他结论:

  • Mamba 适用领域广泛,尤其适合基因组、音频和视频等长时序的情况
  • 本文实验多局限于小规模数据,不确定 Mamba 在大尺度模型上的表现

论文作者在后续补充实验中验证了 Mamba 在 3B 参数量下依然表现出色

3 论文后续

Mamba 的名称起源:因为架构中有太多的 ssss 了(蛇的拟声)

RNN VS Mamba(摘自理解 Mamba 模型 - 董鑫

改动 效果
Mamba 去掉了 RNN 的非线性 (tanh) 方便在序列上进行多线程计算, 更适合 GPU
Mamba 的 hidden state 的维度比较高 Mamba 的隐藏状态是 $D\times N$;一个更大的 hidden state 更利于记住更多东西
Mamba 的 A 矩阵在设计时就关注到了让 hidden state 更好的记忆这个问题 (可能) hidden state 能记得更多
Mamba 让每个位置都用一套不同的参数 Capacity 更高, (可能) 学习能力更强

总结:Mamba 简化了传统 RNN 非线性, 但是增加了其参数量和复杂度

未来展望:

  • Mamba 底层理论复杂,不太具备返璞归真的数学美感
  • Mamba 架构无疑具备一定的创新性,但实际潜力不明
  • Mamba 是站在巨人的肩膀上,揉合了很多前人的智慧
  • Mamba 需要一款表现经验的大模型来展示其真正魅力

MambaOut 一文针对 Mamba 的适用性进行分析,并认为 Mamba 的机制非常适合具有长序列和自回归特性的任务。但对部分不符合长序列或自回归特性的图像任务(比如图像分类),Mamba 的表现通常不如卷积和基于注意力的模型(剔除 SSM 模块后反而性能有所提升)

Mamba 原作者在 2024 年 5 月底发表了 Mamba-2 新架构,该论文提出一个叫结构化状态空间二元性(Structured State Space Duality,SSD)的理论框架,将 Transformer 中的注意力机制与 SSM 进行了统一(都可以表示成可半分离矩阵 Semiseparable Matrices 的变换)。同时基于 SSD 思想的新算法,Mamba-2 引入了很多 Transformer 架构的优化方法,使得最终的 Mamba-2 训练更快,性能更强 ——对 Mamba2 感兴趣的话也推荐阅读官方出品的 Mamba2 系列博客

基于 Mamba 框架的衍生模型:

  • 2024-05-08 StyleMamba :基于 SSM 的文本驱动高效图像风格迁移框架
  • 2024-05-23 DiM:Mamba 的效率+扩散模型的表达能力,生成高分辨率图像
  • 2024-05-24 Meteor:借助 Mamba 理解多面性信息,构建大型语言和视觉模型
  • 2024-05-26 Zamba:7B 的 SSM 和 transformer 混合模型,提高长序列处理能力
  • 2024-05-30 DeMamba:用 SSM 捕捉多区域的时空不一致性,实现 AI 视频检测
  • 2024-06-05 Audio Mamba:纯粹基于 SSM 的音频分类模型,处理长音频更高效

相关资源

往年同期文章