Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

2023-07-18 12:14:32 - 机器之心Pro

编辑:蛋酱、张倩、陈萍

LLM的成功,某种程度上要归功于Transformer架构在自然语言处理任务上的突破。该架构最初是为了克服循环模型的sequentialtraining问题而提出的。这些年来,Transformer已经成为LLM普遍采用的架构。

然而,Transformer的训练并行性是以低效推理为代价的:每一步的复杂度为O(N)且键值缓存受内存限制,让Transformer不适合部署。不断增长的序列长度会增加GPU内存消耗和延迟,并降低推理速度。

研究者们一直在努力开发下一代架构,希望保留训练并行性和Transformer的性能,同时实现高效的O(1)推理。针对这个问题,此前的方法都没能同时实现这几点,至少与Transformer相比没有显示出绝对的优势。

现在,微软研究院和清华大学的研究者已经在这个问题上取得了重大突破。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

论文链接:https://arxiv.org/pdf/2307.08621.pdf

在这项工作中,研究者提出了retentive网络(RetNet),同时实现了低成本推理、高效长序列建模、媲美Transformer的性能和并行模型训练,打破了「不可能三角」。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

具体来说,RetNet引入了一种多尺度retention机制来替代多头注意力,它有三种计算范式:并行、循环和分块循环表征。

首先,并行表征使训练并行化,以充分利用GPU设备。其次,循环表征法在内存和计算方面实现了高效的O(1)推理。部署成本和延迟可以显著降低,同时无需键值缓存技巧,大大简化了实现过程。此外,分块循环表征法能够执行高效的长序列建模。研究者对每个局部块进行并行编码以提高计算速度,同时对全局块进行循环编码以节省GPU内存。

论文进行了大量实验来对比RetNet和Transformer及其变体。实验结果表明,RetNet在scaling曲线和上下文学习方面始终具有竞争力。此外,RetNet的推理成本与长度无关。对于7B模型和8k序列长度,RetNet的解码速度是带键值缓存的Transformers的8.4倍,内存节省70%。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

在训练过程中,RetNet也能够比标准Transformer节省25-50%的内存,实现7倍的加速,并在高度优化的FlashAttention方面具有优势。此外,RetNet的推理延迟对批大小不敏感,从而实现了巨大的吞吐量。

这些令人惊艳的特质让不少研究者惊呼「好得不可思议」,甚至有人将其比作当初「M1芯片」登场所带来的变革意义。看来,RetNet有望成为Transformer的有力继承者。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

不过,也有研究者提出疑问:这么优秀的表现是否意味着RetNet要在某些方面有所权衡?它能扩展到视觉领域吗?

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

接下来,让我们深入了解RetNet方法的细节。

Retentive网络

RetNet由L个相同的块堆叠而成,其布局与Transformer类似(即残差连接和pre-LayerNorm)。每个RetNet块包含两个模块:多尺度retention(MSR)和前馈网络(FFN)。

给定输入序列 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

是隐藏维度。然后,计算上下文向量表征

 ,其中

首先被封装为

,RetNet以自回归方式对序列进行编码。输入向量

Retention

RetNet具有循环和并行双重形式的retention机制,因此能够并行地训练模型,同时循环地进行推理。

给定输入 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

,将其投影为一维函数v(n)=X_n-w_V。考虑一个序列建模问题,通过状态s_n映射v(n)→o(n)。

为简单起见,让v_n,o_n表示v(n),o(n)。此处以循环的方式对映射进行表述:

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中,将v_n映射到状态向量s_n,然后实现线性变换,对序列信息进行循环编码。

接下来,使投影Q_n,K_n具有内容感知能力:

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

是可学习矩阵。

将矩阵 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

。通过将Λ吸收到W_Q和W_K中,可以将方程(1)重写为

。然后得到

对角化,其中

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中, Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

称为xPos,即为Transformer提出的相对位置嵌入。进一步将γ简化为标量,公式(3)则变为

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中†为共轭转置。该公式很容易在训练实例中并行化。

