[论文翻译]Slim Attention:无需损失精度即可将上下文内存减半 —— K. -cache 是 MHA 所需的全部


原文地址:https://arxiv.org/pdf/2503.05840v1


Slim attention: cut your context memory in half without loss of accuracy — K. -cache is all you need for MHA

Slim Attention:无需损失精度即可将上下文内存减半 —— K. -cache 是 MHA 所需的全部

Nils Graef*, Andrew Was ie lewski Open Machine

Nils Graef*, Andrew Wasie lewski Open Machine

Abstract

摘要

Slim attention shrinks the context memory size by 2x for transformer models with MHA (multi-head attention), which can speed up inference by up to 2x for large context windows. Slim attention is an exact, mathematically identical implementation of the standard attention mechanism and therefore doesn’t compromise model accuracy. In other words, slim attention losslessly compresses the context memory by a factor of 2. For encoder-decoder transformers, the context memory size can be reduced even further: For the Whisper models for example, slim attention reduces the context memory by 8x, which can speed up token generation by 5x for batch size 64 for example. And for rare cases where the MHA projection dimension is larger than dmodel , the memory can be reduced by a factor of 32 for the T5-11B model for example. See [1] for code and more transformer tricks, and [2] for a YouTube video about this paper.

Slim attention 通过将上下文内存大小缩小 2x 来加速具有 MHA(多头注意力)的 Transformer 模型的推理,对于大上下文窗口,推理速度最多可提升 2x。Slim attention 是标准注意力机制的精确数学等价实现,因此不会影响模型准确性。换句话说,Slim attention 无损地将上下文内存压缩了 2 倍。对于编码器-解码器 Transformer,上下文内存大小可以进一步减少:例如,对于 Whisper 模型,Slim attention 将上下文内存减少了 8 倍,对于批量大小为 64 的情况,Token 生成速度可以提升 5x。在极少数情况下,当 MHA 投影维度大于 dmodel 时,例如对于 T5-11B 模型,内存可以减少 32 倍。代码和更多 Transformer 技巧请参见 [1],关于本文的 YouTube 视频请参见 [2]。

Fig. 1 illustrates how slim attention computes the value (V) projections from the key (K) projections in a mathematical equivalent way without hurting model accuracy. Therefore, we only need to store the keys in memory, instead of storing both keys and values. This reduces the size of the context memory (aka KV-cache) by half. Alternatively, slim attention can double the context window size without increasing context memory. However, calculating v from K on-the-fly requires additional compute, which we will discuss below.

图 1 展示了 slim attention 如何在不损害模型准确性的情况下,以数学等价的方式从键 (K) 投影中计算值 (V) 投影。因此,我们只需要在内存中存储键,而不需要同时存储键和值。这将上下文内存(即 KV-cache)的大小减少了一半。或者,slim attention 可以在不增加上下文内存的情况下将上下文窗口大小翻倍。然而,从 K 中实时计算 v 需要额外的计算,我们将在下面讨论这一点。


Figure 1: Mathematically identical implementations of multi-headed self-attention with square weight matrices Rd×d . Left: vanilla version. Right: proposed version where values v are computed from keys K with ¯WKV=W1KWV . The symbol denotes mathematical identity.

图 1: 具有方形权重矩阵 Rd×d 的多头自注意力的数学等价实现。左图:原始版本。右图:提出的版本,其中值 v 通过键 K 计算,¯WKV=W1KWV 符号表示数学等价性。

Slim attention is applicable to transformers that use MHA (multi-head attention [3]) instead of MQA (multi-query attention [4]) or GQA (grouped query attention [5]), which includes LLMs such as CodeLlama-7B and Aya-23-35B, SLMs such as Phi-3-mini and SmolLM2-1.7B, VLMs (vision language models) such as LLAVA, audio-language models such as Qwen2-Audio-7B, and encoder-decoder transformer models such as Whisper [6] and T5 [7]. Table 1 lists various MHA transformer models ranging from 9 million to 35 billion parameters. The last column of Table 1 specifies the KV-cache size (in number of activation s) for each model to support its maximum context length, where the KV-cache size equals 2hdk · layers context length.

Slim attention 适用于使用 MHA (多头注意力 [3]) 而非 MQA (多查询注意力 [4]) 或 GQA (分组查询注意力 [5]) 的 Transformer 模型,这包括诸如 CodeLlama-7B 和 Aya-23-35B 等大语言模型,Phi-3-mini 和 SmolLM2-1.7B 等小语言模型,LLAVA 等视觉语言模型,Qwen2-Audio-7B 等音频语言模型,以及 Whisper [6] 和 T5 [7] 等编码器-解码器 Transformer 模型。表 1 列出了从 900 万到 350 亿参数的各种 MHA Transformer 模型。表 1 的最后一列指定了每个模型支持其最大上下文长度所需的 KV 缓存大小 (以激活数表示),其中 KV 缓存大小等于 2hdk · 层数 上下文长度。

Table 1: Various transformers with MHA (instead of MQA or GQA) and their maximum KV-cache sizes (in number of activation s) based on their respective maximum context length. h is the number of attention-heads, d is the embedding dimension (aka hidden size), and dk is the head dimension. See appendix for more MHA models.

表 1: 使用 MHA (而非 MQA 或 GQA) 的各种 Transformer 及其基于各自最大上下文长度的最大 KV 缓存大小 (以激活数表示)。h 是注意力头的数量,d 是嵌入维度 (即隐藏大小),dk 是头维度。更多 MHA 模型请参见附录。

