Mamba

本文最后更新于:2024年3月12日 晚上

1 Introduction

Mamba是一次用状态空间模型来做深度学习的Foundation Model的尝试,原论文是《Mamba: Linear-Time Sequence Modeling with Selective State Spaces》,arXiv: 2312.00752.

2 前置知识:状态空间模型

2.1 连续情况

状态空间模型在控制系统中常见,其目的是建立一个输入到中间状态(latent state)再到输出的关系。假设输入的信号是u(t)u(t),中间状态是x(t)x(t),输出为y(t)y(t),那么状态空间模型可由两个方程表示

{x(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t)\begin{cases} x'(t) = Ax(t) + Bu(t) \\ y(t) = Cx(t) + Du(t) \end{cases}

方程1是一个关于中间变量和输入信号的微分方程,可以解出x(t)x(t),可以看作对状态之间和输入的建模,方程2则建立了输出,状态和输入的关系。

论文中忽略参了参数D,或者说D=0D=0,作者在S4模型中解释为因为项Du(t)Du(t)可看作模型的“跳跃连接”。如果我们只看方程2,那么抛开状态x(t)x(t)所做的变换,y(t)=[  ]+Du(t)y(t) = [\ \cdot\ ] +Du(t)实际上建立了一条从输入u(t)u(t)直接到输出y(t)y(t)的路径,只不过乘上了参数DD

2.2 离散化

计算机系统无法处理连续的微分方程,而离散化状态空间模型已经是一个well-studied问题,本文中采用ZOH方法来进行离散化,具体来说,如果我们记

{A=exp(ΔA)B=(ΔA)1(exp(ΔAI)ΔB)\begin{cases} \overline{A} = \exp{(\Delta A)}\\ \overline{B} = (\Delta A)^{-1}(\exp(\Delta A-I)\cdot \Delta B)\\ \end{cases}

其中Δ\Delta为一个可调参数,我们称作步距(step size),这个参数实际上也作为可学习参数让模型自己学。那么我们可以将状态空间模型写成序列递推的形式

{hk=Ahk1+Bukyk=Chk\begin{cases} h_k = \overline{A}h_{k-1}+\overline{B}u_k \\ y_k = C h_k \end{cases}

方程1中已经不需要计算微分方程,而改为计算一个关于hk1h_{k-1}hkh_k的递推方程。我们可以把hkRNh_k\in \mathbb{R}^N看作一个RNN里面的隐藏状态(hidden state),而矩阵A\overline{A}是过渡矩阵(transition matrix)。

2.3 卷积并行化

到这里,状态空间模型仍然是递推的,也就意味着无法并行化,但注意其与RNN不同的地方,状态之间的过渡是没有激活函数的,这意味着我们可以迭代这个公式来提前计算中间结果。具体来说,假设初始状态h1=0h_{-1}=0,那么迭代方程hk=Ahk1+Bukh_k = \overline{A}h_{k-1}+\overline{B}u_k可以显式的得到

h0=Bu0y0=CBu0h1=ABu0+Bu1y1=CABu0+CBu1h2=A2Bu0+ABu1+Bu2y2=CA2Bu0+CABu1+CBu2\begin{aligned} &h_0 = \overline{B}u_0\quad y_0=\overline{C}\overline{B}u_0\\ &h_1 = \overline{AB}u_0+\overline{B}u_1\quad y_1=\overline{CAB}u_0+\overline{CB}u_1\\ &h_2=\overline{A^2B}u_0+\overline{AB}u_1+\overline{B}u_2\quad y_2=\overline{CA^2B}u_0+\overline{CAB}u_1+\overline{CB}u_2\\ &\dots \end{aligned}

对于第kk个时间点,可以显式地将yky_k写出来

yk=CAkBu0+CAk1Bu1++CABuk1+CBuky_k = \overline{CA^kB}u_0+\overline{CA^{k-1}B}u_1+\cdots+\overline{CAB}u_{k-1}+\overline{CB}u_k

我们把它写成两个向量点乘

yk=CAkBu0+CAk1Bu1++CABuk1+CBuk=[CAkBCAk1BCABCB][u0u1uk1uk]\begin{aligned} y_k &= \overline{CA^kB}u_0+\overline{CA^{k-1}B}u_1+\cdots+\overline{CAB}u_{k-1}+\overline{CB}u_k\\ &=\begin{bmatrix} \overline{CA^kB} \\ \overline{CA^{k-1}B} \\ \cdots \\ \overline{CAB} \\ \overline{CB} \end{bmatrix}\cdot \begin{bmatrix} u_0\\ u_1\\ \cdots \\ u_{k-1}\\ u_k \end{bmatrix} \end{aligned}

我们对于一个确定的序列长度k+1k+1,可以记

K=[CAkBCAk1BCABCB]\overline{K} = \begin{bmatrix} \overline{CA^kB} \\ \overline{CA^{k-1}B} \\ \cdots \\ \overline{CAB} \\ \overline{CB} \end{bmatrix}

那么序列y=[y0,y1,,yk]y=[y_0, y_1,\dots, y_k]实际上可以由卷积

y=Kuy = \overline{K}*u

计算,具体来说,我们举一个长度为3的序列的例子,

u=[u0u1u2]K=[CA2BCABCB]y=[y0y1y2]u=\begin{bmatrix} u_0\\ u_1\\ u_2 \end{bmatrix}\quad \overline{K}= \begin{bmatrix} \overline{CA^2B}\\ \overline{CAB}\\ \overline{CB} \end{bmatrix}\quad y=\begin{bmatrix} y_0\\ y_1\\ y_2 \end{bmatrix}

我们对输入序列前pad个数等于dimu1\dim{u}-1的0元素,得到

u=[00u0u1u2]u=\begin{bmatrix} 0\\ 0\\ u_0\\ u_1\\ u_2 \end{bmatrix}

然后,我们将K\overline{K}当作滑动窗口与uu对齐并滑动,

[CA2BCABCB][00u0u1u2]\begin{aligned} \begin{bmatrix} \overline{CA^2B}\\ \overline{CAB}\\ \overline{CB} \end{bmatrix} \begin{bmatrix} 0\\ 0\\ u_0\\ u_1\\ u_2 \end{bmatrix} \end{aligned}

卷积核从第一个元素位置开始向下滑动并做点积,结果为

y0=CA2B0+CAB0+CBu0=CBu0y1=CA2B0+CABu0+CBu1=CABu0+CBu1y2=CA2Bu0+CABu1+CBu2=CA2Bu0+CABu1+CBu2\begin{aligned} y_0&=\overline{CA^2B}\cdot 0 +\overline{CAB}\cdot 0+\overline{CB}\cdot u_0\\ &=\overline{CB}u_0\\ y_1&=\overline{CA^2B}\cdot 0 +\overline{CAB}\cdot u_0+\overline{CB}\cdot u_1\\ &=\overline{CAB}u_0+\overline{CB}u_1\\ y_2&=\overline{CA^2B}\cdot u_0 +\overline{CAB}\cdot u_1+\overline{CB}\cdot u_2\\&= \overline{CA^2B}u_0+\overline{CAB}u_1+\overline{CB}u_2 \end{aligned}

与我们之前递归计算的结果是一样的。也就意味着,这个模型可以以卷积的形式,提前算出卷积核并与输入做卷积运算,这就实现了并行化计算。

2.4 HiPPO矩阵

另一个值得注意的技术是作者对矩阵AA用了特殊的初始化技巧,具体来说,AA会被初始化为下面这个矩阵

Ank={(2n+1)1/2(2k+1)1/2n>kn+1n=k0n<kA_{nk} = \begin{cases} (2n+1)^{1/2}(2k+1)^{1/2}\quad n>k\\ n+1\quad n=k\\ 0\quad n<k \end{cases}

这是因为之前随机初始化矩阵A的时候效果并不好,所以使用HiPPO矩阵进行初始化,它能够很好的压缩历史记忆,对于近期的信息衰减较小,而对于过往的记忆衰减较大。如果我们以一个4×44\times 4的方阵为例,HiPPO矩阵为

[1000120013301354]\begin{bmatrix} 1 & 0& 0 &0 \\ 1 &2 & 0 & 0\\ 1 & 3 & 3 & 0 \\ 1 & 3 & 5 & 4 \end{bmatrix}

3 Mamba模型

3.1 动机

作者认为,序列建模的一个基本问题在于怎么把上下文压缩到一个更小的状态,例如Attention就完全不压缩上下文,我们在自回归计算的时候显式地存储了整个上下文(KV-Cache)。所以Transformer是effective and inefficient。而循环的模型(RNN和状态空间模型)都有一个有限的状态,而压缩上下文到这个状态的好坏就决定了这类模型的effectiveness。

在这里作者举了两个例子,一个是复制任务,一个是选择性复制任务。

复制任务定义为,给定一个序列,模型学习将序列的元素复制到几个时间点后的地方,例如原序列为[1,2,3,4,0,0,0,0][1,2,3,4,0,0,0,0],模型学习将其复制到[0,0,0,0,1,2,3,4][0,0,0,0,1,2,3,4]

选择性复制则稍有不同,它要求模型将序列选择性的复制到几个时间点后,例如给定序列[1,2,x,3,0,0,0,0][1,2,x,3,0,0,0,0],模型要学会忽略x,将数据复制为[0,0,0,0,0,1,2,3][0,0,0,0,0,1,2,3]。原序列中x的位置和个数都是随机的。

Induction Heads任务要求模型通过序列“回忆”输入序列,根据上下文来检索答案,例如给定序列[1,x,3,4,0,0,1,?][1,x,3,4,0,0,1,?],其中?要求模型填入下一项,那么模型将需要检索上下文,发现1后面会跟上x,然后推断出下一项为x。这是LLM的一项关键能力。

作者认为,时不变模型,也就是之前的S4模型,参数是固定的,不随输入变化的,这导致模型无法进行内容感知(content-aware)推理。在选择性复制任务中,模型无法根据输入内容中要忽略token的位置和个数来选择性忽略(Optimization只能让模型学会忽略一个固定的pattern),Induction Heads任务也是如此,我们无法以不依赖输入的方式影响沿着序列传递的隐藏状态。

所以作者决定,将S4模型中的参数改成input-dependent。

3.2 S6算法

之前S4模型的算法是

Input: x: (B, L, D)
Output: y: (B, L, D)

  1. A: (D, N) <- Parameter # 初始化矩阵A
  2. B: (D, N) <- Parameter # 初始化矩阵B
  3. C: (D, N) <- Parameter # 初始化矩阵C
  4. Δ\Delta: D <- τΔ\tau_\Delta(Parameter) 初始化Δ\Delta
  5. A,B\overline{A,B}: (D, N) <- discretize(Δ,A,B\Delta, A, B) # 离散化矩阵
  6. y <- SSM(A,B,C\overline{A,B},C)(x) # SSM计算
  7. return y

在这里τΔ\tau_{\Delta}是一个激活函数,为softplus,具体来说,

softplus(x)=log[1+exp(x)]\mathrm{softplus}(x) = \log[1+\exp(x)]

可看作平滑的ReLU\mathrm{ReLU}

更改后的算法(叫做S6)为

Input: x: (B, L, D)
Output: y: (B, L, D)

  1. A: (D, N) <- Parameter # 初始化矩阵A
  2. B: (B, L, N) <- sB(x)s_B(x) # 映射输入为矩阵B
  3. C: (B, L, N) <- sC(x)s_C(x) # 映射输入为矩阵C
  4. Δ\Delta: (B, L, D) <- τΔ\tau_\Delta(Parameter+sΔ(x)s_\Delta(x)) # 映射与随机参数共同决定Δ\Delta
  5. A,B\overline{A,B}: (B, L, D, N) <- discretize(Δ,A,B\Delta, A, B) # 离散化矩阵
  6. y <- SSM(A,B,C\overline{A,B},C)(x) # SSM计算
  7. return y

其中

{sB(x)=LinearDN(x)sC(x)=LinearDN(x)sΔ(x)=BroadcastD[LinearD1(x)]\begin{cases} s_B(x) = \mathrm{Linear_{D\rightarrow N}}(x) \\ s_C(x) = \mathrm{Linear_{D\rightarrow N}}(x)\\ s_\Delta(x) = \mathrm{Broadcast_{D}}[\mathrm{Linear_{D\rightarrow 1}}(x)] \end{cases}

所以,S6算法实际上只是多使用了一层映射将参数矩阵与输入联系起来,从而达到参数随着输入动态变化的效果。

笔者注:这里shape可能对不上,代码中计算矩阵乘法是这样的:
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
,而代码中Δ\Delta的shape就是(B, D, L)。与输入做乘法的时候也是这样:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
迭代计算的部分是
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
注意在这里,u为输入,x为中间状态。x初始化为
x = A.new_zeros((batch, dim, dstate)) shape: (B, D, N)

但现在就出现了一个问题,即然参数随着输入变化了,那模型就是一个time-varying的系统,之前提前算卷积核然后用卷积方式并行计算的trick就不能用了。为了避免完全迭代计算,作者使用了三个trick来进行加速,分别是Kernel Fusion,Parallel Scan和Recomputation。

3.3 高效地实现S6算法

3.3.1 Kernel Fusion

现代GPU加速器一般有两个内存空间,HBM(High-Bandwidth Memory)和SRAM(Static Random-Access Memory)。这构成了GPU的内存层级。我们知道,内存分层级则意味着速度快的内存小,内存大的速度慢。HBM则是慢的那个,SRAM是快的那个。

GPU运算的时候,会将数据从HBM加载到SRAM中,运算完毕后再存回HBM。那么如果我们用多个CUDA Kernel来处理Mamba的迭代过程,就会导致多个Kernel对两个内存的读写。

例如,我们以三个CUDA Kernel为例,那么kernel1读取HBM的数据到SRAM,处理之后返回到HBM,Kernel2读取HBM中的结果开始处理,结果又存回HBM,Kernel3再读取HBM里的结果处理后再存回HBM。这些kernel可能 是负责离散化,迭代和最后输出的与矩阵C的矩阵乘法的。作者则把离散化,迭代(后续替代为Parallel scan)和矩阵乘法写到一个自定义的Kernel里面,这样就把几个不同的Kernel融合为一个,叫做Kernel Fusion。

具体来说,他们在一个CUDA Kernel内做以下步骤

  1. 从HBM中读取Δ\Delta, A, B, C到SRAM中
  2. 在SRAM中离散化,得到A,B\overline{A,B}
  3. 做parallel scan,在SRAM中得到中间状态
  4. 乘矩阵C,得到最终结果并写回HBM

作者声称能提速20-40倍。

3.3.2 Parallel Scan

Parallel scan是使用并行化的方法来进行序列操作,最开始是用在prefix-sum问题上。给出一个序列[1,2,3,4,5][1,2,3,4,5], prefix-sum指对前n个元素进行求和,n[1,k]n\in [1,k], k为序列长,得到序列[1,3,6,10,15][1,3,6,10,15]。这个问题用for-loop当然可以很简单的解决,但它可以用多线程来并行计算。

给定一个长度为8的输入序列[x1,x2,x3,x4,x5,x6,x7,x8][x_1,x_2,x_3,x_4,x_5,x_6,x_7,x_8],我们按以下方式计算,首先开四个线程分别计算得到

a=[x1+x2],b=[x3+x4],c=[x5+x6],d=[x7+x8]a=[x_1+x_2], b=[x_3+x_4], c=[x_5+x_6], d=[x_7+x_8]

再开两个线程计算上面得到的结果,得到

e=[x1+x2+x3+x4],f=[x5+x6+x7+x8]e=[x_1+x_2+x_3+x_4], f=[x_5+x_6+x_7+x_8]

最后一个线程求和得到

g=[x1+x2+x3+x4+x5+x6+x7+x8]g=[x_1+x_2+x_3+x_4+x_5+x_6+x_7+x_8]

上述过程叫Up-Sweep。我们利用以上得到的结果和原序列,按照以下方式再次计算。首先一个线程计算得到

h=e+c=[x1+x2+x3+x4+x5+x6]h=e+c=[x_1+x_2+x_3+x_4+x_5+x_6]

再开四个线程计算得到

a+x3=[x1+x2+x3],e+x5=[x1+x2+x3+x4+x5]h+x7=[x1+x2+x3+x4+x5+x6+x7]\begin{aligned} &a+x_3=[x_1+x_2+x_3], e+x_5=[x_1+x_2+x_3+x_4+x_5]\\ &h+x_7=[x_1+x_2+x_3+x_4+x_5+x_6+x_7] \end{aligned}

上述过程称为Down-Sweep, 此时我们已经得到了完整的prefix-sum,为

[x1,a,a+x3,e,e+x5,h,h+x7,g][x_1,a,a+x_3,e,e+x_5,h,h+x_7,g]

Mamba的迭代过程可以定义为类似于prefix-sum的问题,假设prefix-sum的输入是x=[xi]x=[x_i],其输出是为y=[yi]y=[y_i],那么关系为yi=yi1+xiy_i=y_{i-1}+x_i。上述Mamba的状态迭代方程为hk=Ahk1+Bukh_k = \overline{A}h_{k-1}+\overline{B}u_k,抛开多乘了两个参数,这两个关系式的形式是Identical的。所以我们可以使用Parallel Scan来并行化Mamba的计算。

3.3.3 Recomputation

作者提到,他们为了节省显存,他们carefully使用Recomputation技巧,也就是前向过程不存储中间状态,而是在反向传播的时候再重新计算,具体的操作在论文附录中有提到,这一部分应该不难阅读。

3.4 Mamba Block

这一部分就是喜闻乐见的搭积木环节了,block如下
Mamba Blcok

注意这里作者并没画全所有的Module,完整的block大概是这个样子
Mamba Blcok

4 总结

已经有好几个声称能媲美或打败Transformer的模型了,他们应该有种种缺点所以最终没有被大规模采用。Mamba可以说是少见的被follow的很快的工作,但个人感觉在某些方面应该还是比不过Transformer。最近谷歌在3月1日新发了Griffin和Hawk模型(arXiv:2402.19427),好像还没开源,可以观望一下。

也佩服作者Albert Gu的毅力,S4模型,HiPPO矩阵等工作都有他的参与,可谓是一路下来把状态空间模型从不work改进到work。Griffin论文中也见到了Albert Gu的参与。


Mamba
https://jesseprince.github.io/2024/03/11/sequence/ssm/mamba/
作者
林正
发布于
2024年3月11日
许可协议