总之,从公式(1)所示的循环建模开始,然后推导出公式(4)中的并行公式。将原始映射v(n)→o(n)视为向量,得到如下的retention机制:

1)Retention的并行表征

如图3a所示,Retention层定义为:

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

与自注意力类似,并行表征使得能够使用GPU高效地训练模型。

2)Retention的循环表征

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

这里的Q,K,V,γ和公式5相同。

3)Retention分块循环表征

并行表征和循环表征的混合形式可以加速训练,特别是对于长序列。此处将输入序列划分为若干小块。在每个块内,按照并行表征(公式(5))进行计算。相反,跨块信息则按照循环表征(公式(6))进行传递。具体来说,让B表示块长度。通过以下方式计算第i个分块的retention输出:

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中[i]表示第i个数据块,例如 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

门控多尺度Retention

在每个层中,研究者使用h=d_model/d个retention头,其中d是头的维度。这些头使用不同的参数矩阵W_Q、W_K、W_V∈R^(d×d)。此外,多尺度retention(MSR)为每个头分配不同的γ。为了简化,研究者将γ设置为在不同层之间相同并保持固定。另外,他们添加了一个swish门[RZL17]来增加层的非线性性。形式上,给定输入X,研究者将该层定义为:

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中, Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

为可学习参数,GroupNorm[WH18]对每个头的输出进行归一化,遵循[SPP^+19]中提出的SubLN。注意,这些头使用多个γ尺度,这会带来不同的方差统计结果。所以研究者分别对头的输出进行归一化。

retention的伪代码如图4所示。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

RetentionScore归一化

研究者利用GroupNorm的尺度不变性来提高retention层的数值精度。具体而言,在GroupNorm中乘以一个标量值不会影响输出和反向梯度,即GroupNorm(α∗head_i)=GroupNorm(head_i)。研究者在公式(5)中实现了三个归一化因子。首先,他们将QK^⊺归一化为QK^⊺/√d。其次,他们将D替换为 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

。由于尺度不变的特性,上述技巧不会影响最终的结果,同时稳定了正向和反向传递的数值流动。

。然后,retention输出变为 

。第三,他们用R表示retentionscoresR=QK^⊺⊙D,将其归一化为

Retention网络总体结构

对于一个L层的retention网络,研究者堆叠多尺度retention(MSR)和前馈网络(FFN)来构建模型。形式上,输入序列 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

作为输入,并计算模型的输出X^L:

通过一个词嵌入层被转换为向量。研究者使用打包后的嵌入

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

其中,LN(・)为LayerNorm[BKH16]。FFN部分计算为FFN(X)=gelu(XW_1)W_2,其中W_1、W_2为参数矩阵。

训练:研究者在训练过程中使用了并行(公式5)表示和块循环(公式7)表示。序列或块内的并行有效地利用了GPU来加速计算。更有利的是,块循环对于长序列训练特别有用,这在FLOPs和内存消耗方面都是有效的。

推理:在推理过程中,研究者采用了循环表示(公式6),这非常适合自回归解码。O(1)的复杂度减少了内存占用和推理延迟,同时实现了相当的结果。

与以往方法的联系和区别

表1从不同角度对RetNet与以往的方法进行了比较。对比结果与图2所示的「不可能三角」相呼应。此外,RetNet对于长序列具有线性记忆复杂性,因为它采用了分块循环表示。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

Transformer:retention的并行表示与Transformers[VSP^+17]有着相似的思路。最相关的Transformer变体是LexTransformer[SDP^+22],它实现了xPos作为位置嵌入。如式(3)所示,retention的推导与xPos一致。与注意力相比,retention消除了softmax并使循环公式成为可能,这非常有利于推理。

S4:与式(2)不同,如果Q_n和K_n是content-unaware的,则公式可简并为S4[GGR21],其中 Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

LinearAttention:变体通常使用各种kernel Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