年份 发布者 模型 参数量 d 层数 h dk 上下文长度 上下文内存
2024 Meta CodeLlama-7B [8] 7B 4,096 32 128 16k 4.3B
CodeLlama-13B [8] 13B 5,120 40 6.7B
Google CodeGemma-7B [9] 8.5B 3,072 28 16 256 8k 1.9B
Cohere Aya-23-35B [10] SmolLM2-1.7B [11] 35B 8,192 40 64 128 5.4B 0.8B
HuggingFace SmolVLM [12] 1.7B 2.3B 2,048 24 32 64 16k 1.6B
Together AI Evo-1-131k [13] 6.5B 4,096 32 128 128k 34.4B
Microsoft Phi-3-mini-128k [14] 3.8B 3,072 32 96 25.8B
Apple BitNet_b1_58-3B [15] 3.3B 3,200 26 32 100 2k 0.3B
DCLM-7B [16] OLMo-1B [17] 6.9B 4,096 32 128 0.5B
Ai2 OLM0-2-1124-13B [17] 1.3B 13.7B 2,048 5,120 16 40 128 4k 0.3B 1.7B
Amazon Chronos-Bolt-tiny [18] 9M 256 4 64 0.5k 1M
Alibaba Chronos-Bolt-base [18] 205M 768 12 9.4M
Qwen2-Audio-7B [19] 8.4B 4,096 32 8k 2.1B
LLaVA 2023 LLaVA-NeXT-Video [20] 7.1B 128 4k 1.1B
LLaVA-Vicuna-13B [21] 13.4B 5,120 40 1.7B
Google LMSYS Vicuna-7B-16k [22] 7B 4,096 32 16k 4.3B
Vicuna-13B-16k [22] 13B 5,120 40 6.7B
Flan-T5-base [23] 248M 768 12 0.5k 9.4M
2022 2019 Flan-T5-XXL [23] 11.3B 4,096 24 64 101M
Whisper-tiny [6] 38M 384 4 6 64 enc: 1500 dec: 448 6M
OpenAI Whisper-large-v3 [6] GPT-2 XL [24] 1.5B 1,280 1,600 32 48 20 25 1k 160M 157M

For long contexts, the KV-cache can be even larger than the parameter memory: For batch size 1 and 1 byte (FP8) per parameter and activation, the Phi-3-mini 128kΩ model for example has a 3.8GB parameter memory and requires 25GB for its KV-cache to support a context length of 128K tokens. For a batch size of 16 for example, the KV-cache grows to 1625GB=400GB . Therefore, memory bandwidth and capacity become the bottleneck for supporting long context.

对于长上下文,KV缓存甚至可能比参数内存更大:以批量大小为1且每个参数和激活占用1字节(FP8)为例,Phi-3-mini 128kΩ 模型的参数内存为3.8GB,而其KV缓存需要25GB来支持128K token的上下文长度。例如,当批量大小为16时,KV缓存将增长到 1625GB=400GB。因此,内存带宽和容量成为支持长上下文的瓶颈。

For a memory bound system with batch size 1, generating each token takes as long as reading all (activated) parameters and all KV-caches from memory. Therefore, slim attention can speed up the token generation by up to 2x for long contexts. For the Phi-3-min 128kΩ model with 3.8GB parameters for example, slim attention reduces the KV-cache size from 25GB to 12.5GB, which reduces the total memory from 28.8GB to 16.3GB, and thus speeds up the token generation by up to 1.8x for batch size 1 (the maximum speedup happens for the generation of the very last token of the 128K tokens). And for batch size 16 for example, the speedup is (400+3.8)/(200+3.8)=2x .

对于批处理大小为1的内存受限系统,生成每个Token所需的时间与从内存中读取所有(激活的)参数和所有KV缓存的时间相当。因此,对于长上下文,slim attention可以将Token生成速度提升至多2x。以Phi-3-min 128kΩ模型为例,该模型拥有3.8GB的参数,slim attention将KV缓存大小从25GB减少到12.5GB,从而将总内存从28.8GB减少到16.3GB,并在批处理大小为1的情况下将Token生成速度提升至多1.8x(最大加速发生在生成128K Token中的最后一个Token时)。而对于批处理大小为16的情况,加速比为(400+3.8)/(200+3.8)=2x

The vanilla transformer [3] defines the self-attention Y of input X as follows, where h is the number of heads:

原始 Transformer [3] 将输入 X 的自注意力 Y 定义如下,其中 h 是头数:

图片.png

with WQ=concat(WQ,1,,WQ,h) , WK=concat(WK,1,,WK,h) , and WV=concat(WV,1,,WV,h) , and without the causal mask for simplicity. The matrices Q,K,V,WQ,WI , and WV are split into h sub matrices, one for each attention-head. Input X , output Y , queries Q , keys K , and values V are Rn×d , where n is the current sequence length (in tokens) and d=dmodel is the dimension of the embeddings.

其中 WQ=concat(WQ,1,,WQ,h)WK=concat(WK,1,,WK,h) ,以及 WV=concat(WV,1,,WV,h) ,并且为了简化,省略了因果掩码。矩阵 Q,K,V,WQ,WIWV 被分割成 h 个子矩阵,每个注意力头对应一个。输入 X ,输出 Y ,查询 Q ,键 K 和值 V 都属于 Rn×d ,其中 n 是当前序列长度(以 token 为单位), d=dmodel 是嵌入的维度。

For MHA, the weight matrices WK and WV are usually square matrices Rd×d , which allows us to calculate v from K as follows: Refactoring equation (1) as X=KWˉ1K lets us reconstruct X from K , which we can then plug into equation (2) to get

对于 MHA (Multi-Head Attention) 来说,权重矩阵 WKWV 通常是方阵 Rd×d,这使得我们可以从 K 计算 v,如下所示:将方程 (1) 重构为 X=KWˉ1K,我们可以从 K 重构 X,然后将其代入方程 (2) 以得到

图片.png

and WKV,iRd×dv . Fig. 1 illustrates the modified attention scheme that calculates v from K according to equation (4). For inference, WKV=W1KWV can be pre computed offline and stored in the parameter file instead of WV . This requires that WK is invertible (i.e. non-singular). In general, any square matrix can be inverted if its determinant is non-zero. It’s extremely unlikely that a large matrix has a determinant that is exactly 0.

WKV,iRd×dv 。图 1 展示了根据公式 (4) 从 K 计算 v 的改进注意力机制。在推理时, WKV=W1KWV 可以预先计算并存储在参数文件中,而不是 WV 。这要求 WK 是可逆的(即非奇异)。一般来说,任何方阵如果其行列式不为零,都可以被求逆。大型矩阵的行列式恰好为 0 的情况极为罕见。

Related work. Slim attention is somewhat similar to DeepSeek’s multi-head latent attention (MLA) [25]. Unlike MLA, slim attention is an exact post-training implementation of existing MHA models (including models with RoPE).

相关工作。Slim attention 与 DeepSeek 的多头潜在注意力 (MLA) [25] 有些相似。与 MLA 不同,Slim attention 是对现有 MHA 模型(包括带有 RoPE 的模型)的精确训练后实现。

1 K-cache is all you need

1 K-cache 就是你所需要的

Inference consists of the following two phases, which are illustrated in Fig. 2 for the vanilla MHA with KV-cache, where p is the number of input-tokens and n is the total number of current tokens including input-tokens and generated tokens, so n=p+1,,nmax and nmax is the context window length:

推理包括以下两个阶段,如图 2 所示的带有 KV-cache 的普通 MHA (Multi-Head Attention) 中,其中 p 是输入 token 的数量,n 是当前 token 的总数,包括输入 token 和生成的 token,因此 n=p+1,,nmax,且 nmax 是上下文窗口长度:

• During the prompt-phase (aka prefill phase), all p input-tokens are batched up and processed in parallel. In this phase, the K and v projections are stored in the KV-cache. • During the generate-phase (aka decoding phase), each output-token is generated sequentially (aka autoregressively). For each iteration of the generate-phase, only one new K -vector and one new V-vector are calculated and stored in the KV-cache, while all the previously stored KV-vectors are read from the cache.

• 在提示阶段(也称为预填充阶段),所有 p 个输入 Token 会被批量处理并并行处理。在此阶段,Kv 投影会被存储在 KV 缓存中。
• 在生成阶段(也称为解码阶段),每个输出 Token 会依次生成(也称为自回归生成)。在生成阶段的每次迭代中,只会计算一个新的 K 向量和一个新的 V 向量并存储在 KV 缓存中,同时从缓存中读取所有先前存储的 KV 向量。


Figure 2: Standard MHA with KV-cache during (a) prompt-phase and (b) generate-phase.

图 2: 标准 MHA 在 (a) 提示阶段和 (b) 生成阶段使用 KV-cache 的情况。

Fig. 3 illustrates slim attention, which only has a K-cache because v is now calculated from K. Plugging equation (4) into (3) yields

图 3 展示了 slim attention,它只有一个 K-cache,因为 v 现在是从 K 计算得出的。将方程 (4) 代入 (3) 得到

图片.png

Equation (5) can be computed in two different ways:

方程 (5) 可以通过两种不同的方式计算:

• Option 1 (un optimized): Compute Vi=KWKV,i first, and then multiply it with softmax () . This option is used by Fig. 3(a) and 3(b). Complexity: multiplying KRn×d with WKV,iRd×dk takes 2nddkOPs2 , and multiplying softmax ()R1×n with the n×dk result takes 2ndk OPs.

• 选项 1 (未优化):首先计算 Vi=KWKV,i,然后将其与 softmax () 相乘。此选项用于图 3(a) 和 3(b)。复杂度:将 KRn×dWKV,iRd×dk 相乘需要 2nddkOPs2,将 softmax ()R1×nn×dk 的结果相乘需要 2ndk OPs。

• Option 2 (optimized): First multiply softmax () with K , and then multiply the result by WKV,i . This option is illustrated in Fig. 3(c). During the generate-phase, this option has lower compute complexity than option 1: multiplying softmax () with K takes 2nd OPs, and multiplying the result with WKV,i takes 2ddk OPs.

• 选项 2 (优化): 首先将 softmax ()K 相乘,然后将结果与 WKV,i 相乘。此选项如图 3(c) 所示。在生成阶段,此选项的计算复杂度低于选项 1:将 softmax ()K 相乘需要 2nd 次操作,将结果与 WKV,i 相乘需要 2ddk 次操作。