来取代softmax函数。然而,线性注意力难以有效地编码位置信息,导致模型性能下降。此外,研究者从头开始重新检查序列建模,而不是以近似softmax为目标。

AFT/RWKV:AttentionFreeTransformer(AFT)简化了点积对元素运算的关注,并将softmax移动到关键向量。RWKV用指数衰减取代AFT的位置嵌入,并循环运行模型进行训练和推理。相比之下,retention保留了高维状态来编码序列信息,有助于提高表达能力和性能。

xPos/RoPE:与为Transformers提出的相对位置嵌入方法相比,公式(3)呈现出与xPos[SDP^+22]和RoPE[SLP^+21]类似的表达式。

Sub-LayerNorm:如公式(8)所示,retention层使用Sub-LayerNorm[WMH^+22]对输出进行归一化。由于多尺度建模导致不同头的方差不同,研究者将原始的LayerNorm替换为GroupNorm。

实验结果

该研究进行了大量的实验来评估RetNet,包括语言建模任务、下游任务上零样本、少样本学习性能,此外,研究者还比较了RetNet训练和推理的速度、内存消耗和延迟等指标。

与Transformer的比较

语言建模任务。图5报告了基于Transformer和RetNet的语言模型在验证集上的困惑度(perplexity)结果。实验给出了13b、2.7B和6.7B三种模型尺寸的缩放曲线。表明,RetNet取得了与Transformer可比较的结果。

更重要的是,这一结果还表明了RetNet在大小扩展方面更具优势。除了性能优势外,实验中RetNet的训练也非常稳定。RetNet是Transformer的有力竞争对手。研究者根据经验发现,当模型规模大于2B时,RetNet开始超越Transformer。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

该研究还在各种下游任务上对语言模型进行了比较。他们使用6.7B大小的模型进行了零样本和4个样本学习的评估,如表3所示。表中展示的关于准确率的数字与图5中呈现的语言建模困惑度一致。在零样本学习和上下文学习设置中,RetNet在性能上与Transformer相当。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

训练成本

表4比较了Transformer和RetNet在训练速度和内存开销方面的结果,其中训练序列长度为8192。此外,该研究还将其与FlashAttention进行了比较。

实验结果表明,在训练过程中,RetNet比Transformer更节省内存,并且具有更高的吞吐量。即使与FlashAttention相比,RetNet在速度和内存成本方面仍然具有竞争力。此外,由于不依赖于特定的内核,用户可以轻松高效地在其他平台上训练RetNet。例如,研究者可以在具有良好吞吐量的AMDMI200集群上训练RetNet模型。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

推理成本

图6比较了Transformer和RetNet在推理过程中的内存成本、吞吐量和延迟。实验中使用了A100-80GBGPU评估了6.7B模型。图6显示,RetNet在推理成本方面优于Transformer。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

内存:如图6a所示,由于KV(键和值)缓存,Transformer的内存成本呈线性增长。相比之下,RetNet的内存消耗即使对于长序列也保持一致。

吞吐量:如图6b所示,随着解码长度的增加,Transformer的吞吐量开始下降。相比之下,RetNet通过利用Retention的循环表征,在解码过程中具有更高的吞吐量,并且与长度无关。

延迟:延迟是部署中的重要指标,它极大地影响用户体验。图6c报告了解码延迟。实验结果显示,增加批次大小会使Transformer的延迟变大。此外,Transformer的延迟随着输入长度的增加而增加得更快。为了使延迟可接受,研究者不得不限制批次大小,这会损害Transformer的整体推理吞吐量。相比之下,RetNet的解码延迟优于Transformer,并且在不同的批次大小和输入长度下几乎保持不变。

与Transformer变体比较

下表表明,RetNet在不同的数据集上优于先前的方法。RetNet不仅在领域内语料库上取得更好的评估结果,还在几个领域外数据集上获得更低的困惑度。这种优越的性能使得RetNet成为Transformer的有力继任者。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

消融实验

下表列出了RetNet的各种设计选择,并在表6中报告了语言建模结果。

Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强

今日热搜