Option 2 Figure 3: Slim attention without V-cache during (a) prompt-phase; (b) un optimized and (c) optimized generate-phase

图 3: 无 V-cache 的 Slim attention 在 (a) 提示阶段;(b) 未优化和 (c) 优化生成阶段

Option 2 above uses the same associativity trick as MLA, see appendix C of [25]. During the prompt-phase, Fig. 3(a) has the same computational complexity as the vanilla scheme shown in Fig. 2(a). However, during the generate-phase, the proposed scheme has a slightly higher complexity than the vanilla scheme.

上述选项 2 使用了与 MLA 相同的结合性技巧,详见 [25] 的附录 C。在提示阶段,图 3(a) 的计算复杂度与图 2(a) 所示的原始方案相同。然而,在生成阶段,所提出的方案比原始方案的计算复杂度略高。

Table 2 specifies the complexity per token per layer during the generate-phase for batch size 1. The columns labeled “OPs”, “reads”, and “intensity” specify the computational complexity (as number of OPs), the number of memory reads, and the arithmetic intensity, resp. We define the arithmetic intensity here as number of OPs per each activation or parameter read from memory. Specifically, the projection complexity includes calculating XWQ , XWK , XWV , and multiplying with weight matrices WO and WKV . And the memory reads for projections include reading all four weight matrices; while the memory reads of the MHA include reading the K-cache (and the V-cache for the vanilla implementation). See appendix for more details on MHA complexity.

表 2 指定了在生成阶段(batch size 为 1)每层每个 Token 的复杂度。标有“OPs”、“reads”和“intensity”的列分别指定了计算复杂度(以 OP 的数量表示)、内存读取次数和算术强度。我们在这里将算术强度定义为每次从内存中读取激活值或参数时的 OP 数量。具体来说,投影复杂度包括计算 XWQXWKXWV,以及与权重矩阵 WOWKV 的乘法运算。投影的内存读取包括读取所有四个权重矩阵;而多头注意力机制 (MHA) 的内存读取包括读取 K-cache(对于普通实现还包括 V-cache)。有关 MHA 复杂度的更多详细信息,请参阅附录。

Table 2: Complexity per token per layer during the generate-phase for batch size 1

表 2: 生成阶段每层每个 Token 的复杂度(批量大小为 1)

投影复杂度 MHA 复杂度
OPs 读取 强度 OPs 读取 强度
Vanilla,见图 2(b) 8d² 4d² 2 4nd 2nd 2
未优化,图 3(b) (2n + 6)d² 4d² (n + 3)/2 4nd nd 4
优化后,图 3(c) 8d² 4d² 2 2nd(h + 1) nd 2h + 2

Note that for batch-size B , the arithmetic intensity of the vanilla transformer during the generate-phase is 2B for the FFNs and the attention-projections, but only 2 for the remaining attention operations (softmax arguments and weighted sum of V) because each of the B tokens has its own KV-cache.

注意,对于批量大小 B,普通 Transformer 在生成阶段的算术强度对于 FFN 和注意力投影是 2B,但对于剩余的注意力操作(softmax 参数和 V 的加权和)仅为 2,因为每个 B token 都有自己的 KV 缓存。

Table 3 shows the arithmetic intensity (now defined as OPs per memory byte) of various SoCs, TPUs, and GPUs, which vary from 93 to 583. A system is memory bound (i.e. limited by memory bandwidth) if the arithmetic intensity of the executed program is below the chip’s arithmetic intensity. Here, the maximum arithmetic intensity of slim attention is 2h+2 , see Table 2, where h is the number of attention-heads, which ranges between 16 and 64 for the models listed in Table 1. So the peak arithmetic intensity (up to 130 for h=64 ) is usually less than the system’s intensity (except for Apple’s M4Max ), which means that the system is still memory bound during the token generation phase. Therefore, slim attention speeds up the processing by up to 2x as it reduces the context memory reads by half. Furthermore, slim attention enables processing all heads in parallel as a single matrix-matrix multiplication instead of multiple vector-matrix multiplications, which is usually more efficient and faster on many machines. And slim attention is also compatible with Flash Attention [26], which performs softmax and value accumulation in parallel.

表 3 展示了各种 SoC、TPU 和 GPU 的算术强度(现在定义为每内存字节的操作数),其范围从 93 到 583。如果执行程序的算术强度低于芯片的算术强度,则系统受内存限制(即受内存带宽限制)。在这里,slim attention 的最大算术强度为 2h+2,参见表 2,其中 h 是注意力头的数量,表 1 中列出的模型的 h 范围在 16 到 64 之间。因此,峰值算术强度(对于 h=64 最高为 130)通常小于系统的强度(除了 Apple 的 M4Max),这意味着在 Token 生成阶段系统仍然受内存限制。因此,slim attention 通过将上下文内存读取减少一半,将处理速度提高了最多 2x。此外,slim attention 使得所有头可以并行处理,作为一个单一的矩阵-矩阵乘法,而不是多个向量-矩阵乘法,这在许多机器上通常更高效且更快。slim attention 还与 Flash Attention [26] 兼容,后者可以并行执行 softmax 和值累加。

Table 3: TOPS3, memory bandwidth, and arithmetic intensity of popular chips

表 3: 热门芯片的 TOPS3、内存带宽和算术强度

芯片 TOPS (int8) 理论内存带宽 (GB/s) 算术强度 (每字节操作数)
Rockchip RK3588 6 19 316
Apple A18 [27] 35 60 583
Apple M4 Max [27] 38 410 93
Google TPU v4 [28] 275 1,200 229
Google TPU v5p [28] 918 2,765 332
NVIDIA H200 [29] 1,980 4,800 413
NVIDIA B200 [29] 4,500 8,000 563

2 Taking advantage of softmax sparsities

2 利用 softmax 稀疏性

In this section we describe how we can take advantage of softmax sparsities (i.e. sparsities in the attention scores) to reduce the computational complexity of the attention blocks. In some applications, many attention scores are 0 or close to zero. For those attention scores (i.e. attention scores smaller than a threshold), we can simply skip the corresponding V-vector, i.e. we don’t have to add those skipped vectors to the weighted sum of V-vectors. This reduces the complexity of calculating the weighted sum of V-vectors. For example, for a sparsity factor S=0.8 (i.e. 80 of scores are 0), the complexity is reduced by factor 1−1S 11S=5

在本节中,我们将描述如何利用 softmax 稀疏性(即注意力分数的稀疏性)来降低注意力块的计算复杂度。在某些应用中,许多注意力分数为 0 或接近零。对于那些注意力分数(即小于某个阈值的注意力分数),我们可以直接跳过相应的 V 向量,也就是说,我们不需要将那些跳过的向量添加到 V 向量的加权和中。这降低了计算 V 向量加权和的复杂度。例如,对于稀疏因子 S=0.8(即 80 的分数为 0),复杂度降低了 111S=5 倍。

By the way, taking advantage of softmax sparsities is also possible for systems with KV-cache where V is not computed from K. In this case, skipping V. -vectors with zero scores means that we don’t have to read those V-vectors from the KV-cache, which speeds up the auto regressive generate-phase for memory bound systems. However, this will never speed it up more than slim attention’s removal of the entire V-cache. Furthermore, for MQA and GQA, each V-vector is shared among multiple (e.g. 4 or more) queries so we can only skip reading a V-vector from memory if all 4 (or more) attention scores are zero for this shared V. -vector, which reduces the savings significantly. For example, if the V-vectors are shared among 4 queries and the attention scores have sparsity S=0.8 , then the probability of all four queries being 0 is only S4=¯0.41 , so we can only skip 41 of the V-vectors.

顺便提一下,对于具有KV缓存(其中V不是从K计算得出)的系统,利用softmax稀疏性也是可行的。在这种情况下,跳过得分为零的V.向量意味着我们无需从KV缓存中读取这些V向量,从而加速了内存受限系统的自回归生成阶段。然而,这永远不会比slim attention移除整个V缓存带来的加速效果更好。此外,对于MQA(多查询注意力)和GQA(分组查询注意力),每个V向量在多个(例如4个或更多)查询之间共享,因此只有当所有4个(或更多)注意力得分对于这个共享的V.向量都为零时,我们才能跳过从内存中读取V向量,这大大减少了节省。例如,如果V向量在4个查询之间共享,且注意力得分的稀疏性为S=0.8,那么所有四个查询得分为0的概率仅为S4=¯0.41,因此我们只能跳过41的V向量。

3Support for RoPE

3 对 RoPE 的支持

Many transformers nowadays use RoPE (rotary positional embedding) [30], which applies positional encoding to the Q and K projections, but not the V projections. In general, RoPE can be applied to the K projections either before storing them in K-cache or after reading them from K-cache. The former is preferred because of lower computational complexity during the generate-phase (so that each K-vector is RoPE’d only once instead of multiple times). However, if the RoPE’d keys are stored in K-cache, then we first need to un-RoPE them before we can compute V from K. The following details two options to support RoPE.

如今,许多 Transformer 使用 RoPE (rotary positional embedding) [30],它将位置编码应用于 Q 和 K 的投影,但不应用于 V 的投影。一般来说,RoPE 可以在将 K 投影存储到 K-cache 之前或从 K-cache 读取之后应用于 K 投影。前者是首选,因为在生成阶段计算复杂度较低(这样每个 K 向量只需进行一次 RoPE 操作,而不是多次)。然而,如果 RoPE 处理后的键存储在 K-cache 中,那么在从 K 计算 V 之前,我们需要先对其进行反 RoPE 操作。以下详细介绍了支持 RoPE 的两种选项。

Option 1 is for the case where we don’t take advantage of softmax sparsities. In this case, we apply RoPE to the K-vectors after reading them from K-cache during the generate-phase. That way we can use the raw K-vectors for computing V from K.

选项1适用于我们不利用softmax稀疏性的情况。在这种情况下,我们在生成阶段从K缓存中读取K向量后,对它们应用RoPE。这样我们就可以使用原始的K向量从K计算V。

Option 2 is for the case where we take advantage of softmax sparsities as detailed in the previous section. In this case, RoPE is applied to the K-vectors before writing them into the K-cache. And when they are read from K-cache during the generate-phase, then we have to revert (or undo) the RoPE-encoding before we can use the K-vectors to compute the V-vectors (i.e. multiplying the K-vectors with the attention scores). However, we only need to do this for a portion of the K-vectors, depending on the sparsity factor S . For example, for S=0.8 , we only need to revert the RoPE-encoding for 20 of the K-vectors. The RoPE encoding can be reverted (aka RoPE-decoding) by simply performing a rotation in the opposite direction by the same amount as shown below for the 2D case.

选项 2 适用于我们利用前一节中详述的 softmax 稀疏性的情况。在这种情况下,RoPE 在将 K 向量写入 K 缓存之前应用于它们。当在生成阶段从 K 缓存中读取它们时,我们必须先撤销(或还原)RoPE 编码,然后才能使用 K 向量来计算 V 向量(即将 K 向量与注意力分数相乘)。然而,我们只需要对一部分 K 向量执行此操作,具体取决于稀疏因子 S。例如,对于 S=0.8,我们只需要对 20 的 K 向量进行 RoPE 编码的还原。RoPE 编码可以通过简单地执行相反方向的旋转来还原(即 RoPE 解码),旋转量与 2D 情况下的相同,如下所示。

RoPE encoding:

RoPE编码:

图片.png

RoPE decoding:

RoPE解码:

图片.png

Note that the RoPE decoding uses the same trigonometric coefficients (such as cosmθ ) as the RoPE encoding. Therefore, we only need one look-up table that can be used for both RoPE encoding and decoding.

请注意,RoPE解码使用与RoPE编码相同的三角系数(例如 cosmθ)。因此,我们只需要一个查找表,即可用于RoPE编码和解码。

4 Support for bias

4 支持偏差

Since PaLM’s removal of bias terms from all its projection layers [31], most transformer models nowadays do the same. However, some models are still using biases today (especially older models that are still relevant today such as Whisper). In this section, we briefly discuss how projection layers with bias can be supported. We show how the biases of two of the four attention projection layers can be eliminated in a mathematically equivalent way.

自 PaLM 从其所有投影层中移除偏置项 [31] 以来,现今大多数 Transformer 模型也采取了相同的做法。然而,一些模型至今仍在使用偏置项(尤其是像 Whisper 这样至今仍具影响力的较旧模型)。在本节中,我们简要讨论了如何支持带有偏置的投影层。我们展示了如何在数学上等价地消除四个注意力投影层中两个的偏置。

Bias removal for V projections: This bias can be combined with the bias of the output projection layer as follows. Recall that all value vectors vi plus their constant bias b are multiplied by the attention scores si (i.e. the softmax outputs) and summed up, such as

V 投影的偏差去除:这种偏差可以与输出投影层的偏差结合如下。回想一下,所有的值向量 vi 加上它们的常数偏差 b 都会乘以注意力分数 si(即 softmax 输出)并求和,例如

ni=1si(vi+b)=ni=1sivi+ni=1sib=ni=1sivi+b

ni=1si(vi+b)=ni=1sivi+ni=1sib=ni=1sivi+b

The last equal sign holds because the sum over all attention-scores si is always 1 as per softmax definition (because the softmax generates a probability distribution that always adds up to 1). We can now merge the bias b with bias c of the

最后一个等号成立是因为根据 softmax 的定义,所有注意力分数 si 的总和始终为 1(因为 softmax 生成的概率分布总和始终为 1)。现在我们可以将偏差 b 与偏差 c 合并。

preceding output projection layer (O) as follows: $y=(x+b)W_{O}+c=x W_{O}+(b W_{O}+c)=x W_{O}+c^{},withthenewbiasc^{}=b W o+c.Thisnewbiasvectorc^{*}$ can be computed offline, before inference time. Or simply remove the V-bias already during training.

前面的输出投影层 (O) 如下:$y=(x+b)W_{O}+c=x W_{O}+(b W_{O}+c)=x W_{O}+c^{}c^{}=b W o+cc^{*}$ 可以在推理时间之前离线计算。或者简单地在训练期间移除 V-bias。

Bias removal for K projections: The bias of the K projection cancels out due to the constant invariance of the softmax function. For example, say we have 2-dimensional heads, then the dot-product p between query-vector q=(q1+b1,q2+b2) with bias b and key-vector k=(k1+c1,k2+c2) with bias c is as follows:

K 投影的偏差去除:由于 softmax 函数的常数不变性,K 投影的偏差会被抵消。例如,假设我们有二维的头部,那么带有偏差 b 的查询向量 q=(q1+b1,q2+b2) 与带有偏差 c 的键向量 k=(k1+c1,k2+c2) 之间的点积 p 如下:

\boldmath σ =(q1+b1)(k1+c1)+(q2+b2)(k2+c2)=[q1k1+q2k2]+[q1c1+q2c2]+[b1k1+b2k2]+[b1k2+c2k1] =q1k1+q2k2+f(q)+b1k1+b2k2+constant,

\boldmath σ =(q1+b1)(k1+c1)+(q2+b2)(k2+c2)=[q1k1+q2k2]+[q1c1+q2c2]+[b1k1+b2k2]+[b1k2+c2k1] =q1k1+q2k2+f(q)+b1k1+b2k2+constant,

where f(q)=q1c1+q2c2 is a function of the query-vector only; and “constant” is a constant that only depends on the two biases b and c . Now recall that the softmax function doesn’t change if a constant is added to all its arguments. Because all arguments of the attention softmax use the same single query-vector q , f(q) is the same for all arguments and is therefore constant and can be removed from all softmax arguments. As a result, we can remove the entire bias-vector c from the keys. But we still need the bias-vector b for the queries. However, this assumes that there is no RoPE applied between the projections and the dot-product calculation, which is fortunately the case for Whisper for example.

其中 f(q)=q1c1+q2c2 是仅与查询向量相关的函数;而“常数”是一个仅依赖于两个偏置 bc 的常数。现在回想一下,softmax 函数在其所有参数上加上一个常数时不会改变。由于注意力 softmax 的所有参数都使用相同的查询向量 qf(q) 对所有参数都是相同的,因此是常数,可以从所有 softmax 参数中移除。因此,我们可以从键中移除整个偏置向量 c。但我们仍然需要查询的偏置向量 b。然而,这假设在投影和点积计算之间没有应用 RoPE(旋转位置编码),幸运的是,例如 Whisper 就是这种情况。

5 Support for non-square weight matrices

5 支持非方形权重矩阵

Some transformers with MHA use non-square weight matrices for their K and v projections. Specifically, these models do not satisfy d=dkh . The table below shows three such models where e=dkh>d . Let’s also define the aspect ratio r as r=e/d . For example, Google’s T5-11B model has a large aspect ratio of r=16 .

一些使用多头注意力机制 (MHA) 的 Transformer 在其 K 和 v 投影中使用了非方阵权重矩阵。具体来说,这些模型不满足 d=dkh。下表展示了三个这样的模型,其中 e=dkh>d。我们还可以定义长宽比 rr=e/d。例如,Google 的 T5-11B 模型具有较大的长宽比 r=16

模型 d dk h e=dkh 宽高比 r = e/d
CodeGemma-7B 3,072 256 16 4,096 1.3
T5-3B 1,024 128 32 4,096 4
T5-11B 1,024 128 128 16,384 16

There are two options to reduce the KV-cache by 2x or more, which are compared in the table below and summarized as follows:

有两种方法可以将 KV-cache 减少 2x 或更多,下表对这两种方法进行了比较并总结如下:

• Option 1: Because the K weight matrix is non-square, inverting this matrix is not straight forward. And the resulting matrix WKVRe×e , which has r -times more parameters than WVRd×e . • Option 2: Instead of storing v in cache and then calculating v from K, we can store the smaller d -element vectors X before the projection and then on-the-fly calculate both projections ( V and K) from X. The cache is now r -times smaller than option 1, and 2r times smaller than the baseline, for example 32 times smaller for the T5-11B model. However, this comes at a slightly higher computational cost.

• 选项 1:由于 K 权重矩阵是非方阵,因此求逆矩阵并不直接。生成的矩阵 WKVRe×eWVRd×er 倍的参数。
• 选项 2:与其将 v 存储在缓存中,然后从 K 计算 v,我们可以在投影之前存储较小的 d 元素向量 X,然后动态地从 X 计算两个投影(V 和 K)。现在的缓存比选项 1 小 r 倍,比基线小 2r 倍,例如对于 T5-11B 模型,缓存小了 32 倍。然而,这会带来稍高的计算成本。

BaselineOption 1Option 2
Cache reduction factor122r
Size of Wv or Wkvdee2 (r-times larger)de
Computational complexitybaselinehigherevenhigher
SupportforRoPE?YesYesNo
Baseline Option 1 Option 2
缓存减少因子 1 2 2r
Wv 或 Wkv 的大小 de e2 (r 倍大) de
计算复杂度 baseline 更高 更高
支持 RoPE?

Option 1: The standard matrix inverse is defined only for square matrices, and the inversion functions in NumPy and SciPy are limited to such matrices. We want to compute the inverse of WKRd×e with e>d such that WKW1K=I , where I is the identity matrix and W1K is the so-called right inverse of WK . We compute W1K by using a trick that inverts the term WKWK instead of WK as follows:

选项 1:标准矩阵逆仅针对方阵定义,NumPy 和 SciPy 中的求逆函数也仅限于此类矩阵。我们希望计算 WKRd×e 的逆矩阵,其中 e>d,使得 WKW1K=I,其中 I 是单位矩阵,W1KWK 的所谓右逆矩阵。我们通过使用一种技巧来计算 W1K,即对 WKWK 进行求逆,而不是直接对 WK 求逆,具体如下:

$$
I=W_{K}\underbrace{W_{K}^{\top}(W_{K}W_{K}^{\top})^{-1}}{W{K}^{-1}}
$$

$$
I=W_{K}\underbrace{W_{K}^{\top}(W_{K}W_{K}^{\top})^{-1}}{W{K}^{-1}}
$$

In the equation above, everything on the right side of WK has to be the inverse of WK , thus W1K=WK(WKWK)1 . We can now use the matrix inversion function of NumPy to compute the inverse of the term WKWK , which is a square d×d matrix. Now we can calculate WKV=W1KWV . However, storing WKV instead of the original WV takes r times more space in memory, which is an issue for large aspect-ratios r .

在上面的等式中,WK 右侧的所有内容都必须是 WK 的逆矩阵,因此 W1K=WK(WKWK)1。我们现在可以使用 NumPy 的矩阵求逆函数来计算 WKWK 的逆矩阵,这是一个 d×d 的方阵。现在我们可以计算 WKV=W1KWV。然而,存储 WKV 而不是原始的 WV 会占用 r 倍的内存空间,这对于大宽高比 r 来说是一个问题。

Option 2 caches the X -matrix instead of KV or just K, where the X -matrix contains the input activation s of the attention layer (before the projections). Re computing all K-vectors from X by multiplying X with weight matrix WK would require 2nde operations and would be very expensive. A lower complexity option is illustrated in Fig. 4, which is similar to the trick illustrated in Fig. 3(c). Recall that for the i -th head (i=1,,h) , the softmax argument (without the scaling factor 1/dk) is Ai=QiKi , where Qi=XWQ,i and Ki=XWK,i . For the generate-phase, there is only one input-vector xn for the query, but there are n input-vectors X for the key and value projections. We can take advantage of this and modify A as follows (which uses the trick (BC)=C˙B for transposing the product of arbitrary matrices B and C ):

选项 2 缓存的是 X 矩阵,而不是 KV 或仅 K,其中 X 矩阵包含注意力层的输入激活(在投影之前)。通过将 X 与权重矩阵 WK 相乘来重新计算所有 K 向量需要 2nde 次操作,这将非常昂贵。图 4 展示了一个复杂度较低的选项,类似于图 3(c) 中的技巧。回想一下,对于第 i 个头 (i=1,,h),softmax 参数(不包括缩放因子 1/dk)Ai=QiKi,其中 Qi=XWQ,iKi=XWK,i。在生成阶段,查询只有一个输入向量 xn,但键和值投影有 n 个输入向量 X。我们可以利用这一点并修改 A 如下(使用 (BC)=C˙B 的技巧来转置任意矩阵 BC 的乘积):


Figure 4: Slim attention with X -cache (instead of KV or V-cache) for the generate-phase of transformers with non-square weight matrices with e>d .

图 4: 使用 X -cache(而非 KV 或 V-cache)的 Slim attention,适用于权重矩阵非方阵且 e>d 的 Transformer 生成阶段。

Ai=QiKi=xnWQ,i(XWK,i)=(xnWQ,iWK,i)X

Ai=QiKi=xnWQ,i(XWK,i)=(xnWQ,iWK,i)X

For each iteration of the generate-phase, we now have to calculate the term xnWQ,iWK,i only once for each attentionhead, which is independent of the sequence length. Calculating this term involves multiplying the d -dimensional vector xn with matrices WQ,iRd×dk and WK,iRdk×d , which requires 2de multiplications for the h heads, so 4de operations in total (where we count a multiply-add operation as 2 operations).

在生成阶段的每次迭代中,我们现在只需为每个注意力头计算一次 xnWQ,iWK,i 项,这与序列长度无关。计算该项涉及将 d 维向量 xn 与矩阵 WQ,iRd×dkWK,iRdk×d 相乘,这需要为 h 个头进行 2de 次乘法运算,因此总共需要 4de 次操作(我们将一次乘加运算计为 2 次操作)。

This scheme also works for projection layers with biases (as used by the Whisper models for example). Recall from the previous section that we can eliminate the biases from the key and value projections, but not from the query projection. Adding a constant query bias-vector b to the equation above is straightforward and also illustrated in Fig. 4:

该方案也适用于带有偏置的投影层(例如 Whisper 模型所使用的)。回顾上一节,我们可以消除键和值投影中的偏置,但不能消除查询投影中的偏置。在上述方程中添加一个常数查询偏置向量 b 是直接的,如图 4 所示:

Ai=QiKi=(xnWQ,i+b)(XWK,i)=((xnWQ,i+b)WK,i)X

Ai=QiKi=(xnWQ,i+b)(XWK,i)=((xnWQ,i+b)WK,i)X

However, this scheme doesn’t work if there is a positional encoding such as RoPE located between the projection layers and the dot-product calculation. But option 2 fully supports other relative position encoding (PE) schemes such as RPE of the T5 model, Alibi, Kerple and FIRE [32] which add a variable bias to the softmax arguments (instead of modifying the queries and keys before the dot-product calculation). See for example the FIRE paper [32], which shows that FIRE and even NoPE can outperform RoPE for long context.

然而,如果投影层和点积计算之间存在位置编码(如 RoPE),此方案将无法工作。但选项 2 完全支持其他相对位置编码(PE)方案,例如 T5 模型的 RPE、Alibi、Kerple 和 FIRE [32],这些方案在 softmax 参数中添加了一个可变偏差(而不是在点积计算之前修改查询和键)。例如,参见 FIRE 论文 [32],该论文表明,对于长上下文,FIRE 甚至 NoPE 可以优于 RoPE。

6 Support for encoder-decoder transformers

6 支持编码器-解码器 Transformer

In general, calculating K from V is not only possible for self-attention (see Fig. 1) but also for cross-attention. In this section, we present two context memory options for encoder-decoder transformers such as Whisper (speech-to-text), language translation models such as Google’s T5, and time series forecasting models such as Amazon’s Chronos models. One option is not limited to MHA only, but is also applicable to MQA and GQA. The table below compares the options, which are summarized as follows:

通常来说,从 V 计算 K 不仅适用于自注意力机制(见图 1),也适用于交叉注意力机制。在本节中,我们为编码器-解码器 Transformer(如 Whisper(语音转文本))、语言翻译模型(如 Google 的 T5)以及时间序列预测模型(如 Amazon 的 Chronos 模型)提供了两种上下文记忆选项。其中一种选项不仅适用于多头注意力机制 (MHA),也适用于多查询注意力机制 (MQA) 和分组查询注意力机制 (GQA)。下表对比了这些选项,总结如下:

• The baseline implementation uses complete KV-caches for both self-attention and cross-attent