Tensor Product Attention Is All You Need
Tensor Product Attention Is All You Need
Abstract
摘要
Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decomposition s to represent queries, keys, and values compactly, significantly shrinking KV cache size at inference time. By factorizing these representations into contextual low-rank components (contextual factorization) and seamlessly integrating with RoPE, TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation of language modeling tasks, we demonstrate that T6 exceeds the performance of standard Transformer baselines including MHA, MQA, GQA, and MLA across various metrics, including perplexity and a range of renowned evaluation benchmarks. Notably, TPA’s memory efficiency enables the processing of significantly longer sequences under fixed resource constraints, addressing a critical s cal ability challenge in modern language models. The code is available at https://github.com/tensorgi/T6.
扩展语言模型以处理更长的输入序列通常需要较大的键值(KV)缓存,导致推理过程中内存开销显著增加。本文提出了一种新颖的注意力机制——张量积注意力(Tensor Product Attention,TPA),该机制利用张量分解技术紧凑地表示查询、键和值,从而在推理时显著缩小KV缓存的大小。通过将这些表示分解为上下文低秩分量(上下文分解)并与RoPE无缝集成,TPA在提高模型质量的同时实现了内存效率的提升。基于TPA,我们引入了张量积注意力Transformer(T6),这是一种用于序列建模的新模型架构。通过对语言建模任务进行广泛的实证评估,我们证明T6在包括困惑度和一系列著名评估基准在内的各种指标上均优于标准Transformer基线模型,包括MHA、MQA、GQA和MLA。值得注意的是,TPA的内存效率使得在固定资源约束下能够处理显著更长的序列,解决了现代语言模型中的一个关键扩展性挑战。代码可在https://github.com/tensorgi/T6获取。
1 Introduction
1 引言
Large language models (LLMs) have revolutionized natural language processing, demonstrating exceptional performance across tasks (Brown et al., 2020; Chowdhery et al., 2023; Touvron et al., 2023; Bubeck et al., 2023). As these models evolve, their ability to process longer contexts becomes increasingly important for sophisticated applications such as document analysis, complex reasoning, and code completions. However, managing longer sequences during inference poses significant computational and memory challenges, particularly due to the storage of key-value (KV) caches (Zhang et al., 2023c; Liu et al., 2024c). Because memory consumption grows linearly with sequence length, the maximum context window is limited by practical hardware constraints.
大语言模型 (LLMs) 彻底改变了自然语言处理领域,在各种任务中展现了卓越的性能 (Brown et al., 2020; Chowdhery et al., 2023; Touvron et al., 2023; Bubeck et al., 2023)。随着这些模型的演进,它们处理更长上下文的能力对于文档分析、复杂推理和代码补全等复杂应用变得越来越重要。然而,在推理过程中管理更长的序列带来了显著的计算和内存挑战,特别是由于键值 (KV) 缓存的存储 (Zhang et al., 2023c; Liu et al., 2024c)。由于内存消耗随序列长度线性增长,最大上下文窗口受到实际硬件限制的制约。
A variety of solutions have been explored to address this memory bottleneck. Some approaches compress or selectively prune cached states through sparse attention patterns (Child et al., 2019) or token eviction strategies (Zhang et al., 2023c; Xiao et al., 2024; Ribar et al., 2024), though such methods risk discarding tokens that may later prove important. Other work proposes off-chip storage of keyvalue states (He & Zhai, 2024), at the expense of increased I/O latency. Attention variants like multi-query attention (MQA) (Shazeer, 2019) and grouped-query attention (GQA) (Ainslie et al., 2023) reduce per-token cache requirements by sharing keys and values across heads, but often compromise flexibility or require significant architectural modifications. Meanwhile, low-rank weight factorization methods such as LoRA (Hu et al., 2022) effectively reduce fine-tuning memory, yet do not address the KV cache overhead that dominates runtime. The recently introduced Multi-head Latent Attention (MLA) in Deepseek-V2 (Liu et al., 2024a) caches compressed key-value representations but needs additional position-encoded parameters per head due to incompatibility with Rotary Position Embedding (RoPE) efficiently (Su et al., 2024b).
为了解决这一内存瓶颈,人们探索了多种解决方案。一些方法通过稀疏注意力模式 (Child et al., 2019) 或 Token 淘汰策略 (Zhang et al., 2023c; Xiao et al., 2024; Ribar et al., 2024) 来压缩或选择性修剪缓存状态,尽管这些方法可能会丢弃后来可能被证明重要的 Token。其他工作提出了将键值状态存储在片外 (He & Zhai, 2024),但代价是增加了 I/O 延迟。像多查询注意力 (MQA) (Shazeer, 2019) 和分组查询注意力 (GQA) (Ainslie et al., 2023) 这样的注意力变体通过跨头共享键和值来减少每个 Token 的缓存需求,但通常会牺牲灵活性或需要显著的架构修改。与此同时,低秩权重分解方法如 LoRA (Hu et al., 2022) 有效减少了微调内存,但并未解决主导运行时开销的 KV 缓存问题。最近在 Deepseek-V2 (Liu et al., 2024a) 中引入的多头潜在注意力 (MLA) 缓存了压缩的键值表示,但由于与旋转位置嵌入 (RoPE) 不兼容,每个头需要额外的位置编码参数 (Su et al., 2024b)。

Figure 1: Tensor Product Attention (TPA) in the Tensor ProducT ATTenTion Transformer (T6). Different from multi-head attention, in each layer, firstly the hidden state goes through different linear layers to get the latent factor matrices A’s and B’s for query, key, and value. We additionally apply RoPE to $\mathbf{B}{Q}$ and $\mathbf{B}{K}$ for query and key. Then the multi-head query, key, and value vectors are attained by the tensor product of $\mathbf{A}{(\cdot)}$ and $\mathbf{B}{(\cdot)}$ . Finally, the output of TPA is produced by scaled dot-product attention followed by linear projection of concatenated results of multiple heads.
图 1: Tensor ProducT ATTenTion Transformer (T6) 中的张量积注意力 (Tensor Product Attention, TPA)。与多头注意力不同,在每一层中,首先隐藏状态通过不同的线性层得到查询、键和值的潜在因子矩阵 A 和 B。我们额外对查询和键的 $\mathbf{B}{Q}$ 和 $\mathbf{B}{K}$ 应用了 RoPE。然后,通过 $\mathbf{A}{(\cdot)}$ 和 $\mathbf{B}{(\cdot)}$ 的张量积得到多头查询、键和值向量。最后,TPA 的输出通过缩放点积注意力生成,随后对多头连接的结果进行线性投影。
In order to overcome the limitations of existing approaches, we introduce Tensor Product Attention (TPA), as illustrated in Figure 1, a novel architecture that uses higher-order tensors to factorize queries (Q), keys (K), and values (V) during attention computation. By dynamically factorizing activation s rather than static weights (e.g., LoRA), TPA constructs low-rank, contextual representations that substantially reduce KV cache memory usage with improved representational capacity. In practice, TPA can reduce the memory overhead by an order of magnitude compared to standard multi-head attention (MHA) with lower pre training validation loss (perplexity) and improved downstream performance.
为了克服现有方法的局限性,我们引入了张量积注意力 (Tensor Product Attention, TPA),如图 1 所示。这是一种新颖的架构,在注意力计算过程中使用高阶张量来分解查询 (Q)、键 (K) 和值 (V)。通过动态分解激活值而非静态权重(例如 LoRA),TPA 构建了低秩的上下文表示,显著减少了 KV 缓存的内存使用,同时提高了表示能力。在实际应用中,与标准的多头注意力 (MHA) 相比,TPA 可以将内存开销降低一个数量级,同时具有更低的预训练验证损失(困惑度)和更好的下游性能。
A key advantage of TPA is its native compatibility with rotary positional embeddings (RoPE) (Su et al., 2024b), enabling a straightforward drop-in replacement for multi-head attention (MHA) layers in modern LLM architectures such as LLaMA (Touvron et al., 2023) and Gemma (Team et al., 2024).
TPA 的一个关键优势是其与旋转位置嵌入 (RoPE) (Su et al., 2024b) 的原生兼容性,这使得它能够直接替代现代大语言模型架构(如 LLaMA (Touvron et al., 2023) 和 Gemma (Team et al., 2024))中的多头注意力 (MHA) 层。
Our primary contributions are summarized as follows:
我们的主要贡献总结如下:

Figure 2: Training loss and validation loss of pre training large-size (773M) models with different attention mechanisms on the FineWeb-Edu-100B dataset.
图 2: 在 FineWeb-Edu-100B 数据集上,使用不同注意力机制预训练大尺寸 (773M) 模型的训练损失和验证损失。
2 Background
2 背景
In this section, we review several classical forms of attention: Scaled Dot-Product Attention, MultiHead Attention (MHA) (Vaswani et al., 2017), Multi-Query Attention (MQA) (Shazeer, 2019), and Grouped Query Attention (GQA) (Ainslie et al., 2023), as well as Rotary Position Embedding (RoPE, Su et al. (2024b)). We also introduce a recent method called Multi-head Latent Attention (MLA) used in DeepSeek-V2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b).
在本节中,我们回顾了几种经典的注意力机制形式:缩放点积注意力 (Scaled Dot-Product Attention)、多头注意力 (MultiHead Attention, MHA) (Vaswani et al., 2017)、多查询注意力 (Multi-Query Attention, MQA) (Shazeer, 2019) 和分组查询注意力 (Grouped Query Attention, GQA) (Ainslie et al., 2023),以及旋转位置嵌入 (Rotary Position Embedding, RoPE, Su et al. (2024b))。我们还介绍了最近在 DeepSeek-V2 (Liu et al., 2024a) 和 DeepSeek-V3 (Liu et al., 2024b) 中使用的一种新方法,称为多头潜在注意力 (Multi-head Latent Attention, MLA)。
Notations. We use bold uppercase letters (e.g., X, $\mathbf{Q}$ ) for matrices, bold lowercase (e.g., a, b) for vectors, and italic uppercase (e.g., $\boldsymbol{W}{i}^{Q}$ ) for learnable parameter matrices. We denote by $[n]$ the set ${1,\ldots,n}$ for some positive integer $n$ . We use $\intercal$ to denote the transpose of a vector or a matrix. Let $d{\mathrm{model}}$ be the embedding dimension, $h$ the number of attention heads, $d_{h}$ the dimension per head, $\mathbf{x}{t},\in,\mathbb{R}^{d}$ the input for the $t$ -th token at a given attention layer, $\mathbf{X},\in,\mathbb{R}^{T\times d{\mathrm{model}}}$ denotes the input em beddings for $T$ tokens, and $\mathbf{Q}$ , K, $\mathbf{V}\in\breve{\mathbb{R}}^{T\times h\times d_{h}}$ denote the queries, keys, and values of $h$ heads for $T$ tokens. With a little abuse of notation, $\mathbf{Q}{i}$ , $\mathbf{K}{i}$ , $\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ denote the $i$ -th head of queries, keys, and values, and $\mathbf{Q}{t}$ , $\mathbf{K}{t}$ , $\mathbf{V}{t},\in,\mathbb{R}^{h\times d{h}}$ denote the heads of the query, key, and value for $t$ -th token.
符号说明。我们用粗体大写字母(例如,X, $\mathbf{Q}$)表示矩阵,粗体小写字母(例如,a, b)表示向量,斜体大写字母(例如,$\boldsymbol{W}{i}^{Q}$)表示可学习的参数矩阵。我们用 $[n]$ 表示集合 ${1,\ldots,n}$,其中 $n$ 是某个正整数。我们用 $\intercal$ 表示向量或矩阵的转置。设 $d{\mathrm{model}}$ 为嵌入维度,$h$ 为注意力头的数量,$d_{h}$ 为每个头的维度,$\mathbf{x}{t},\in,\mathbb{R}^{d}$ 为给定注意力层中第 $t$ 个 token 的输入,$\mathbf{X},\in,\mathbb{R}^{T\times d{\mathrm{model}}}$ 表示 $T$ 个 token 的输入嵌入,$\mathbf{Q}$、K、$\mathbf{V}\in\breve{\mathbb{R}}^{T\times h\times d_{h}}$ 表示 $T$ 个 token 的 $h$ 个头的查询、键和值。稍微滥用符号,$\mathbf{Q}{i}$、$\mathbf{K}{i}$、$\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ 表示第 $i$ 个头的查询、键和值,$\mathbf{Q}{t}$、$\mathbf{K}{t}$、$\mathbf{V}{t},\in,\mathbb{R}^{h\times d{h}}$ 表示第 $t$ 个 token 的查询、键和值的头。
Throughout the paper, $W^{Q},W^{K},W^{V}$ denote projection matrices for queries, keys, and values, respectively. In multi-head attention, each head is associated with its own set of $W_{i}^{Q},W_{i}^{K},W_{i}^{V}$ , and each has dimension $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ , where $d_{k}$ is typically set to $d_{h}$ , the dimension of each head.5 Similarly, we have an output projection matrix $W^{O}\in\mathbb{R}^{(h\cdot d_{h})\times d_{\mathrm{model}}}$ . For methods like MQA and GQA, some of these are shared or partially shared across heads, but their shapes remain consistent.
在整篇论文中,$W^{Q},W^{K},W^{V}$ 分别表示查询、键和值的投影矩阵。在多头注意力机制中,每个头都与其自己的一组 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}$ 相关联,每个矩阵的维度为 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$,其中 $d_{k}$ 通常设置为 $d_{h}$,即每个头的维度。同样地,我们有一个输出投影矩阵 $W^{O}\in\mathbb{R}^{(h\cdot d_{h})\times d_{\mathrm{model}}}$。对于像 MQA 和 GQA 这样的方法,这些矩阵中的一些在头之间是共享或部分共享的,但它们的形状保持一致。
We define the tensor product of two vectors as follows: for vectors $\mathbf{a}\in\mathbb{R}^{m},\mathbf{b}\in\mathbb{R}^{n}$ , the tensor product of a and $\mathbf{b}$ is:
我们定义两个向量的张量积如下:对于向量 $\mathbf{a}\in\mathbb{R}^{m},\mathbf{b}\in\mathbb{R}^{n}$,$\mathbf{a}$ 和 $\mathbf{b}$ 的张量积为:
\begin{array}{r}{\mathbf{a}\otimes\mathbf{b}=\mathbf{C}\in\mathbb{R}^{m\times n},\mathrm{with~}C_{i j}=a_{i}b_{j},}\end{array}
\begin{array}{r}{\mathbf{a}\otimes\mathbf{b}=\mathbf{C}\in\mathbb{R}^{m\times n},\mathrm{with~}C_{i j}=a_{i}b_{j},}\end{array}
where $a_{i}$ and $b_{j}$ are the $i$ -th and $j$ -th elements of a and $\mathbf{b}$ respectively, and $C_{i j}$ is the $(i,j)$ -th entry of $\mathbf{C}$ . We also define the vector iz ation of a matrix $\mathbf{C}\in\mathbb{R}^{m\times\bar{n}}$ by:
其中 $a_{i}$ 和 $b_{j}$ 分别是向量 $\mathbf{a}$ 和 $\mathbf{b}$ 的第 $i$ 个和第 $j$ 个元素,$C_{i j}$ 是矩阵 $\mathbf{C}$ 的第 $(i,j)$ 个元素。我们还定义了矩阵 $\mathbf{C}\in\mathbb{R}^{m\times\bar{n}}$ 的向量化操作:
\operatorname{vec}(\mathbf{C})=\mathbf{d}\in\mathbb{R}^{m n},\operatorname{with}d_{i\cdot n+j}=C_{i j},
\operatorname{vec}(\mathbf{C})=\mathbf{d}\in\mathbb{R}^{m n},\operatorname{with}d_{i\cdot n+j}=C_{i j},
where $d_{i\cdot n+j}$ is the $(i\cdot n+j)$ -th element of $\mathbf{d}$ .
其中 $d_{i\cdot n+j}$ 是 $\mathbf{d}$ 的第 $(i\cdot n+j)$ 个元素。
2.1 Scaled Dot-Product Attention
2.1 缩放点积注意力 (Scaled Dot-Product Attention)
Scaled dot-product attention (Vaswani et al., 2017) determines how to focus on different parts of an input sequence by comparing queries $(\mathbf{Q})$ and keys $({\bf K})$ . It produces a weighted combination of the
缩放点积注意力机制 (Vaswani et al., 2017) 通过比较查询 $(\mathbf{Q})$ 和键 $({\bf K})$ 来确定如何关注输入序列的不同部分。它生成一个加权组合的
values $({\bf V})$ . Formally, the attention output is:
值 $({\bf V})$。形式上,注意力输出为:
\begin{array}{r}{\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Softmax}\Big(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}\Big)\,\mathbf{V},}\end{array}
\begin{array}{r}{\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Softmax}\Big(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}\Big)\,\mathbf{V},}\end{array}
where each of $\mathbf{Q},\mathbf{K},\mathbf{V}$ is an $(n\times d_{k})$ matrix for $n$ tokens and key dimension $d_{k}$ . The division by $\sqrt{d_{k}}$ stabilizes training by controlling the scale of the inner products.
其中,$\mathbf{Q}$、$\mathbf{K}$、$\mathbf{V}$ 都是 $(n\times d_{k})$ 的矩阵,$n$ 表示 token 的数量,$d_{k}$ 表示键的维度。除以 $\sqrt{d_{k}}$ 通过控制内积的规模来稳定训练。
2.2 Multi-Head Attention (MHA)
2.2 多头注意力机制 (Multi-Head Attention, MHA)
Multi-Head Attention (MHA) extends scaled dot-product attention by dividing the model’s internal representation into several heads. Each head learns different projections for queries, keys, and values, allowing the model to attend to different types of information. For each token embedding $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ , MHA computes each head $i$ as follows:
多头注意力机制 (Multi-Head Attention, MHA) 通过将模型的内部表示划分为多个头来扩展缩放点积注意力。每个头为查询、键和值学习不同的投影,使模型能够关注不同类型的信息。对于每个 Token 嵌入 $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$,MHA 计算每个头 $i$ 如下:
\begin{array}{r l r}&{\mathbf{Q}_{t,i}=(W_{i}^{Q})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},\quad\mathbf{K}_{t,i}=(W_{i}^{K})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},\quad\mathbf{V}_{t,i}=(W_{i}^{V})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},}&\\ &{\qquad\qquad\qquad\qquad\mathrm{head}_{i}=\mathrm{Attention}\Bigl(\mathbf{Q}_{i},\mathbf{K}_{i},{\mathbf{V}}_{i}\Bigr),}\end{array}```
\begin{array}{r l r}&{\mathbf{Q}{t,i}=(W{i}^{Q})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},\quad\mathbf{K}{t,i}=(W{i}^{K})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},\quad\mathbf{V}{t,i}=(W{i}^{V})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},}&\ &{\qquad\qquad\qquad\qquad\mathrm{head}{i}=\mathrm{Attention}\Bigl(\mathbf{Q}{i},\mathbf{K}{i},{\mathbf{V}}{i}\Bigr),}\end{array}```
where $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\quad\in\quad\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ are learnable projection matrices for the $i$ -th head, $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i},\in,\mathbb{R}^{T\times d{h}}$ . After computing each head’s attention, the outputs are concatenated and mapped back to the original dimension via another matrix $W^{O}\in\mathbb{R}^{h d_{h}\times\dot{d}_{\mathrm{model}}}$ :
其中 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\quad\in\quad\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ 是可学习的投影矩阵,用于第 $i$ 个头,$\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i},\in,\mathbb{R}^{T\times d{h}}$。在计算每个头的注意力后,输出被拼接并通过另一个矩阵 $W^{O}\in\mathbb{R}^{h d_{h}\times\dot{d}_{\mathrm{model}}}$ 映射回原始维度:
\mathrm{MHA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}_{1},\dots,\mathbf{head}_{h}\big)\,W^{O}.```
\mathrm{MHA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}{1},\dots,\mathbf{head}{h}\big),W^{O}.```
MHA can capture a rich set of dependencies while each head focuses on different subspaces.
MHA 可以捕获丰富的依赖关系,而每个头专注于不同的子空间。
2.3 Multi-Query Attention (MQA)
2.3 多查询注意力机制 (Multi-Query Attention, MQA)
Multi-Query Attention (MQA) (Shazeer, 2019) significantly reduces memory usage by sharing keys and values across heads, while still preserving unique query projections. For a sequence of embeddings $\mathbf{X}\in\mathbb{R}^{T\times d_{\mathrm{model}}}$ ,
多查询注意力 (Multi-Query Attention, MQA) (Shazeer, 2019) 通过在多个头之间共享键和值,显著减少了内存使用,同时仍保留了唯一的查询投影。对于嵌入序列 $\mathbf{X}\in\mathbb{R}^{T\times d_{\mathrm{model}}}$,
\mathbf{Q}_{i}=\mathbf{X}W_{i}^{Q},\quad\mathbf{K}_{\mathrm{shared}}=\mathbf{X}W_{\mathrm{shared}}^{K},\quad\mathbf{V}_{\mathrm{shared}}=\mathbf{X}W_{\mathrm{shared}}^{V}.```
\mathbf{Q}{i}=\mathbf{X}W{i}^{Q},\quad\mathbf{K}{\mathrm{shared}}=\mathbf{X}W{\mathrm{shared}}^{K},\quad\mathbf{V}{\mathrm{shared}}=\mathbf{X}W{\mathrm{shared}}^{V}.```
Hence, each head $i$ only has a distinct query $\mathbf{Q}{i}\in\mathbb{R}^{T\times d{k}}$ , but shares the same key $\mathbf{K}{\mathrm{shared}}\in\mathbb{R}^{T\times d{k}}$ and value $\mathbf{V}{\mathrm{shared}}\in\mathbb{R}^{\check{T}\times d{k}}$ . In practice, this means:
因此,每个头 $i$ 仅有一个独特的查询 $\mathbf{Q}{i}\in\mathbb{R}^{T\times d{k}}$,但共享相同的键 $\mathbf{K}{\mathrm{shared}}\in\mathbb{R}^{T\times d{k}}$ 和值 $\mathbf{V}{\mathrm{shared}}\in\mathbb{R}^{\check{T}\times d{k}}$。在实践中,这意味着:
\pmb{W}_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}},\quad\pmb{W}_{\mathrm{shared}}^{K},\pmb{W}_{\mathrm{shared}}^{V}\;\in\;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}.```
\pmb{W}{i}^{Q}\in\mathbb{R}^{d{\mathrm{model}}\times d_{k}},\quad\pmb{W}{\mathrm{shared}}^{K},\pmb{W}{\mathrm{shared}}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}.```
The resulting MQA operation is:
生成的 MQA 操作为:
\operatorname{MQA}(\mathbf{X})=\operatorname{Concat}\!\left(\mathbf{head}_{1},\dots,\mathbf{head}_{h}\right)W^{O},```
\operatorname{MQA}(\mathbf{X})=\operatorname{Concat}!\left(\mathbf{head}{1},\dots,\mathbf{head}{h}\right)W^{O},```
where
其中
\mathbf{head}_{i}=\operatorname{Attention}\!\left(\mathbf{Q}_{i},\mathbf{K}_{\mathrm{shared}},\mathbf{V}_{\mathrm{shared}}\right)\!.```
\mathbf{head}{i}=\operatorname{Attention}!\left(\mathbf{Q}{i},\mathbf{K}{\mathrm{shared}},\mathbf{V}{\mathrm{shared}}\right)!.```
By sharing these key and value projections, MQA cuts down on memory usage (especially for the key-value cache in auto regressive inference) but loses some expressivity since all heads must rely on the same key/value representations.
通过共享这些键和值投影,MQA 减少了内存使用(特别是在自回归推理中的键值缓存),但由于所有头必须依赖相同的键/值表示,因此失去了一些表达能力。
2.4 Grouped Query Attention (GQA)
2.4 分组查询注意力机制 (Grouped Query Attention, GQA)
Grouped Query Attention (GQA) (Ainslie et al., 2023) generalizes MHA and MQA by grouping heads. Specifically, we partition the $h$ total heads into $G$ groups. Each group has a single set of keys and values, but each individual head within that group still retains its own query projection. Formally, if $g(i)$ maps a head $i\in[h]$ to its group index $g\in[G]$ , then:
分组查询注意力 (Grouped Query Attention, GQA) (Ainslie et al., 2023) 通过将头部分组来泛化多头注意力 (MHA) 和多查询注意力 (MQA)。具体来说,我们将总头数 $h$ 划分为 $G$ 组。每组共享一组键和值,但组内的每个头仍然保留自己的查询投影。形式上,如果 $g(i)$ 将头 $i\in[h]$ 映射到其组索引 $g\in[G]$,则:
\mathbf{K}_{g(i)}=\mathbf{X}\,W_{g(i)}^{K},\quad\mathbf{V}_{g(i)}=\mathbf{X}\,W_{g(i)}^{V},\quad\mathbf{Q}_{i}=\mathbf{X}\,W_{i}^{Q},```
\mathbf{K}{g(i)}=\mathbf{X},W{g(i)}^{K},\quad\mathbf{V}{g(i)}=\mathbf{X},W{g(i)}^{V},\quad\mathbf{Q}{i}=\mathbf{X},W{i}^{Q},```
and
和
\mathrm{head}_{i}=\mathrm{Attention}\Big(\mathbf{Q}_{i},\mathbf{K}_{g(i)},\mathbf{V}_{g(i)}\Big).```
\mathrm{head}{i}=\mathrm{Attention}\Big(\mathbf{Q}{i},\mathbf{K}{g(i)},\mathbf{V}{g(i)}\Big).```
Again, ${\cal W}{g}^{K},{\cal W}{g}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ for each group $g$ , and $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ for each head $i$ . The complete output is again a concatenation of all heads:
再次,对于每个组 $g$,${\cal W}{g}^{K},{\cal W}{g}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$,而对于每个头 $i$,$W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$。完整的输出再次是所有头的拼接:
\operatorname{GQA}(\mathbf{X})=\operatorname{Concat}\!\left(\operatorname{head}_{1},\dots,\operatorname{head}_{h}\right)\!W^{O}.```
\operatorname{GQA}(\mathbf{X})=\operatorname{Concat}!\left(\operatorname{head}{1},\dots,\operatorname{head}{h}\right)!W^{O}.```
By adjusting $G$ between 1 and $h$ , GQA can interpolate between sharing all key/value projections across heads (i.e., MQA) and having one set of projections per head (i.e., MHA).
通过调整 $G$ 在 1 到 $h$ 之间的值,GQA 可以在多头之间共享所有键/值投影(即 MQA)和每个头拥有一组投影(即 MHA)之间进行插值。
2.5 Rotary Position Embedding (RoPE)
2.5 旋转位置编码 (Rotary Position Embedding, RoPE)
Many recent LLMs use rotary position embedding (RoPE; $\mathrm{Su}$ et al., 2024b) to encode positional information in the query/key vectors. Specifically, let $\mathrm{RoPE}{t}$ denote the rotation operator $\mathbf{T}{t}\mathbf{\Sigma}\in$ $\mathbb{R}^{d_{h}\times d_{h}}$ corresponding to the $t$ -th position. $\mathbf{T}{t}$ is a block-diagonal matrix, which consists of blockdiagonal matrix $\cdot\biggl(\cos(t\theta{j}),,,,,,-\sin(t\theta_{j})\biggr),j\in{1,\cdots,d_{h}/2}$ , where ${\theta_{j}}$ are pre-defined frequency parameters, e.g., $\theta_{j}=1/10000^{2j/d_{h}}$ . Then we define
许多最近的大语言模型使用旋转位置嵌入 (RoPE; Su 等, 2024b) 来在查询/键向量中编码位置信息。具体来说,令 $\mathrm{RoPE}{t}$ 表示对应于第 $t$ 个位置的旋转算子 $\mathbf{T}{t}\mathbf{\Sigma}\in$ $\mathbb{R}^{d_{h}\times d_{h}}$。$\mathbf{T}{t}$ 是一个块对角矩阵,由块对角矩阵 $\cdot\biggl(\cos(t\theta{j}),,,,,,-\sin(t\theta_{j})\biggr),j\in{1,\cdots,d_{h}/2}$ 组成,其中 ${\theta_{j}}$ 是预定义的频率参数,例如 $\theta_{j}=1/10000^{2j/d_{h}}$。然后我们定义
\mathrm{RoPE}\left(\mathbf{Q}_{t}\right)\triangleq\mathbf{Q}_{t}\mathbf{T}_{t},\quad\mathrm{where~}\mathbf{Q}_{t}\in\mathbb{R}^{h\times d_{h}}.```
\mathrm{RoPE}\left(\mathbf{Q}{t}\right)\triangleq\mathbf{Q}{t}\mathbf{T}{t},\quad\mathrm{where~}\mathbf{Q}{t}\in\mathbb{R}^{h\times d_{h}}.```
A fundamental property is that
一个基本属性是
\mathbf{T}_{t}\,\mathbf{T}_{s}^{\top}=\mathbf{T}_{t-s},```
\mathbf{T}{t},\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s},```
which ensures that relative positions $(t-s)$ are preserved, thereby providing a form of translation invariance in the rotary position embedding.
这确保了相对位置 $(t-s)$ 得以保留,从而在旋转位置嵌入中提供了一种平移不变性。
2.6 Multi-head Latent Attention (MLA)
2.6 多头潜在注意力机制 (Multi-head Latent Attention, MLA)
Below, we briefly outline the Multi-head Latent Attention (MLA) approach used by DeepSeekV2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b). MLA introduces a low-rank compression of the keys and values to reduce the Key-Value (KV) caching cost at inference.
下面,我们简要概述了 DeepSeekV2 (Liu et al., 2024a) 和 DeepSeek-V3 (Liu et al., 2024b) 使用的多头潜在注意力 (Multi-head Latent Attention, MLA) 方法。MLA 引入了键和值的低秩压缩,以减少推理时的键值 (Key-Value, KV) 缓存成本。
\begin{array}{r l r}&{}&{{\mathbf{C}}^{K V}={\mathbf{X}}{\boldsymbol{W}}^{D K V},\quad({\boldsymbol{W}}^{D K V}\in\mathbb{R}^{d_{m o d1}\times d_{c}}),}\\ &{}&{\mathrm{Concat}\big({\mathbf{K}}_{1}^{C},{\mathbf{K}}_{2}^{C},\dots,{\mathbf{K}}_{h}^{C}\big)={\mathbf{K}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U K},\quad({\boldsymbol{W}}^{U K}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\\ &{}&{{\mathbf{K}}^{R}=\mathrm{RoPE}\big({\mathbf{X}}{\boldsymbol{W}}^{K R}\big),\quad({\boldsymbol{W}}^{K R R}\in\mathbb{R}^{d_{m o d1}\times d_{h}^{R}}),}\\ &{}&{{\mathbf{K}}_{i}=\mathrm{Concat}\big({\mathbf{K}}_{i}^{C},{\mathbf{K}}^{R}\big),}\\ &{}&{\mathrm{Concat}\big({\mathbf{V}}_{1}^{C},{\mathbf{V}}_{2}^{C},\dots,{\mathbf{V}}_{h}^{C}\big)={\mathbf{V}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U V},\quad({\boldsymbol{W}}^{U V}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\end{array}```
\begin{array}{r l r}&{}&{{\mathbf{C}}^{K V}={\mathbf{X}}{\boldsymbol{W}}^{D K V},\quad({\boldsymbol{W}}^{D K V}\in\mathbb{R}^{d_{m o d1}\times d_{c}}),}\ &{}&{\mathrm{Concat}\big({\mathbf{K}}{1}^{C},{\mathbf{K}}{2}^{C},\dots,{\mathbf{K}}_{h}^{C}\big)={\mathbf{K}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U K},\quad({\boldsymbol{W}}^{U K}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\ &{}&{{\mathbf{K}}^{R}=\mathrm{RoPE}\big({\mathbf{X}}{\boldsymbol{W}}^{K R}\big),\quad({\boldsymbol{W}}^{K R R}\in\mathbb{R}^{d_{m o d1}\times d_{h}^{R}}),}\ &{}&{{\mathbf{K}}_{i}=\mathrm{Concat}\big({\mathbf{K}}_{i}^{C},{\mathbf{K}}^{R}\big),}\ &{}&{\mathrm{Concat}\big({\mathbf{V}}_{1}^{C},{\mathbf{V}}_{2}^{C},\dots,{\mathbf{V}}_{h}^{C}\big)={\mathbf{V}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U V},\quad({\boldsymbol{W}}^{U V}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\end{array}```
where $\mathbf{C}^{K V}\in\mathbb{R}^{T\times d_{c}}$ is the compressed KV latent (with $d_{c}\ll d_{h}h)$ ), and $\mathrm{{RoPE}(\cdot)}$ represents the RoPE transform applied to the separate key embeddings $\mathbf{K}^{R}$ of dimension $d_{h}^{R}$ . Thus, only $\mathbf{C}^{K V}$ and $\mathbf{K}^{R}$ need to be cached, reducing KV memory usage while largely preserving performance compared to standard MHA (Vaswani et al., 2017).
其中 $\mathbf{C}^{K V}\in\mathbb{R}^{T\times d_{c}}$ 是压缩的 KV 潜在表示(其中 $d_{c}\ll d_{h}h)$),$\mathrm{{RoPE}(\cdot)}$ 表示应用于维度为 $d_{h}^{R}$ 的独立键嵌入 $\mathbf{K}^{R}$ 的 RoPE 变换。因此,只需要缓存 $\mathbf{C}^{K V}$ 和 $\mathbf{K}^{R}$,从而在减少 KV 内存使用的同时,与标准 MHA (Vaswani et al., 2017) 相比,性能基本保持不变。
MLA also compresses the queries, lowering their training-time memory footprint:
MLA 还压缩了查询,降低了它们在训练时的内存占用:
\begin{array}{r l r}&{}&{{\mathbf{C}}^{Q}={\mathbf{X}}W^{D Q},\quad(W^{D Q}\in\mathbb{R}^{d_{\mathrm{modt}}\times d_{c}^{\prime}}),\quad}\\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{C},{\mathbf{Q}}_{2}^{C},\,.\,.\,.\,,{\mathbf{Q}}_{h}^{C}\bigr)={\mathbf{Q}}^{C}={\mathbf{C}}^{Q}W^{U Q},\quad(W^{U Q}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}h}),\quad}\\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{R},\,{\mathbf{Q}}_{2}^{R},\,.\,.\,.\,,\,{\mathbf{Q}}_{h}^{R}\bigr)={\mathbf{Q}}^{R}=\mathrm{RoPE}\bigl({\mathbf{C}}^{Q}W^{Q R}\bigr),\quad(W^{Q R}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}^{R}h}),}\\ &{}&{\quad{\mathbf{Q}}=\mathrm{Concat}\bigl({\mathbf{Q}}^{C},{\mathbf{Q}}^{R}\bigr).}\end{array}```
\begin{array}{r l r}&{}&{{\mathbf{C}}^{Q}={\mathbf{X}}W^{D Q},\quad(W^{D Q}\in\mathbb{R}^{d_{\mathrm{modt}}\times d_{c}^{\prime}}),\quad}\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}{1}^{C},{\mathbf{Q}}{2}^{C},,.,.,.,,{\mathbf{Q}}_{h}^{C}\bigr)={\mathbf{Q}}^{C}={\mathbf{C}}^{Q}W^{U Q},\quad(W^{U Q}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}h}),\quad}\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{R},,{\mathbf{Q}}_{2}^{R},,.,.,.,,,{\mathbf{Q}}_{h}^{R}\bigr)={\mathbf{Q}}^{R}=\mathrm{RoPE}\bigl({\mathbf{C}}^{Q}W^{Q R}\bigr),\quad(W^{Q R}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}^{R}h}),}\ &{}&{\quad{\mathbf{Q}}=\mathrm{Concat}\bigl({\mathbf{Q}}^{C},{\mathbf{Q}}^{R}\bigr).}\end{array}```
Here, $\mathbf{C}^{Q}\in\mathbb{R}^{T\times d_{c}^{\prime}}$ (with $d_{c}^{\prime}\ll d_{h}h)$ is the compressed query latent. As above, each ${\pmb W}^{D Q},{\pmb W}^{U Q}$ , and $W^{Q R}$ connects these lower-dimensional query latents back to $h$ heads of dimension $d_{h}+d_{h}^{R}$ .
这里,$\mathbf{C}^{Q}\in\mathbb{R}^{T\times d_{c}^{\prime}}$(其中$d_{c}^{\prime}\ll d_{h}h)$是压缩后的查询潜在表示。如上所述,每个${\pmb W}^{D Q},{\pmb W}^{U Q}$和$W^{Q R}$将这些低维查询潜在表示连接回维度为$d_{h}+d_{h}^{R}$的$h$个头。
Given compressed queries, keys, and values, the final attention output for the $t$ -th token is:
给定压缩后的查询、键和值,第 $t$ 个 Token 的最终注意力输出为:
\begin{array}{r l}&{\mathbf{O}_{i}=\mathrm{Softmax}\Big(\frac{\mathbf{Q}_{i}\mathbf{K}_{i}^{\top}}{\sqrt{d_{h}+d_{h}^{R}}}\Big)\ \mathbf{V}_{i}^{C},}\\ &{\mathbf{U}=\mathrm{Concat}\big(\mathbf{O}_{1},\mathbf{O}_{2},...\,,\mathbf{O}_{h}\big)W^{O},}\end{array}```
\begin{array}{r l}&{\mathbf{O}{i}=\mathrm{Softmax}\Big(\frac{\mathbf{Q}{i}\mathbf{K}{i}^{\top}}{\sqrt{d{h}+d_{h}^{R}}}\Big)\ \mathbf{V}{i}^{C},}\ &{\mathbf{U}=\mathrm{Concat}\big(\mathbf{O}{1},\mathbf{O}{2},...,,\mathbf{O}{h}\big)W^{O},}\end{array}```
where $W^{O}\in\mathbb{R}^{(d_{h}h)\times{d_{\mathrm{model}}}}$ is the output projection.
其中 $W^{O}\in\mathbb{R}^{(d_{h}h)\times{d_{\mathrm{model}}}}$ 是输出投影。
In inference time, $\mathbf{C}^{K V}$ and $\mathbf{K}^{R}$ can be cached to accelerate decoding. In detail, when RoPE is ignored, the inner product $\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}$ (where $\mathbf{q}{t,i},\mathbf{k}{s,i},\in,\mathbb{R}^{d})$ of the $i$ -th head between $t$ -th and $s$ -th tokens can be calculated using the hidden state $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ for $t$ -th token and the cached latent state $\mathbf{c}{s}^{K V}\in\mathbb{R}^{d{c}}$ for $s$ -th token:
在推理时,$\mathbf{C}^{K V}$ 和 $\mathbf{K}^{R}$ 可以被缓存以加速解码。具体来说,当忽略 RoPE 时,第 $i$ 个头在第 $t$ 个和第 $s$ 个 Token 之间的内积 $\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}$(其中 $\mathbf{q}{t,i},\mathbf{k}{s,i},\in,\mathbb{R}^{d}$)可以通过第 $t$ 个 Token 的隐藏状态 $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ 和第 $s$ 个 Token 的缓存潜在状态 $\mathbf{c}{s}^{K V}\in\mathbb{R}^{d{c}}$ 来计算:
\mathbf{q}_{t,i}^{\top}\mathbf{k}_{s,i}=[(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}_{t}]^{\top}[(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V},```
\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}=[(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}{t}]^{\top}[(W{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V},```
where ${W}{i}^{(\cdot)}$ is the $i^{\th}$ -th head of the original weight, and $[\pmb{W}{i}^{D Q}\pmb{W}{i}^{U Q}(\pmb{W}{i}^{U K})^{\top}]$ can be computed previously for faster decoding. However, this process fails when RoPE is considered according to Su (2024). Since RoPE can be considered as multiplication with a block-diagonal matrix $\mathbf{T}{t}\in\mathbb{R}^{d{h}\times d_{h}}$ (see Section 2.5), with the property (2.1) that $\dot{\mathbf{T}}{t}\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s}$ , then
其中 ${W}{i}^{(\cdot)}$ 是原始权重的第 $i^{\th}$ 个头,且 $[\pmb{W}{i}^{D Q}\pmb{W}{i}^{U Q}(\pmb{W}{i}^{U K})^{\top}]$ 可以预先计算以加快解码速度。然而,根据 Su (2024) 的研究,当考虑 RoPE 时,这一过程会失败。由于 RoPE 可以被视为与块对角矩阵 $\mathbf{T}{t}\in\mathbb{R}^{d{h}\times d_{h}}$ 的乘法(见第 2.5 节),并且具有性质 (2.1) $\dot{\mathbf{T}}{t}\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s}$,那么
\begin{array}{r l}&{\mathbf{q}_{t,i}^{\top}\mathbf{k}_{s,i}=[{\mathbf{T}_{t}}^{\top}(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}_{t}]^{\top}[{\mathbf{T}_{s}}^{\top}(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]}\\ &{\qquad\qquad=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}\mathbf{T}_{t-s}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V}.}\end{array}```
\begin{array}{r l}&{\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}=[{\mathbf{T}{t}}^{\top}(W{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}{t}]^{\top}[{\mathbf{T}{s}}^{\top}(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]}\ &{\qquad\qquad=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}\mathbf{T}_{t-s}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V}.}\end{array}```
Different from (2.2), acceleration by pre-computing $[W_{i}^{D Q}W_{i}^{U Q}{\bf T}{t-s}(W{i}^{U K})^{\top}]$ fails since it varies for different $(\boldsymbol{t},\boldsymbol{s})$ position pairs. Therefore, MLA adds the additional $\mathbf{k}_{t}^{R}$ part with a relatively smaller size for RoPE compatibility. In Section 3.2, we will show that TPA addresses the issue of RoPE-incompatibility by applying tensor product.
与 (2.2) 不同,通过预计算 $[W_{i}^{D Q}W_{i}^{U Q}{\bf T}{t-s}(W{i}^{U K})^{\top}]$ 来加速的方法失效了,因为它会因不同的 $(\boldsymbol{t},\boldsymbol{s})$ 位置对而变化。因此,MLA 添加了额外的 $\mathbf{k}_{t}^{R}$ 部分,其大小相对较小,以兼容 RoPE。在第 3.2 节中,我们将展示 TPA 如何通过应用张量积来解决 RoPE 不兼容的问题。
3 Tensor Product Attention
3 张量积注意力 (Tensor Product Attention)
In this section, we provide a detailed description of our proposed Tensor Product Attention (TPA), which allows contextual low-rank factorization for queries, keys, and values. First, we explain how TPA factorizes queries, keys, and values with explicit tensor shapes. Next, we describe how TPA can be integrated into the multi-head attention framework and how it reduces memory consumption in KV caching at inference time. Finally, we show how RoPE can seamlessly integrate with TPA (including a pre-rotated variant).
在本节中,我们详细介绍了我们提出的张量积注意力机制 (Tensor Product Attention, TPA),它允许对查询、键和值进行上下文低秩分解。首先,我们解释了 TPA 如何通过显式的张量形状对查询、键和值进行分解。接着,我们描述了 TPA 如何集成到多头注意力框架中,以及它如何在推理时减少 KV 缓存的内存消耗。最后,我们展示了 RoPE 如何与 TPA 无缝集成(包括预旋转变体)。
3.1 Tensor Factorization of Queries, Keys, and Values
3.1 查询、键和值的张量分解
Let $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}}$ for $t,=,1,\ldots,T$ be the hidden-state vector corresponding to the $t$ -th token in a sequence of length $T$ . A typical multi-head attention block has $h$ heads, each of dimension $d_{h}$ , satisfying $d_{\mathrm{model}}~=~h~\times~d_{h}$ . Standard attention projects the entire sequence into three tensors, Q, K, $\bar{\textbf{V}}\in\mathbb{R}^{T\times h\times d_{h}}$ ,× where $\mathbf{Q}_{t},\mathbf{K}_{t},\mathbf{V}_{t}\in\mathbb{R}^{h\times\mathbf{\hat{\boldsymbol{d}}}_{h}}$ denote the slices for the $t$ -th token.
设 $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}}$ 对于 $t,=,1,\ldots,T$ 是长度为 $T$ 的序列中第 $t$ 个 Token 对应的隐藏状态向量。一个典型的多头注意力块有 $h$ 个头,每个头的维度为 $d_{h}$,满足 $d_{\mathrm{model}}~=~h~\times~d_{h}$。标准注意力将整个序列投影到三个张量 Q、K 和 $\bar{\textbf{V}}\in\mathbb{R}^{T\times h\times d_{h}}$ 中,其中 $\mathbf{Q}_{t},\mathbf{K}_{t},\mathbf{V}_{t}\in\mathbb{R}^{h\times\mathbf{\hat{\boldsymbol{d}}}_{h}}$ 表示第 $t$ 个 Token 的切片。
Contextual Factorization $\mathbf{(CF)}$ . Instead of forming each head’s query, key, or value via a single linear map, TPA factorizes each $\mathbf{Q}{t},\mathbf{K}{t},\mathbf{V}{t}$ into a sum of (contextual) tensor products whose ranks are $R{q},R_{k}$ , and $R_{v}$ , respectively and may differ. Specifically, for each token $t$ , with a small abuse of notation, we define:
上下文分解 (Contextual Factorization, CF)。TPA 不是通过单一的线性映射来形成每个头的查询、键或值,而是将每个 $\mathbf{Q}{t},\mathbf{K}{t},\mathbf{V}{t}$ 分解为秩分别为 $R{q},R_{k}$ 和 $R_{v}$ 的(上下文)张量积之和,这些秩可能不同。具体来说,对于每个 token $t$,我们定义(稍微滥用符号):
\mathbf{Q}_{t}=\frac{1}{R_{Q}}\sum_{r=1}^{R_{Q}}\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{Q}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}},```
\mathbf{Q}{t}=\frac{1}{R{Q}}\sum_{r=1}^{R_{Q}}\mathbf{a}{r}^{Q}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{Q}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{Q}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}},```
\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{K}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{K}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}},```
\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{K}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{K}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{K}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}},```
\mathbf{V}_{t}=\frac{1}{R_{V}}\sum_{r=1}^{R_{V}}\mathbf{a}_{r}^{V}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{V}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{V}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{V}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}}.```
\mathbf{V}{t}=\frac{1}{R{V}}\sum_{r=1}^{R_{V}}\mathbf{a}{r}^{V}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{V}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{V}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{V}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}}.```
Hence, for queries, each tensor product $\mathbf{a}{r}^{Q}(\mathbf{x}{t})\otimes\mathbf{b}{r}^{Q}(\mathbf{x}{t})\colon\mathbb{R}^{h}\times\mathbb{R}^{d_{h}}\to\mathbb{R}^{h\times d_{h}}$ adds up to form the query slice $\mathbf{Q}{t}\in\mathbb{R}^{h\times d{h}}$ . Similarly, analogous definitions apply to key slice $\mathbf{K}{t}$ and value slice $\mathbf{V}{t}$ .
因此,对于查询,每个张量积 $\mathbf{a}{r}^{Q}(\mathbf{x}{t})\otimes\mathbf{b}{r}^{Q}(\mathbf{x}{t})\colon\mathbb{R}^{h}\times\mathbb{R}^{d_{h}}\to\mathbb{R}^{h\times d_{h}}$ 相加形成查询切片 $\mathbf{Q}{t}\in\mathbb{R}^{h\times d{h}}$。类似地,相同的定义适用于键切片 $\mathbf{K}{t}$ 和值切片 $\mathbf{V}{t}$。
Latent Factor Maps. Each factor in the tensor product depends on the token’s hidden state $\mathbf{x}_{t}$ . For example, for queries, we can write:
潜在因子映射。张量积中的每个因子都依赖于 Token 的隐藏状态 $\mathbf{x}_{t}$。例如,对于查询,我们可以写成:
\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{a^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{h},\quad\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{b^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},```
\mathbf{a}{r}^{Q}(\mathbf{x}{t})=W_{r}^{a^{Q}},\mathbf{x}{t}\in\mathbb{R}^{h},\quad\mathbf{b}{r}^{Q}(\mathbf{x}{t})=W{r}^{b^{Q}},\mathbf{x}{t}\in\mathbb{R}^{d{h}},```
and similarly for keys and values.
同样适用于键和值。
One often merges the rank index into a single output dimension. For instance, for queries:
通常会将秩索引合并为单一的输出维度。例如,对于查询:
\mathbf{a}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{a^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{R_{q}\cdot h},\quad\mathbf{b}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{b^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{R_{q}\cdot d_{h}},```
\mathbf{a}^{Q}(\mathbf{x}{t})=\boldsymbol{W}^{a^{Q}},\mathbf{x}{t}\in\mathbb{R}^{R_{q}\cdot h},\quad\mathbf{b}^{Q}(\mathbf{x}{t})=\boldsymbol{W}^{b^{Q}},\mathbf{x}{t}\in\mathbb{R}^{R_{q}\cdot d_{h}},```
which are then reshaped into $\mathbf{A}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times h}$ and $\mathbf{B}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times d_{h}}$ . Summing over $R_{q}$ and scaled by $\frac{1}{R_{q}}$ yields
然后将其重塑为 $\mathbf{A}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times h}$ 和 $\mathbf{B}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times d_{h}}$。对 $R_{q}$ 求和并通过 $\frac{1}{R_{q}}$ 进行缩放得到
\mathbf{Q}_{t}=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\mathbf{B}_{Q}(\mathbf{x}_{t})\in\mathbb{R}^{h\times d_{h}}.```
\mathbf{Q}{t}=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{h\times d_{h}}.```
Repeating for all tokens reconstitutes $\mathbf{Q}\in\mathbb{R}^{T\times h\times d_{h}}$ . Similar procedures can be applied to obtain $\mathbf{K}$ and $\mathbf{V}$ with ranks $R_{k}$ and $R_{v}$ , respectively.
对所有 Token 重复上述过程,即可重构 $\mathbf{Q}\in\mathbb{R}^{T\times h\times d_{h}}$。类似的过程可以应用于获取 $\mathbf{K}$ 和 $\mathbf{V}$,其秩分别为 $R_{k}$ 和 $R_{v}$。
Scaled Dot-Product Attention. Once $\mathbf{Q},\mathbf{K},\mathbf{V}$ are factorized, multi-head attention proceeds as in standard Transformers. For each head $i\in{1,\ldots,h}$ :
缩放点积注意力 (Scaled Dot-Product Attention)。一旦 $\mathbf{Q},\mathbf{K},\mathbf{V}$ 被分解,多头注意力 (multi-head attention) 就会像标准 Transformer 中那样进行。对于每个头 $i\in{1,\ldots,h}$:
\begin{array}{r}{{\bf h e a d}_{i}=\mathrm{Softmax}\Big(\frac{1}{\sqrt{d_{h}}}\,{\bf Q}_{i}\,({\bf K}_{i})^{\top}\Big)\,{\bf V}_{i},}\end{array}```
\begin{array}{r}{{\bf h e a d}{i}=\mathrm{Softmax}\Big(\frac{1}{\sqrt{d{h}}},{\bf Q}{i},({\bf K}{i})^{\top}\Big),{\bf V}_{i},}\end{array}```
where $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ are the slices along the head dimension. Concatenating these $h$ heads along the last dimension yields an $\mathbb{R}^{T\times(h\cdot d_{h})}$ tensor, which is projected back to $\mathbb{R}^{T\times d_{\mathrm{model}}}$ by an output weight matrix W O ∈R(h·dh)×dmodel:
其中 $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ 是沿头维度切片的张量。将这些 $h$ 个头沿最后一个维度拼接,得到一个 $\mathbb{R}^{T\times(h\cdot d_{h})}$ 的张量,然后通过输出权重矩阵 W O ∈R(h·dh)×dmodel 将其投影回 $\mathbb{R}^{T\times d_{\mathrm{model}}}$:
\mathrm{TPA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}_{1},...\,,\mathbf{head}_{h}\big)W^{O}.```
\mathrm{TPA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}{1},...,,\mathbf{head}{h}\big)W^{O}.```
Parameter Initialization. We initialize the weight matrices W raQ, W raK, W raV, W rb Q, W rbK, $\boldsymbol{W_{r}^{b}}^{V}$ using Xavier initialization (Glorot & Bengio, 2010). Specifically, each entry of the weight matrix is drawn from a uniform distribution with bounds $[-\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})},\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})}]$ , where $n_{\mathrm{in}}$ and $n_{\mathrm{out}}$ are the input and output dimensions of the respective weight matrices. This initialization strategy helps maintain the variance of activation s and gradients across the network.
参数初始化。我们使用 Xavier 初始化 (Glorot & Bengio, 2010) 来初始化权重矩阵 W raQ, W raK, W raV, W rb Q, W rbK, $\boldsymbol{W_{r}^{b}}^{V}$。具体来说,权重矩阵的每个元素都是从均匀分布中抽取的,其范围为 $[-\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})},\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})}]$,其中 $n_{\mathrm{in}}$ 和 $n_{\mathrm{out}}$ 分别是相应权重矩阵的输入和输出维度。这种初始化策略有助于保持网络中激活值和梯度的方差。
3.2 RoPE Compatibility and Acceleration
3.2 RoPE 兼容性与加速
In a typical workflow of adding RoPE to standard multi-head attention, one first computes $\mathbf{Q}{t},\mathbf{K}{s}\in$ $\mathbb{R}^{h\times\check{d}_{h}^{\textbf{\textbf{1}}}}$ of the $t$ -th token and $s$ -th token and then applies:
在将RoPE添加到标准多头注意力机制的典型工作流程中,首先计算第$t$个Token和第$s$个Token的$\mathbf{Q}{t},\mathbf{K}{s}\in$ $\mathbb{R}^{h\times\check{d}_{h}^{\textbf{\textbf{1}}}}$,然后应用:
\mathbf{Q}_{t}\mapsto\widetilde{\mathbf{Q}}_{t}=\mathrm{RoPE}_{t}(\mathbf{Q}_{t}),\qquad\mathbf{K}_{s}\mapsto\widetilde{\mathbf{K}}_{s}=\mathrm{RoPE}_{s}(\mathbf{K}_{s}).```
\mathbf{Q}{t}\mapsto\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t}),\qquad\mathbf{K}{s}\mapsto\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}_{s}(\mathbf{K}_{s}).```
Direct Integration. A useful optimization is to integrate RoPE directly into the TPA factorization. For example, one can pre-rotate the token-dimension factors:
直接集成。一个有用的优化是将 RoPE 直接集成到 TPA 分解中。例如,可以预先旋转 Token 维度的因子:
\widetilde{\bf B}_{K}({\bf x}_{t})\,\longleftarrow\,\mathrm{RoPE}_{t}\bigl({\bf B}_{K}({\bf x}_{t})\bigr),```
\widetilde{\bf B}{K}({\bf x}{t}),\longleftarrow,\mathrm{RoPE}{t}\bigl({\bf B}{K}({\bf x}_{t})\bigr),```
yielding a pre-rotated key represe ntation:
生成一个预旋转的密钥表示:
\widetilde{\mathbf{K}}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{(r)}^{K}(\mathbf{x}_{t})\otimes\mathrm{RoPE}_{t}\bigl(\mathbf{b}_{(s)}^{K}(\mathbf{x}_{t})\bigr)=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{t})^{\top}\,\mathrm{RoPE}_{t}\bigl(\mathbf{B}_{K}(\mathbf{x}_{t})\bigr).```
\widetilde{\mathbf{K}}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{(r)}^{K}(\mathbf{x}{t})\otimes\mathrm{RoPE}{t}\bigl(\mathbf{b}{(s)}^{K}(\mathbf{x}_{t})\bigr)=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{t})^{\top},\mathrm{RoPE}_{t}\bigl(\mathbf{B}_{K}(\mathbf{x}_{t})\bigr).```
Thus, each $\mathbf{K}_{t}$ is already rotated before caching, removing the need for explicit rotation at the decoding time and accelerating auto regressive inference. Depending on hardware and performance requirements, one can also adopt different RoPE integration approaches for training and inference.
因此,每个 $\mathbf{K}_{t}$ 在缓存之前已经进行了旋转,从而在解码时无需显式旋转,并加速了自回归推理。根据硬件和性能需求,还可以在训练和推理时采用不同的 RoPE 集成方法。
Theorem 1 (RoPE’s Compatibility with TPA). Let $\mathbf{Q}_{t}$ be factorized by TPA as
定理 1 (RoPE 与 TPA 的兼容性). 令 $\mathbf{Q}_{t}$ 由 TPA 分解为
\mathbf{Q}_{t}=\frac{1}{R_{Q}}\,\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\mathbf{B}_{Q}(\mathbf{x}_{t})\ \in\mathbb{R}^{h\times d_{h}},```
\mathbf{Q}{t}=\frac{1}{R{Q}},\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\mathbf{B}{Q}(\mathbf{x}{t})\ \in\mathbb{R}^{h\times d_{h}},```
where $\mathbf{A}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times h}$ and $\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times d_{h}}$ . Then we have:
其中 $\mathbf{A}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times h}$ 和 $\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times d_{h}}$。然后我们有:
\mathrm{RoPE}(\mathbf{Q}_{t})=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t}),\qquad\mathrm{where}~\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})=\mathrm{RoPE}_{t}\big(\mathbf{B}_{Q}(\mathbf{x}_{t})\big).```
\mathrm{RoPE}(\mathbf{Q}{t})=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t}),\qquad\mathrm{where}~\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t})=\mathrm{RoPE}{t}\big(\mathbf{B}{Q}(\mathbf{x}_{t})\big).```
In addition, assume $\mathbf{Q}{t}$ and $\mathbf{K}{s}$ are factorized by TPA and then rotated by $\mathrm{RoPE}{t},\mathrm{RoPE}{s}$ . Let $\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t})$ and $\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}{s}(\mathbf{K}{s})$ . Then we have
此外,假设 $\mathbf{Q}{t}$ 和 $\mathbf{K}{s}$ 通过 TPA 进行分解,然后通过 $\mathrm{RoPE}{t},\mathrm{RoPE}{s}$ 进行旋转。设 $\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t})$ 和 $\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}{s}(\mathbf{K}{s})$。那么我们有
\mathrm{RoPE}_{t-s}(\mathbf{Q}_{t})\mathbf{K}_{s}^{\top}=\widetilde{\mathbf{Q}}_{t}\,\widetilde{\mathbf{K}}_{s}^{\top},```
\mathrm{RoPE}{t-s}(\mathbf{Q}{t})\mathbf{K}{s}^{\top}=\widetilde{\mathbf{Q}}{t},\widetilde{\mathbf{K}}_{s}^{\top},```
Focusing on individual heads $i$ , the above matrix equalit y im plies:
关注单个头 $i$,上述矩阵等式意味着:
\mathrm{RoPE}_{t-s}\big(\mathbf{q}_{t,i}\big)^{\top}\mathbf{k}_{s,i}=\widetilde{\mathbf{q}}_{t,i}^{\top}\,\widetilde{\mathbf{k}}_{s,i}.```
\mathrm{RoPE}{t-s}\big(\mathbf{q}{t,i}\big)^{\top}\mathbf{k}{s,i}=\widetilde{\mathbf{q}}{t,i}^{\top},\widetilde{\mathbf{k}}_{s,i}.```
where $\mathbf{q}{t,i}\in\mathbb{R}^{d{h}}$ is the $i$ -th query head of $t$ -th token, an d $\mathbf{k}{s,i}\in\mathbb{R}^{d{h}}$ is the $j$ -th key head of $s$ -th token, and
其中 $\mathbf{q}{t,i}\in\mathbb{R}^{d{h}}$ 是第 $t$ 个 token 的第 $i$ 个查询头,$\mathbf{k}{s,i}\in\mathbb{R}^{d{h}}$ 是第 $s$ 个 token 的第 $j$ 个键头,
\widetilde{\mathbf{q}}_{t,i}=\mathrm{RoPE}(\mathbf{q}_{t,i})=\mathbf{T}_{t}\mathbf{q}_{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```
\widetilde{\mathbf{q}}{t,i}=\mathrm{RoPE}(\mathbf{q}{t,i})=\mathbf{T}{t}\mathbf{q}{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```
Theorem 1 indicates that TPA does not break RoPE’s relative translational property. We prove Theorem 1 in Appendix A. In short, $\mathrm{RoPE}{t}$ acts as a block-diagonal orthogonal transform (i.e., a matrix $\mathbf{T}{t}$ ) on $\mathbf{B}{Q}(\mathbf{x}{t})$ . Consequently, $\mathbf{A}{Q}(\mathbf{x}{t})$ remains unchanged, while each column of $\mathbf{B}{Q}(\mathbf{x}{t})$ is rotated appropriately, preserving the TPA structure.
定理 1 表明 TPA 不会破坏 RoPE 的相对平移性质。我们在附录 A 中证明了定理 1。简而言之,$\mathrm{RoPE}{t}$ 作为块对角正交变换(即矩阵 $\mathbf{T}{t}$)作用于 $\mathbf{B}{Q}(\mathbf{x}{t})$。因此,$\mathbf{A}{Q}(\mathbf{x}{t})$ 保持不变,而 $\mathbf{B}{Q}(\mathbf{x}{t})$ 的每一列都适当旋转,从而保留了 TPA 结构。
3.3 KV Caching and Memory Reduction
3.3 KV 缓存与内存优化
In auto regressive decoding, standard attention caches $\mathbf{K}{t},\mathbf{V}{t}\in\mathbb{R}^{h\times d_{h}}$ for each past token $t$ . This accumulates to $\mathbb{R}^{T\times h\times d_{h}}$ for keys and $\mathbb{R}^{T\times h\times d_{h}}$ for values, i. e., $2,T,h,d_{h}$ total.
在自回归解码中,标准注意力缓存为每个过去的 token $t$ 存储 $\mathbf{K}{t},\mathbf{V}{t}\in\mathbb{R}^{h\times d_{h}}$。对于键和值,这会累积到 $\mathbb{R}^{T\times h\times d_{h}}$,即总共 $2,T,h,d_{h}$。
TPA Factorized KV Caching. Instead of storing the full $\mathbf{K}{t}$ and $\mathbf{V}{t}$ , TPA stores only their factorized ranks. Specifically, we keep
TPA 分解 KV 缓存。TPA 不存储完整的 $\mathbf{K}{t}$ 和 $\mathbf{V}{t}$,而是仅存储它们的分解秩。具体来说,我们保留
\mathbf{A}_{K}(\mathbf{x}_{t}),\,\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{t})\quad\mathrm{and}\quad\mathbf{A}_{V}(\mathbf{x}_{t}),\,\mathbf{B}_{V}(\mathbf{x}_{t}),```
\mathbf{A}{K}(\mathbf{x}{t}),,\widetilde{\mathbf{B}}{K}(\mathbf{x}{t})\quad\mathrm{and}\quad\mathbf{A}{V}(\mathbf{x}{t}),,\mathbf{B}{V}(\mathbf{x}{t}),```
\begin{array}{r}{\mathbf{A}_{K}(\mathbf{x}_{t})\in\mathbb{R}^{R_{K}\times h},\;\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{t})\in\mathbb{R}^{R_{K}\times d_{h}},\;\mathbf{A}_{V}(\mathbf{x}_{t})\in\mathbb{R}^{R_{V}\times h},\;\mathbf{B}_{V}(\mathbf{x}_{t})\in\mathbb{R}^{R_{V}\times d_{h}}.}\end{array}```
\begin{array}{r}{\mathbf{A}{K}(\mathbf{x}{t})\in\mathbb{R}^{R_{K}\times h},;\widetilde{\mathbf{B}}{K}(\mathbf{x}{t})\in\mathbb{R}^{R_{K}\times d_{h}},;\mathbf{A}{V}(\mathbf{x}{t})\in\mathbb{R}^{R_{V}\times h},;\mathbf{B}{V}(\mathbf{x}{t})\in\mathbb{R}^{R_{V}\times d_{h}}.}\end{array}```
Hence, the memory cost per token is
因此,每个 Token 的内存成本为
\underbrace{R_{K}(h+d_{h})}_{\mathrm{for\,K}}\,+\,\underbrace{R_{V}(h+d_{h})}_{\mathrm{for\,V}}=\left(\,R_{K}+R_{V}\,\right)\left(h+d_{h}\right).```
\underbrace{R_{K}(h+d_{h})}{\mathrm{for,K}},+,\underbrace{R{V}(h+d_{h})}{\mathrm{for,V}}=\left(,R{K}+R_{V},\right)\left(h+d_{h}\right).```
Compared to the standard caching cost of $2,h,d_{h}$ , the ratio is:
与标准缓存成本 $2,h,d_{h}$ 相比,比率为:
\frac{\left(R_{K}+R_{V}\right)\left(h+d_{h}\right)}{2\,h\,d_{h}}.```
\frac{\left(R_{K}+R_{V}\right)\left(h+d_{h}\right)}{2,h,d_{h}}.```
For large $h$ and $d_{h}$ (typically $d_{h}=64$ or 128), setting $R_{K},R_{V}\ll d_{h}$ (e.g., 1 or 2) often yields $10\times$ or more reduction.
对于较大的 $h$ 和 $d_{h}$(通常 $d_{h}=64$ 或 128),设置 $R_{K},R_{V}\ll d_{h}$(例如 1 或 2)通常可以减少 $10\times$ 或更多。
Table 1: Comparison of different attention mechanisms. Here, $R_{Q}$ , $R_{K}$ , and $R_{V}$ denote the ranks for queries, keys, and values in TPA, respectively. Variants of TPA, such as TPA (KVonly), TPA (Non-contextual A), and TPA (Non-contextual B), are detailed in Section 3.5. For MLA, $d_{h}^{R}$ and $d_{h}$ are the dimensions for RoPE and non-RoPE parts; $d_{c}^{\prime}$ and $d_{c}$ are the dimensions of compressed vectors for query and key-value, respectively.
表 1: 不同注意力机制的比较。其中,$R_{Q}$、$R_{K}$ 和 $R_{V}$ 分别表示 TPA 中查询、键和值的秩。TPA 的变体,如 TPA (KVonly)、TPA (Non-contextual A) 和 TPA (Non-contextual B),详见第 3.5 节。对于 MLA,$d_{h}^{R}$ 和 $d_{h}$ 分别是 RoPE 部分和非 RoPE 部分的维度;$d_{c}^{\prime}$ 和 $d_{c}$ 分别是查询和键值压缩向量的维度。
| 方法 | KVCACHE | #参数 | #查询头 | #KV头 |
|---|---|---|---|---|
| MHA | 2hdh | h | h | |
| MQA | 2dh | h | 1 | |
| GQA | 2gdh | d(dmodel +hdn +hdR) | h | 6 |
| MLA | de+ dR | +dmodeld + de(dmodel + 2hdn) | h | h |
| TPA | (RK+Rv)(h+dn) | dmodel(RQ+RK+Rv)(h + dn)+ dmode1 hdh | h | h |
| TPA (KVonly) | (RK+Rv)(h+dh) | dmodel(RK +Rv)(h +dn)+2dmodel hdh | h | h |
| TPA(Non-contextualA) | (RK+Rv)dh | (RQ+RK+Rv)(dmodeldn+h)+dmodel hdh | h | h |
| TPA(Non-contextualB) | (RK+Rv)h | (RQ+RK+Rv)(dmodeih+dh)+dmodel hdh | h | h |
3.4 Unifying MHA, MQA, and GQA as Non-contextual TPA
3.4 将 MHA、MQA 和 GQA 统一为非上下文 TPA
3.4.1 MHA as Non-contextual TPA
3.4.1 多头注意力机制 (MHA) 作为非上下文 TPA
Standard multi-head attention (MHA) can be viewed as a specific instance of TPA in which: 1) the rank is set equal to the number of heads; 2) the head dimension factor is non-contextual (i.e., independent of the $t$ -th token embedding $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}})$ ; 3) the token dimension factor is a linear function of $\mathbf{x}_{t}$ .
标准多头注意力机制 (MHA) 可以被视为 TPA 的一个特定实例,其中:1) 秩被设置为头的数量;2) 头维度因子是非上下文相关的(即独立于第 $t$ 个 token 嵌入 $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}})$;3) token 维度因子是 $\mathbf{x}_{t}$ 的线性函数。
To match MHA with TPA, let $R_{Q}=R_{K}=R_{V}=h$ . Focusing on $\mathbf{Q}_{t}$ :
为了将 MHA 与 TPA 匹配,令 $R_{Q}=R_{K}=R_{V}=h$。重点关注 $\mathbf{Q}_{t}$:
(a) Non-contextual head factors. Define
(a) 非上下文头因子。定义
\mathbf{a}_{i}^{Q}=R_{Q}\mathbf{e}_{i}\in\mathbb{R}^{h},\quad(\mathbf{e}_{i}\in\mathbb{R}^{h}\mathrm{~is~the~}i\mathbf{\cdot}\mathbf{t}\mathrm{h~standard~basis~vector}),```
\mathbf{a}{i}^{Q}=R{Q}\mathbf{e}{i}\in\mathbb{R}^{h},\quad(\mathbf{e}{i}\in\mathbb{R}^{h}\mathrm{是}i\mathbf{\cdot}\mathbf{t}\mathrm{~标准基向量}),```
so that $\mathbf{e}{i}\otimes\cdot$ corresponds to the $i$ -th head of $\mathbf{Q}{t}$ .
使得 $\mathbf{e}{i}\otimes\cdot$ 对应于 $\mathbf{Q}{t}$ 的第 $i$ 个头。
(b) Contextual token factors. Define
(b) 上下文 Token 因素。定义
\mathbf{b}_{i}^{Q}(\mathbf{x}_{t})=(\boldsymbol{W}_{i}^{Q})^{\top}\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},```
\mathbf{b}{i}^{Q}(\mathbf{x}{t})=(\boldsymbol{W}{i}^{Q})^{\top}\mathbf{x}{t}\in\mathbb{R}^{d_{h}},```
where $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ is the per-head query projection defined before, hence $\mathbf{b}{i}^{Q}(\mathbf{x}{t})$ dependent on $\mathbf{x}_{t}$ .
其中 $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ 是之前定义的每个头的查询投影,因此 $\mathbf{b}{i}^{Q}(\mathbf{x}{t})$ 依赖于 $\mathbf{x}_{t}$。
Substituting (3.8)–(3.9) into (3.1) gives:
将 (3.8)–(3.9) 代入 (3.1) 可得:
\mathbf{Q}_{t}=\sum_{i=1}^{h}\left[\mathbf{e}_{i}\otimes\left((W_{i}^{Q})^{\top}\,\mathbf{x}_{t}\right)\right]\in\mathbb{R}^{h\times d_{h}}.```
\mathbf{Q}{t}=\sum{i=1}^{h}\left[\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top},\mathbf{x}{t}\right)\right]\in\mathbb{R}^{h\times d{h}}.```
Each term $\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top}\mathbf{x}{t}\right)$ in (3.10) contributes only to the $i$ -th row, reconstituting the usual MHA form of $\mathbf{Q}{t}$ . Analogous constructions hold for $\mathbf{K}{t}$ and $\mathbf{V}{t}$ using $W_{i}^{K},W_{i}^{V}$ . Thus, MHA is $a$ non-contextual, full-rank variant of $T P A$ .
(3.10) 式中的每一项 $\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top}\mathbf{x}{t}\right)$ 仅对第 $i$ 行有贡献,重构了 $\mathbf{Q}{t}$ 的常规多头注意力 (MHA) 形式。类似地,使用 $W_{i}^{K},W_{i}^{V}$ 可以构造 $\mathbf{K}{t}$ 和 $\mathbf{V}{t}$。因此,MHA 是 $T P A$ 的一种非上下文、满秩变体。
TPA with Non-contextual A. More broadly, TPA can use non-contextual head-dimension factors $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}{r}^{V}\ \in\ \mathbb{R}^{h}$ (i.e., independent of ${\bf x}{t}$ ), while allowing $\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ to remain context-dependent. Then, for keys:
TPA 与非上下文 A。更广泛地说,TPA 可以使用非上下文头维度因子 $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}{r}^{V}\ \in\ \mathbb{R}^{h}$ (即与 ${\bf x}{t}$ 无关),同时允许 $\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ 保持上下文依赖性。然后,对于键:
\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}\otimes\mathbf{b}_{r}^{K}(\mathbf{x}_{t}),```
\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}\otimes\mathbf{b}{r}^{K}(\mathbf{x}_{t}),```
and similarly for queries/values. This reduces per-token computations and can be effective when head-dimension relationships are relatively stable across all tokens.
对于查询/值也是如此。这减少了每个 Token 的计算量,并且当所有 Token 之间的头维度关系相对稳定时,这种方法可能有效。
MQA and GQA as Non-Contextual TPA. Multi-Query Attention (MQA) (Shazeer, 2019) and Grouped Query Attention (GQA) (Ainslie et al., 2023) also emerge naturally from TPA by restricting the head-dimension factors to be non-contextual and low-rank:
MQA 和 GQA 作为非上下文 TPA。多查询注意力 (MQA) (Shazeer, 2019) 和分组查询注意力 (GQA) (Ainslie et al., 2023) 也通过将头维度因子限制为非上下文和低秩,自然地从 TPA 中产生:
• MQA as Rank-1 TPA. In MQA, all heads share a single set of keys/values, corresponding to $R_{K}=R_{V}=1$ along the head dimension. Concretely,
• MQA 作为 Rank-1 TPA。在 MQA 中,所有头共享一组键/值,对应于沿头维度的 $R_{K}=R_{V}=1$。具体来说,
\mathbf{K}_{t}=(1,\ldots,1)^{\top}\,\otimes\,\mathbf{b}^{K}(\mathbf{x}_{t}),\quad\mathbf{V}_{t}=(1,\ldots,1)^{\top}\,\otimes\,\mathbf{b}^{V}(\mathbf{x}_{t}),```
\mathbf{K}{t}=(1,\ldots,1)^{\top},\otimes,\mathbf{b}^{K}(\mathbf{x}{t}),\quad\mathbf{V}{t}=(1,\ldots,1)^{\top},\otimes,\mathbf{b}^{V}(\mathbf{x}{t}),```
forces every head to use the same $\mathbf{K}{t},\mathbf{V}{t}$ . Each head retains a distinct query projection, matching the MQA design.
强制每个头使用相同的 $\mathbf{K}{t},\mathbf{V}{t}$。每个头保留一个独特的查询投影,与 MQA 设计相匹配。
• GQA as Grouped Rank-1 TPA. GQA partitions $h$ heads into $G$ groups, each sharing keys/values within that group. In TPA form, each group $g$ has a dedicated non-contextual factor pair $\bar{\mathbf{a}}{g}^{K},\mathbf{a}{g}^{V}\in$ $\mathbb{R}^{h}$ , which acts as a “mask” for the heads in that group. Varying $G$ from 1 to $h$ interpolates from MQA to standard MHA.
• GQA 作为分组秩-1 TPA。GQA 将 $h$ 个头部分为 $G$ 个组,每个组内的头共享键/值。在 TPA 形式中,每个组 $g$ 都有一个专用的非上下文因子对 $\bar{\mathbf{a}}{g}^{K},\mathbf{a}{g}^{V}\in$ $\mathbb{R}^{h}$,作为该组内头部的“掩码”。将 $G$ 从 1 变化到 $h$ 时,从 MQA 过渡到标准 MHA。
Hence, by constraining TPA’s head-dimension factors to be constant masks (one for MQA; multiple for GQA), these popular variants are recovered as special cases.
因此,通过将 TPA 的头维度因子限制为常数掩码(MQA 为一个;GQA 为多个),这些流行的变体可以作为特例恢复。
3.5 Other Variants of TPA
3.5 TPA 的其他变体
TPA with Non-contextual B. Conversely, one may fix the token-dimension factors $\mathbf{b}{r}^{Q},\mathbf{b}{r}^{K},\mathbf{b}{r}^{V}\in$ $\mathbb{R}^{d{h}}$ as learned parameters, while allowing $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t})$ to adapt to $\mathbf{x}_{t}$ . For keys:
TPA 与非上下文 B。相反,可以将 Token 维度因子 $\mathbf{b}{r}^{Q},\mathbf{b}{r}^{K},\mathbf{b}{r}^{V}\in$ $\mathbb{R}^{d{h}}$ 固定为学习参数,同时允许 $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t})$ 适应 $\mathbf{x}_{t}$。对于键:
\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\otimes\mathbf{b}_{r}^{K},```
\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}(\mathbf{x}{t})\otimes\mathbf{b}_{r}^{K},```
and similarly for keys/values. This arrangement is effective if the token-dimension structure remains mostly uniform across the sequence, while the head-dimension factors capture context.
同样适用于键/值。如果Token维度的结构在序列中基本保持一致,而头部维度的因素捕捉上下文,这种安排是有效的。
TPA KV Only. One can preserve a standard query mapping,
TPA KV Only。可以保留标准的查询映射,
\mathbf{Q}_{t}=W^{Q}\,\mathbf{x}_{t}\in\mathbb{R}^{h\times d_{h}},```
\mathbf{Q}{t}=W^{Q},\mathbf{x}{t}\in\mathbb{R}^{h\times d_{h}},```
and factorize only the keys and values. This leaves the query projection as the original linear transformation while reducing memory usage via factorized KV caching.
仅对键和值进行分解。这使得查询投影保持原始的线性变换,同时通过分解的键值缓存减少内存使用。
TPA KV with Shared B. Another variant is to share the token-dimension factors of keys and values:
TPA KV 与共享 B。另一种变体是共享键和值的 Token 维度因子:
\mathbf{b}_{r}^{K}(\mathbf{x}_{t})=\mathbf{b}_{r}^{V}(\mathbf{x}_{t}),```
\mathbf{b}{r}^{K}(\mathbf{x}{t})=\mathbf{b}{r}^{V}(\mathbf{x}{t}),```
lowering parameter counts and the KV cache footprint. While it constrains $\mathbf{K}$ and $\mathbf{V}$ to be formed from the same token basis, it can still perform well and provide additional memory savings.
降低参数量和 KV 缓存占用。虽然它限制了 $\mathbf{K}$ 和 $\mathbf{V}$ 必须基于相同的 Token 基础,但它仍然可以表现良好并提供额外的内存节省。
Nonlinear Head Factors. Rather than applying purely linear mappings to the head-dimension factors $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}_{r}^{V}$ , one may introduce element-wise nonlinear i ties such as $\sigma(\cdot)$ or softmax $(\cdot)$ . This effectively yields a Mixture of Heads Attention (MoH Attention), where each component becomes a learned mixture weight modulated by the non linearity.
非线性头因子。与对头维度因子 $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}_{r}^{V}$ 应用纯线性映射不同,可以引入逐元素的非线性关系,例如 $\sigma(\cdot)$ 或 softmax $(\cdot)$。这实际上产生了一种多头注意力混合(MoH Attention),其中每个分量都成为由非线性调制的学习混合权重。
Discussion. These variants illustrate TPA’s versatility in balancing memory cost, computational overhead, and representation power. By choosing which dimensions (heads or tokens) remain contextual and adjusting ranks $(R_{Q},R_{K},R_{V})$ , TPA unifies multiple existing attention mechanisms— such as MHA, MQA, and GQA—under one framework, while potentially reducing the KV cache size by an order of magnitude during auto regressive inference.
讨论。这些变体展示了 TPA 在平衡内存成本、计算开销和表示能力方面的多功能性。通过选择哪些维度(头或 Token)保持上下文并调整秩 $(R_{Q},R_{K},R_{V})$,TPA 将多种现有的注意力机制(如 MHA、MQA 和 GQA)统一在一个框架下,同时可能在自回归推理期间将 KV 缓存大小减少一个数量级。
3.6 Model Architectures
3.6 模型架构
We propose a new architecture called Tensor ProducT ATTenTion Transformer (T6), which uses our Tensor Product Attention (TPA) in place of standard MHA (multi-head attention) or GQA (grouped-query attention). Building upon the query, key, and value tensors $\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathrm{~\mathbb{R}}^{N\times h\times\tilde{d}_{h}}$ defined in Section 3.1, T6 utilize the overall architecture of LLaMA (Touvron et al., 2023) while changing the self-attention block to our TPA-based version. The feed-forward network (FFN) adopts a SwiGLU layer, as in (Shazeer, 2020; Touvron et al., 2023).
我们提出了一种名为 Tensor ProducT ATTenTion Transformer (T6) 的新架构,该架构使用我们的 Tensor Product Attention (TPA) 替代了标准的 MHA(多头注意力)或 GQA(分组查询注意力)。基于第 3.1 节中定义的查询、键和值张量 $\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathrm{~\mathbb{R}}^{N\times h\times\tilde{d}_{h}}$,T6 采用了 LLaMA (Touvron et al., 2023) 的整体架构,同时将自注意力块更改为基于 TPA 的版本。前馈网络 (FFN) 采用了 SwiGLU 层,如 (Shazeer, 2020; Touvron et al., 2023) 所述。
TPA QKV Factorization. Let each token’s hidden-state vector be $\mathbf{x}{t}\ \in\ \mathbb{R}^{d{\mathrm{model}}}$ , and we follow Section 3.1 to project the entire sequence into three tensors $\mathbf{Q},\mathbf{K},\mathbf{V}\quad\in\mathrm{\quad}\mathbb{R}^{T\times h\times d_{h}}$ , where $\mathbf{Q}{t}$ , Kt, $\mathbf{V}{t}^{\lambda}\in,\mathbb{R}^{h\times d_{h}}$ denote the slices for the $t$ -th token. The factor components $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ are produced by linear transformations on $\mathbf{x}{t}$ . For instance, letting $W{r}^{a^{Q}}\in\mathbb{R}^{h\times d_{\mathrm{model}}}$ and $\boldsymbol{W_{r}^{b}}^{Q}\in\mathbb{R}^{d_{h}\times d_{\mathrm{model}}}$ , we have:
TPA QKV 分解。设每个 Token 的隐藏状态向量为 $\mathbf{x}{t}\ \in\ \mathbb{R}^{d{\mathrm{model}}}$,我们按照第 3.1 节将整个序列投影为三个张量 $\mathbf{Q},\mathbf{K},\mathbf{V}\quad\in\mathrm{\quad}\mathbb{R}^{T\times h\times d_{h}}$,其中 $\mathbf{Q}{t}$、Kt、$\mathbf{V}{t}^{\lambda}\in,\mathbb{R}^{h\times d_{h}}$ 表示第 $t$ 个 Token 的切片。因子分量 $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ 通过对 $\mathbf{x}{t}$ 进行线性变换生成。例如,设 $W{r}^{a^{Q}}\in\mathbb{R}^{h\times d_{\mathrm{model}}}$ 和 $\boldsymbol{W_{r}^{b}}^{Q}\in\mathbb{R}^{d_{h}\times d_{\mathrm{model}}}$,我们有:
\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{a}^{Q}\,\mathbf{x}_{t},\quad\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{b}^{Q}\,\mathbf{x}_{t}.```
\mathbf{a}{r}^{Q}(\mathbf{x}{t})=W_{r}^{a}^{Q},\mathbf{x}{t},\quad\mathbf{b}{r}^{Q}(\mathbf{x}{t})=W{r}^{b}^{Q},\mathbf{x}_{t}.```
In practice, we merge all ranks $r$ into a single dimension of the output, reshape, and sum over rank indices; see Section 3.1 for details. The factorization for $\mathbf{K}$ and $\mathrm{v}$ follows the same pattern.
在实践中,我们将所有秩 $r$ 合并到输出的单一维度中,进行重塑,并对秩索引求和;详见第 3.1 节。$\mathbf{K}$ 和 $\mathrm{v}$ 的分解遵循相同的模式。
Rotary Positional Embedding $(\mathbf{RoPE})$ ). As discussed in Section 3.2, RoPE (Su et al., 2024b) is applied to the $\mathbf{Q}$ and $\mathbf{K}$ . Within TPA, we pre-rotate the factor $\mathbf{b}{t}^{Q}(\mathbf{x}{t})$ and $\mathbf{b}{s}^{K}(\mathbf{x}{s})$ directly, so that each $\mathbf{K}_{s}$ is already rotated prior to caching, see (3.6) and Theorem 1.
旋转位置嵌入 (Rotary Positional Embedding, RoPE) 。如第3.2节所述,RoPE (Su et al., 2024b) 被应用于 $\mathbf{Q}$ 和 $\mathbf{K}$ 。在TPA中,我们直接预旋转因子 $\mathbf{b}{t}^{Q}(\mathbf{x}{t})$ 和 $\mathbf{b}{s}^{K}(\mathbf{x}{s})$ ,使得每个 $\mathbf{K}_{s}$ 在缓存之前已经旋转,参见 (3.6) 和定理1。
Attention Step and Output Projection. Once we have $\mathbf{Q},\mathbf{K},\mathbf{V}$ factorized per token with RoPE applied on $\mathbf{Q}$ and $\mathbf{K}$ , the attention step proceeds for each head $i\in{1,\ldots,h}$ using (3.4). Finally, concatenating these $h$ heads and then projecting them back using an output weight matrix gives the final attention result, as shown in (3.5).
注意力步骤和输出投影。一旦我们为每个 Token 分解了 $\mathbf{Q},\mathbf{K},\mathbf{V}$,并在 $\mathbf{Q}$ 和 $\mathbf{K}$ 上应用了 RoPE,注意力步骤就会为每个头 $i\in{1,\ldots,h}$ 使用 (3.4) 进行。最后,将这些 $h$ 个头连接起来,然后使用输出权重矩阵将它们投影回去,得到最终的注意力结果,如 (3.5) 所示。
SwiGLU Feed-Forward Network. Following Shazeer (2020); Touvron et al. (2023), our T6 uses a SwiGLU-based Feed-Forward Network (FFN):
SwiGLU 前馈网络。遵循 Shazeer (2020) 和 Touvron et al. (2023) 的研究,我们的 T6 使用了基于 SwiGLU 的前馈网络 (FFN):
\mathrm{FFN}(\mathbf{x})=\left[\sigma(\mathbf{x}\,{W_{1}})\,\odot\,(\mathbf{x}\,{W_{2}})\right]{W_{3}},```
\mathrm{FFN}(\mathbf{x})=\left[\sigma(\mathbf{x},{W_{1}}),\odot,(\mathbf{x},{W_{2}})\right]{W_{3}},```
where $\sigma$ is the SiLU (a.k.a., swish) non linearity, $\odot$ is element-wise product, and $W_{1},W_{2},W_{3}$ are learnable parameters. Note that other activation functions can also be used.
其中,$\sigma$ 是 SiLU(也称为 swish)非线性函数,$\odot$ 是逐元素乘积,$W_{1},W_{2},W_{3}$ 是可学习的参数。注意,其他激活函数也可以使用。
Overall T6 Block Structure. Putting everything together, one T6 block consists of:
总体 T6 块结构。将所有内容整合在一起,一个 T6 块由以下部分组成:
\begin{array}{r l}&{\textbf{x}\leftarrow\textbf{x}+\mathrm{TPA}\big(\mathrm{RMSNorm}(\mathbf{x})\big),}\\ &{\textbf{x}\leftarrow\textbf{x}+\mathrm{SwiGLU-FFN}\big(\mathrm{RMSNorm}(\mathbf{x})\big).}\end{array}```
\begin{array}{r l}&{\textbf{x}\leftarrow\textbf{x}+\mathrm{TPA}\big(\mathrm{RMSNorm}(\mathbf{x})\big),}\ &{\textbf{x}\leftarrow\textbf{x}+\mathrm{SwiGLU-FFN}\big(\mathrm{RMSNorm}(\mathbf{x})\big).}\end{array}```
We place norm layers (e.g., RMSNorm) before each sub-layer. Stacking $L$ such blocks yields a T6 model architecture with $L$ layers.
我们将归一化层(例如 RMSNorm)放置在每个子层之前。堆叠 $L$ 个这样的块会生成一个具有 $L$ 层的 T6 模型架构。
4 Experiments
4 实验
4.1 Language Modeling Tasks
4.1 语言建模任务
All experiments reported in this paper are implemented on the nanoGPT code base (Karpathy, 2022), using the FineWeb-Edu 100B dataset (Lozhkov et al., 2024). The dataset contains 100 billion tokens for training and 0.1 billion tokens for validation. We compare T6 against the baseline Llama architecture (Touvron et al., 2023) with SwiGLU activation (Shazeer, 2020) and RoPE embeddings (Su et al., 2024a), as well as Llama variants that replace Multi-Head Attention (MHA; Vaswani et al., 2017) with Multi-Query Attention (MQA; Shazeer, 2019), Grouped Query Attention (GQA; Ainslie et al., 2023), or Multi-head Latent Attention (MLA; Liu et al., 2024a). In our experiments, the numtbheer osfa mheea dnsu $h$ ibse ra dojfu sptaerda fmoer teearsc ha sa tttehnet isotan nmdaercdh aMniusltmi -tHo eeands uArtet tehnatit oanll (atMteHntAi)o,n wmheicchh ahnaiss $4d_{\mathrm{model}}^{2}$ parameters per attention layer. We train models at four scales: small (124M parameters), medium (353M), and large (773M). Details on architecture hyper parameters and training hardware appear in Appendix B.1.
本文报告的所有实验均在nanoGPT代码库(Karpathy, 2022)上实现,使用了FineWeb-Edu 100B数据集(Lozhkov等, 2024)。该数据集包含1000亿个训练token和1亿个验证token。我们将T6与基线Llama架构(Touvron等, 2023)进行比较,后者使用了SwiGLU激活函数(Shazeer, 2020)和RoPE嵌入(Su等, 2024a),以及将多头注意力机制(MHA;Vaswani等, 2017)替换为多查询注意力机制(MQA;Shazeer, 2019)、分组查询注意力机制(GQA;Ainslie等, 2023)或多头潜在注意力机制(MLA;Liu等, 2024a)的Llama变体。在我们的实验中,头数$h$被调整为与多头注意力机制(MHA)相匹配,每个注意力层有$4d_{\mathrm{model}}^{2}$个参数。我们在四个规模上训练模型:小型(1.24亿参数)、中型(3.53亿)和大型(7.73亿)。架构超参数和训练硬件的详细信息见附录B.1。
Training Setup. We follow the nanoGPT training configuration. In particular, we use the AdamW (Loshchilov, 2017) optimizer with $(\beta_{1},\beta_{2})=(0.9,0.95)$ , a weight decay of 0.1, and gradient clipping at 1.0. We follow the same setting as nanoGPT that the learning rate is managed by a cosine annealing scheduler (Loshchilov & Hutter, 2016) with 2,000 warmup steps and a (total) global batch size of 480. For the small, medium, and large models, we set maximum learning rates of $6\times10^{-4}$ , $3\times10^{-4}$ , and $2\times10^{-4}$ (respectively), and minimum learning rates of $3\times10^{-5}$ , $3\times10^{-5}$ , and $1\times10^{-5}$ (respectively).
训练设置。我们遵循 nanoGPT 的训练配置。具体来说,我们使用 AdamW (Loshchilov, 2017) 优化器,其参数为 $(\beta_{1},\beta_{2})=(0.9,0.95)$,权重衰减为 0.1,梯度裁剪为 1.0。我们遵循与 nanoGPT 相同的设置,即学习率由余弦退火调度器 (Loshchilov & Hutter, 2016) 管理,预热步数为 2,000,全局批量大小为 480。对于小型、中型和大型模型,我们分别设置最大学习率为 $6\times10^{-4}$、$3\times10^{-4}$ 和 $2\times10^{-4}$,最小学习率为 $3\times10^{-5}$、$3\times10^{-5}$ 和 $1\times10^{-5}$。
Training & Validation Curves. Figures 2 and 3 compare training and validation loss curves for the large (773M) and medium (353M) models on FineWeb-Edu-100B. Overall, TPA (red curves) and its simpler variant TPA-KVonly (pink curves) converge as fast as or faster than the baselines (MHA, MQA, GQA, MLA) while also achieving visibly lower final losses. For instance, in Figure 2(b), TPA and TPA-KVonly remain below the MHA baseline in terms of validation loss at nearly all training stages. Meanwhile, Multi-Head Latent Attention (MLA) (Liu et al., 2024a) (blue curves) generally trains more slowly and yields higher losses.
训练与验证曲线。图2和图3比较了在FineWeb-Edu-100B上训练的大型(773M)和中型(353M)模型的训练和验证损失曲线。总体而言,TPA(红色曲线)及其简化变体TPA-KVonly(粉色曲线)的收敛速度与基线(MHA、MQA、GQA、MLA)相当或更快,同时最终损失也明显更低。例如,在图2(b)中,TPA和TPA-KVonly在几乎所有训练阶段的验证损失都低于MHA基线。与此同时,多头潜在注意力(MLA)[20](蓝色曲线)通常训练速度较慢,且损失较高。
Validation Perplexity. Figure 4 shows the validation perplexities of the medium- and large-scale models. Mirroring the loss curves, TPA and TPA-KVonly steadily outperform MHA, MQA, GQA, and MLA over the course of training. By the end of pre training (around 49B tokens), TPA-based approaches achieve the lowest perplexities in most configurations.
验证困惑度。图 4 展示了中规模和大规模模型的验证困惑度。与损失曲线一致,TPA 和 TPA-KVonly 在训练过程中始终优于 MHA、MQA、GQA 和 MLA。在预训练结束时(约 49B tokens),基于 TPA 的方法在大多数配置中实现了最低的困惑度。
Downstream Evaluation. We evaluate zero-shot and two-shot performance on standard benchmarks, including ARC (Yadav et al., 2019), BoolQ (Clark et al., 2019), HellaSwag (Zellers et al., 2019), OBQA (Mihaylov et al., 2018), PIQA (Bisk et al., 2020), WinoGrande (Sakaguchi et al., 2020) and MMLU (Hendrycks et al., 2021), using the lm-evaluation-harness codebase (Gao et al., 2024). For ARC-E, ARC-C, HellaSwag, OBQA, PIQA, and SciQ, we report accuracy norm; for other tasks, we report standard accuracy. Tables 8–9 in the appendix present results for small models; Tables 2–3 for medium models; Tables 4–5 for large models;
下游评估。我们在标准基准上评估零样本和两样本性能,包括 ARC (Yadav et al., 2019)、BoolQ (Clark et al., 2019)、HellaSwag (Zellers et al., 2019)、OBQA (Mihaylov et al., 2018)、PIQA (Bisk et al., 2020)、WinoGrande (Sakaguchi et al., 2020) 和 MMLU (Hendrycks et al., 2021),使用 lm-evaluation-harness 代码库 (Gao et al., 2024)。对于 ARC-E、ARC-C、HellaSwag、OBQA、PIQA 和 SciQ,我们报告归一化准确率;对于其他任务,我们报告标准准确率。附录中的表 8-9 展示了小型模型的结果;表 2-3 展示了中型模型的结果;表 4-5 展示了大型模型的结果。
For the medium-size (353M) models (Tables 2–3), TPA generally ties or outperforms all competing methods, achieving, for example, an average of $51.41%$ in zero-shot mode versus MHA’s $50.11%$ , MQA’s $50.44%$ , and MLA’s $48.96%$ . When given two-shot prompts, TPA again leads with $53.12%$ average accuracy. A similar trend appears for the large-size (773M) models (Tables 4–5), where TPA-KVonly attains the highest average $53.52%$ zero-shot, $55.33%$ two-shot), closely followed by full TPA.
对于中等规模(353M)的模型(表 2-3),TPA 通常与所有竞争方法持平或优于它们,例如在零样本模式下平均达到 $51.41%$,而 MHA 为 $50.11%$,MQA 为 $50.44%$,MLA 为 $48.96%$。当提供两样本提示时,TPA 再次以 $53.12%$ 的平均准确率领先。对于大规模(773M)模型(表 4-5),TPA-KVonly 在零样本模式下达到了最高的平均 $53.52%$,两样本模式下为 $55.33%$,紧随其后的是完整的 TPA。
Our experiments confirm that TPA consistently matches or exceeds the performance of established attention mechanisms (MHA, MQA, GQA, MLA) across medium and large model scales. The fully factorized TPA excels on mid-scale models, while TPA-KVonly can rival or surpass it at larger scales. In both cases, factorizing the attention activation s shrinks auto regressive KV cache requirements by up to $5{\times}{-}10\times$ , thus enabling much longer context windows under fixed memory budgets. In summary, tensor product attention provides a flexible, memory-efficient alternative to standard multi-head attention, advancing the s cal ability of modern language models.
我们的实验证实,TPA 在中型和大型模型规模上始终匹配或超越现有注意力机制(MHA、MQA、GQA、MLA)的性能。完全分解的 TPA 在中等规模模型上表现出色,而 TPA-KVonly 在更大规模上可以与之匹敌或超越。在这两种情况下,分解注意力激活将自回归 KV 缓存需求缩小了 $5{\times}{-}10\times$,从而在固定内存预算下实现了更长的上下文窗口。总之,张量积注意力为标准的多头注意力提供了一种灵活且内存高效的替代方案,提升了现代大语言模型的可扩展性。
5 Related Work
5 相关工作
Transformers and Attention. As a sequence-to-sequence architecture Transformer (Vaswani et al., 2017) introduced Multi-Head Attention (MHA), enabling more effective capture of long-range dependencies. Subsequent work has explored a variety of attention mechanisms aimed at improving s cal ability and efficiency, including sparse patterns (Child et al., 2019; Shi et al., 2023; Han et al., 2024; Liang et al., 2024a; Li et al., 2024; Liang et al., 2024b), kernel-based projections (Choromanski et al., 2021), and linearized transformers (Tsai et al., 2019; Katha ro poul os et al., 2020; Schlag
Transformer 与注意力机制。作为一种序列到序列的架构,Transformer (Vaswani et al., 2017) 引入了多头注意力机制 (Multi-Head Attention, MHA),使其能够更有效地捕捉长距离依赖关系。后续的研究探索了多种注意力机制,旨在提高可扩展性和效率,包括稀疏模式 (Child et al., 2019; Shi et al., 2023; Han et al., 2024; Liang et al., 2024a; Li et al., 2024; Liang et al., 2024b)、基于核的投影 (Choromanski et al., 2021) 以及线性化 Transformer (Tsai et al., 2019; Katharopoulos et al., 2020; Schlag)。

Figure 3: The training loss and validation loss of medium-size (353M) models with different attention mechanisms on the FineWeb-Edu 100B dataset.
图 3: 在 FineWeb-Edu 100B 数据集上,不同注意力机制的中等规模 (353M) 模型的训练损失和验证损失。

Figure 4: The validation perplexity of medium-size (353M) models and large-size (773M) models with different attention mechanisms on the FineWeb-Edu 100B dataset.
图 4: 在 FineWeb-Edu 100B 数据集上,中等规模 (353M) 模型和大规模 (773M) 模型使用不同注意力机制时的验证困惑度。
Table 2: The evaluation results of medium models with different attention mechanisms pretrained using the FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 2: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制的中等模型的评估结果 (0-shot 使用 lm-evaluation-harness)。每列中的最佳分数以粗体显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 56.52 | 29.27 | 58.84 | 44.06 | 35.00 | 68.44 | 51.07 | 25.35 | 76.40 | 49.44 |
| MQA | 55.68 | 28.24 | 60.86 | 44.17 | 35.20 | 68.66 | 52.72 | 25.14 | 72.90 | 49.29 |
| GQA | 54.88 | 29.61 | 56.36 | 43.77 | 35.20 | 68.82 | 52.57 | 25.41 | 74.80 | 49.05 |
| MLA | 55.30 | 29.27 | 58.96 | 41.92 | 35.40 | 67.25 | 51.78 | 25.20 | 75.60 | 48.96 |
| TPA-KVonly | 57.11 | 30.03 | 61.25 | 44.83 | 34.60 | 69.04 | 54.54 | 23.35 | 74.60 | 49.93 |
| TPA | 59.30 | 31.91 | 60.98 | 45.57 | 34.60 | 69.48 | 53.91 | 24.93 | 77.20 | 50.88 |
et al., 2021; Zhang et al., 2023b; Sun et al., 2023; Zhang et al., 2024). To decrease memory usage and circumvent the limitation of memory bandwidth in training, Shazeer (2019) proposed MultiQuery Attention (MQA) where multiple query heads share the same key head and value head. To tackle with the issue of quality degradation and instability in training, Grouped-Query Attention (GQA) (Ainslie et al., 2023) divides queries into several groups, and each group of queries shares a single key head and value head. Recently, DeepSeek-V2 (Liu et al., 2024a) applied multihead latent attention (MLA) to achieve better performance than MHA while reducing KV cache in inference time by sharing the same low-rank representation of key and value. In comparison to the approaches above, TPA applied a low-rank tensor product to compute the queries, keys, and values where the cached representations of keys and values are much smaller than those in MHA, achieving better reduction on memory assumption of KV cache in inference time.
et al., 2021; Zhang et al., 2023b; Sun et al., 2023; Zhang et al., 2024)。为了减少内存使用并规避训练中内存带宽的限制,Shazeer (2019) 提出了多查询注意力 (MultiQuery Attention, MQA),其中多个查询头共享同一个键头和值头。为了解决训练中质量下降和不稳定的问题,分组查询注意力 (Grouped-Query Attention, GQA) (Ainslie et al., 2023) 将查询分成若干组,每组查询共享一个键头和值头。最近,DeepSeek-V2 (Liu et al., 2024a) 应用多头潜在注意力 (multihead latent attention, MLA) 在推理时通过共享键和值的低秩表示,实现了比 MHA 更好的性能,同时减少了 KV 缓存。与上述方法相比,TPA 应用低秩张量积来计算查询、键和值,其中键和值的缓存表示比 MHA 中的小得多,从而在推理时实现了更好的 KV 缓存内存假设减少。
Table 3: The evaluation results of medium models with different attention mechanisms pre-trained using the FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 3: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制中等模型的评估结果 (2-shot 使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 64.44 | 32.85 | 59.05 | 44.18 | 33.20 | 68.72 | 50.12 | 26.01 | 87.40 | 49.44 |
| MQA | 64.27 | 32.94 | 57.71 | 44.36 | 31.80 | 68.01 | 51.70 | 25.99 | 86.00 | 49.29 |
| GQA | 61.70 | 32.17 | 52.81 | 43.99 | 33.80 | 68.50 | 53.35 | 24.44 | 86.40 | 50.80 |
| MLA | 62.75 | 30.80 | 59.17 | 42.02 | 34.80 | 67.08 | 52.41 | 26.11 | 84.80 | 51.10 |
| TPA-KVonly | 65.99 | 33.70 | 57.49 | 44.47 | 34.20 | 69.53 | 53.28 | 24.23 | 86.50 | 49.93 |
| TPA | 66.54 | 34.47 | 58.96 | 45.35 | 33.00 | 69.21 | 53.99 | 24.51 | 91.30 | 53.04 |
Table 4: The evaluation results of large models with different attention mechanisms pre-trained using the FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 4: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制大模型的评估结果 (零样本,使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 59.93 | 33.62 | 61.93 | 50.63 | 36.00 | 71.06 | 55.41 | 22.87 | 81.20 | 52.52 |
| MQA | 60.73 | 33.62 | 57.34 | 50.09 | 37.00 | 69.97 | 55.49 | 25.30 | 79.60 | 52.13 |
| GQA | 61.66 | 34.30 | 58.72 | 49.85 | 38.40 | 71.16 | 53.75 | 25.23 | 77.60 | 52.30 |
| MLA | 60.73 | 31.57 | 61.74 | 48.96 | 35.40 | 69.59 | 55.09 | 26.39 | 76.70 | 51.80 |
| TPA-KVonly | 63.26 | 34.13 | 61.96 | 50.66 | 37.20 | 72.09 | 55.25 | 26.06 | 81.10 | 53.52 |
| TPA | 63.22 | 35.58 | 60.03 | 51.26 | 36.80 | 71.44 | 55.56 | 24.77 | 79.60 | 53.10 |
Table 5: The evaluation results of large models with different attention mechanisms pre-trained using the FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSwag $=$ HellaSwag, $\mathrm{WG}=$ WinoGrande.
表 5: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制大模型的评估结果 (2-shot 使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSwag = HellaSwag, WG = WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 67.85 | 36.35 | 59.82 | 50.22 | 35.00 | 70.67 | 53.35 | 23.92 | 91.10 | 54.25 |
| MQA | 68.86 | 36.09 | 53.79 | 50.50 | 37.00 | 70.89 | 54.70 | 25.01 | 88.00 | 53.87 |
| GQA | 69.15 | 36.09 | 58.84 | 50.29 | 36.20 | 70.73 | 54.22 | 26.08 | 90.00 | 54.62 |
| MLA | 68.56 | 35.41 | 60.12 | 49.18 | 38.00 | 69.21 | 55.25 | 25.29 | 88.20 | 54.36 |
| TPA-KVonly | 71.34 | 37.71 | 59.76 | 51.10 | 36.00 | 71.49 | 54.62 | 25.83 | 90.10 | 55.33 |
| TPA | 70.41 | 37.71 | 60.06 | 51.30 | 34.00 | 71.06 | 54.54 | 25.79 | 90.30 | 55.02 |
Low-Rank Factorization s. Low-rank approximations have been applied to compress model parameters and reduce complexity including LoRA (Hu et al., 2022), which factorizes weight updates during fine-tuning, and its derivatives for other training scenarios such as efficient pre training (ReLoRA (Lialin et al., 2023), MoRA (Jiang et al., 2024)), long-context training (LongLoRA (Chen et al., 2024), SinkLoRA (Zhang, 2024)), as well as continual training (InfLoRA (Liang & Li, 2024), GS-LoRA (Zhao et al., 2024), I-LoRA (Ren et al., 2024)). These approaches typically produce static low-rank expansions that do not explicitly depend on the input context. And Malladi et al. (2023); Zeng & Lee (2024) provided theoretical proof of the expressiveness of low-rank approximation. For the initialization of factorization matrices, OLoRA (Bu¨yu¨kakyu¨z, 2024) applied QR-decomposition of pretrained weight to achieve better performance of language models while LoLDU (Shi et al., 2024) used LDU-decomposition to accelerate training of LoRA. Moreover, AdaLoRA (Zhang et al., 2023a) utilized Singular Value Decomposition (SVD) of the pretrained weight and introduced importance score for each parameter as a measurement to achieve dynamic adjustment of rank. TPA, by contrast, constructs Q, K, and V as con textually factorized tensors, enabling dynamic adaptation.
低秩分解 (Low-Rank Factorization)。低秩近似已被应用于压缩模型参数并降低复杂度,包括 LoRA (Hu et al., 2022),它在微调过程中对权重更新进行分解,以及其衍生方法用于其他训练场景,如高效预训练 (ReLoRA (Lialin et al., 2023), MoRA (Jiang et al., 2024))、长上下文训练 (LongLoRA (Chen et al., 2024), SinkLoRA (Zhang, 2024)) 以及持续训练 (InfLoRA (Liang & Li, 2024), GS-LoRA (Zhao et al., 2024), I-LoRA (Ren et al., 2024))。这些方法通常生成静态的低秩扩展,不显式依赖于输入上下文。Malladi et al. (2023); Zeng & Lee (2024) 提供了低秩近似表达能力的理论证明。对于分解矩阵的初始化,OLoRA (Bu¨yu¨kakyu¨z, 2024) 应用了预训练权重的 QR 分解以提高语言模型的性能,而 LoLDU (Shi et al., 2024) 使用 LDU 分解来加速 LoRA 的训练。此外,AdaLoRA (Zhang et al., 2023a) 利用预训练权重的奇异值分解 (SVD) 并为每个参数引入重要性评分作为度量,以实现秩的动态调整。相比之下,TPA 将 Q、K 和 V 构建为上下文分解的张量,从而实现动态适应。
KV Cache Optimization. During the inference time of Transformers, key and value tensors of the previous tokens are repeatedly computed due to their auto-regressive nature. To enhance efficiency, firstly proposed by Ott et al. (2019), these tensors can be cached in memory for future decoding, referred to as the KV cache. However, the KV cache requires additional memory usage and may add to more latencies due to the bandwidth limitation (Adnan et al., 2024). Therefore, previous studies have explored diverse approaches to mitigate these issues, including KV cache eviction to discard less significant tokens (Zhang et al., 2023c; Xiao et al., 2024; Cai et al., 2024; Adnan et al., 2024), dynamic sparse attention among selected keys and values (Ribar et al., 2024; Tang et al., 2024; Singhania et al., 2024), KV cache offloading to CPU (He & Zhai, 2024; Lee et al., 2024; Sun et al., 2024), as well as quantization of KV cache (Xiao et al., 2023; Liu et al., 2024c; Hooper et al., 2024). Besides these methods, it is also effective to reduce the amount of KV cache for each token, by approaches such as reducing the number of KV heads (Ren et al., 2024; Ainslie et al., 2023), cross-layer KV re-usage (Xiao et al., 2019; Mu et al., 2024; Wu et al., 2024), and low-rank KV representation (Saxena et al., 2024). Different from the methods above, TPA reduces the size of the KV cache by using tensor-decomposed keys and values.
KV Cache 优化。在 Transformer 的推理过程中,由于自回归特性,先前 Token 的键和值张量会被重复计算。为了提高效率,Ott 等人 (2019) 首次提出可以将这些张量缓存在内存中供未来解码使用,称为 KV cache。然而,KV cache 需要额外的内存使用,并且由于带宽限制可能会增加延迟 (Adnan 等人, 2024)。因此,先前的研究探索了多种方法来缓解这些问题,包括通过 KV cache 淘汰机制丢弃不太重要的 Token (Zhang 等人, 2023c; Xiao 等人, 2024; Cai 等人, 2024; Adnan 等人, 2024),在选定的键和值之间进行动态稀疏注意力 (Ribar 等人, 2024; Tang 等人, 2024; Singhania 等人, 2024),将 KV cache 卸载到 CPU (He & Zhai, 2024; Lee 等人, 2024; Sun 等人, 2024),以及 KV cache 的量化 (Xiao 等人, 2023; Liu 等人, 2024c; Hooper 等人, 2024)。除了这些方法外,通过减少每个 Token 的 KV cache 量也是有效的,例如减少 KV 头的数量 (Ren 等人, 2024; Ainslie 等人, 2023),跨层 KV 重用 (Xiao 等人, 2019; Mu 等人, 2024; Wu 等人, 2024),以及低秩 KV 表示 (Saxena 等人, 2024)。与上述方法不同,TPA 通过使用张量分解的键和值来减少 KV cache 的大小。
6 Conclusion
6 结论
We introduced Tensor Product Attention (TPA), which factorizes query, key, and value matrices into rank $R$ tensor products dependent on the token’s hidden state. Storing only the factorized key/value components during auto regressive decoding substantially decreases the kv memory size with improved performance compared with MHA, MQA, GQA, and MLA. The approach is fully compatible with RoPE (and can store pre-rotated keys). Variants of TPA include factorizing only the key/value or sharing basis vectors across tokens. Overall, TPA offers a powerful mechanism for compressing KV storage while improving the model performance, thereby enabling longer sequence contexts under constrained memory.
我们引入了张量积注意力 (Tensor Product Attention, TPA),它将查询、键和值矩阵分解为依赖于 Token 隐藏状态的秩为 $R$ 的张量积。在自回归解码过程中,仅存储分解后的键/值组件,与多头注意力 (MHA)、多查询注意力 (MQA)、分组查询注意力 (GQA) 和多层注意力 (MLA) 相比,显著减少了键值存储大小并提升了性能。该方法完全兼容旋转位置编码 (RoPE)(并且可以存储预旋转的键)。TPA 的变体包括仅分解键/值或在 Token 之间共享基向量。总体而言,TPA 提供了一种强大的机制,可以在压缩键值存储的同时提升模型性能,从而在内存受限的情况下支持更长的序列上下文。
References
参考文献
Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant Nair, Ilya Solo vey chi k, and Purushotham Kamath. Keyformer: Kv cache reduction through key tokens selection for efficient generative inference. Proceedings of Machine Learning and Systems, 6:114–127, 2024.
Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant Nair, Ilya Solo vey chi k, 和 Purushotham Kamath. Keyformer: 通过关键 Token 选择减少 KV 缓存以实现高效生成推理. 机器学习与系统会议论文集, 6:114–127, 2024.
Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlya nsk iy, Federico Lebro´n, and Sumit Sanghai. GQA: training generalized multi-query transformer models from multi-head checkpoints. In Houda Bouamor, Juan Pino, and Kalika Bali (eds.), Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, EMNLP 2023, Singapore, December 6- 10, 2023, pp. 4895–4901. Association for Computational Linguistics, 2023. doi: 10.18653/V1/ 2023.EMNLP-MAIN.298. URL https://doi.org/10.18653/v1/2023.emnlp-main.298.
Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, 和 Sumit Sanghai. GQA: 从多头检查点训练广义多查询Transformer模型. 在 Houda Bouamor, Juan Pino, 和 Kalika Bali (编), 2023年自然语言处理经验方法会议论文集, EMNLP 2023, 新加坡, 2023年12月6-10日, 第4895–4901页. 计算语言学协会, 2023. doi: 10.18653/V1/2023.EMNLP-MAIN.298. URL https://doi.org/10.18653/v1/2023.emnlp-main.298.
Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, and Yejin Choi. PIQA: reasoning about physical commonsense in natural language. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pp. 7432–7439. AAAI Press, 2020.
Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, 和 Yejin Choi. PIQA: 自然语言中的物理常识推理. 在第三十四届 AAAI 人工智能会议 (AAAI 2020), 第三十二届人工智能创新应用会议 (IAAI 2020), 第十届 AAAI 人工智能教育进展研讨会 (EAAI 2020), 美国纽约, 2020 年 2 月 7-12 日, 第 7432–7439 页. AAAI Press, 2020.
Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neel a kant an, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, 等. 语言模型是少样本学习者. 神经信息处理系统进展, 33:1877–1901, 2020.
Se´bastien Bubeck, Varun Chandra sekar an, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Ka- mar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, et al. Sparks of artificial general intelligence: Early experiments with gpt-4. arXiv preprint arXiv:2303.12712, 2023.
Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg 等. 通用人工智能的火花:GPT-4 的早期实验. arXiv 预印本 arXiv:2303.12712, 2023.
Kerim Bu¨yu¨kakyu¨z. Olora: Ortho normal low-rank adaptation of large language models. arXiv preprint arXiv:2406.01775, 2024.
Kerim Büyükakgöz. Olora: 大语言模型的正交低秩适应 (Ortho Normal Low-Rank Adaptation of Large Language Models). arXiv 预印本 arXiv:2406.01775, 2024.
Zefan Cai, Yichi Zhang, Bofei Gao, Yuliang Liu, Tianyu Liu, Keming Lu, Wayne Xiong, Yue Dong, Baobao Chang, Junjie Hu, et al. Pyramidkv: Dynamic kv cache compression based on pyramidal information funneling. arXiv preprint arXiv:2406.02069, 2024.
Zefan Cai, Yichi Zhang, Bofei Gao, Yuliang Liu, Tianyu Liu, Keming Lu, Wayne Xiong, Yue Dong, Baobao Chang, Junjie Hu, 等. Pyramidkv: 基于金字塔信息漏斗的动态 KV 缓存压缩. arXiv 预印本 arXiv:2406.02069, 2024.
Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, and Jiaya Jia. Longlora: Efficient fine-tuning of long-context large language models. In The Twelfth International Conference on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11, 2024, 2024.
Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, and Jiaya Jia. Longlora: 长上下文大语言模型的高效微调。在第十二届国际学习表征会议 (ICLR 2024) 上,2024年5月7-11日,奥地利维也纳,2024年。
Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
Rewon Child、Scott Gray、Alec Radford 和 Ilya Sutskever。使用稀疏 Transformer 生成长序列。arXiv 预印本 arXiv:1904.10509,2019。
Krzysztof Marcin Cho roman ski, Valerii Li kho s her s to v, David Dohan, Xingyou Song, Andreea Gane, Tama´s Sarlo´s, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, David Benjamin Belanger, Lucy J. Colwell, and Adrian Weller. Rethinking attention with performers. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021, 2021.
Krzysztof Marcin Cho roman ski, Valerii Li kho s her s to v, David Dohan, Xingyou Song, Andreea Gane, Tama´s Sarlo´s, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, David Benjamin Belanger, Lucy J. Colwell, 和 Adrian Weller。重新思考注意力机制与Performers。在第九届国际学习表征会议 (ICLR 2021) 上,虚拟会议,奥地利,2021年5月3-7日,2021年。
Ting Jiang, Shaohan Huang, Shengyue Luo, Zihan Zhang, Haizhen Huang, Furu Wei, Weiwei Deng, Feng Sun, Qi Zhang, Deqing Wang, et al. Mora: High-rank updating for parameter-efficient finetuning. arXiv preprint arXiv:2405.12130, 2024.
Ting Jiang, Shaohan Huang, Shengyue Luo, Zihan Zhang, Haizhen Huang, Furu Wei, Weiwei Deng, Feng Sun, Qi Zhang, Deqing Wang, 等. Mora: 用于参数高效微调的高秩更新. arXiv 预印本 arXiv:2405.12130, 2024.
Andrej Karpathy. NanoGPT. https://github.com/karpathy/nanoGPT, 2022.
Andrej Karpathy. NanoGPT. https://github.com/karpathy/nanoGPT, 2022.
Yongyu Mu, Yuzhang Wu, Yuchun Fan, Chenglong Wang, Hengyu Li, Qiaozhi He, Murun Yang, Tong Xiao, and Jingbo Zhu. Cross-layer attention sharing for large language models. arXiv preprint arXiv:2408.01890, 2024.
Yongyu Mu, Yuzhang Wu, Yuchun Fan, Chenglong Wang, Hengyu Li, Qiaozhi He, Murun Yang, Tong Xiao, and Jingbo Zhu. 大语言模型的跨层注意力共享. arXiv preprint arXiv:2408.01890, 2024.
Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, and Michael Auli. fairseq: A fast, extensible toolkit for sequence modeling. In Waleed Ammar, Annie Louis, and Nasrin Most af azad eh (eds.), Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Demonstrations, pp. 48–53. Association for Computational Linguistics, 2019.
Myle Ott、Sergey Edunov、Alexei Baevski、Angela Fan、Sam Gross、Nathan Ng、David Grangier 和 Michael Auli。fairseq:一个快速、可扩展的序列建模工具包。在 Waleed Ammar、Annie Louis 和 Nasrin Mostafazadeh(编)的《2019年北美计算语言学协会会议:人类语言技术》(NAACL-HLT 2019)中,美国明尼苏达州明尼阿波利斯,2019年6月2-7日,演示部分,第48-53页。计算语言学协会,2019年。
Weijieying Ren, Xinlong Li, Lei Wang, Tianxiang Zhao, and Wei Qin. Analyzing and reducing catastrophic forgetting in parameter efficient tuning. arXiv preprint arXiv:2402.18865, 2024.
任伟杰英,李新龙,王磊,赵天翔,秦伟。分析和减少参数高效调优中的灾难性遗忘。arXiv预印本 arXiv:2402.18865,2024。
Luka Ribar, Ivan Chelombiev, Luke Hudlass-Galley, Charlie Blake, Carlo Luschi, and Douglas Orr. Sparq attention: Bandwidth-efficient LLM inference. In Forty-first International Conference on Machine Learning, ICML 2024, Vienna, Austria, July 21-27, 2024, 2024.
Luka Ribar、Ivan Chelombiev、Luke Hudlass-Galley、Charlie Blake、Carlo Luschi 和 Douglas Orr。Sparq attention: 带宽高效的大语言模型推理。在第四十一届国际机器学习会议(ICML 2024)上,2024年7月21-27日,奥地利维也纳,2024年。
Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhaga va tula, and Yejin Choi. Winogrande: An adversarial winograd schema challenge at scale. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pp. 8732–8740. AAAI Press, 2020.
Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, 和 Yejin Choi. Winogrande: 大规模对抗性 Winograd 模式挑战. 在第三十四届 AAAI 人工智能会议 (AAAI 2020), 第三十二届人工智能创新应用会议 (IAAI 2020), 第十届 AAAI 人工智能教育进展研讨会 (EAAI 2020), 美国纽约, 2020年2月7-12日, 第8732–8740页. AAAI Press, 2020.
Utkarsh Saxena, Gobinda Saha, Sakshi Choudhary, and Kaushik Roy. Eigen attention: Attention in low-rank space for KV cache compression. In Yaser Al-Onaizan, Mohit Bansal, and YunNung Chen (eds.), Findings of the Association for Computational Linguistics: EMNLP 2024, Miami, Florida, USA, November 12-16, 2024, pp. 15332–15344. Association for Computational Linguistics, 2024.
Utkarsh Saxena、Gobinda Saha、Sakshi Choudhary 和 Kaushik Roy。Eigen attention: Attention in low-rank space for KV cache compression。在 Yaser Al-Onaizan、Mohit Bansal 和 YunNung Chen(编),《计算语言学协会发现:EMNLP 2024》,美国佛罗里达州迈阿密,2024 年 11 月 12-16 日,第 15332–15344 页。计算语言学协会,2024 年。
Imanol Schlag, Kazuki Irie, and Ju¨rgen Schmid huber. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pp. 9355–9366. PMLR, 2021.
Imanol Schlag, Kazuki Irie, 和 Ju¨rgen Schmidhuber。线性 Transformer 是隐秘的快速权重编程器。在国际机器学习会议上,第 9355–9366 页。PMLR,2021。
Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
Noam Shazeer. 快速Transformer解码:一个写头足矣。arXiv预印本 arXiv:1911.02150, 2019.
Noam Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.
Noam Shazeer. Glu变体改进Transformer. arXiv预印本 arXiv:2002.05202, 2020.
Yiming Shi, Jiwei Wei, Yujia Wu, Ran Ran, Chengwei Sun, Shiyuan He, and Yang Yang. Loldu: Low-rank adaptation via lower-diag-upper decomposition for parameter-efficient fine-tuning. arXiv preprint arXiv:2410.13618, 2024.
Yiming Shi, Jiwei Wei, Yujia Wu, Ran Ran, Chengwei Sun, Shiyuan He, 和 Yang Yang. Loldu: 通过下三角-对角-上三角分解进行低秩适应的参数高效微调. arXiv 预印本 arXiv:2410.13618, 2024.
Zhenmei Shi, Jiefeng Chen, Kunyang Li, Jayaram Raghuram, Xi Wu, Yingyu Liang, and Somesh Jha. The trade-off between universality and label efficiency of representations from contrastive learning. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023.
Zhenmei Shi, Jiefeng Chen, Kunyang Li, Jayaram Raghuram, Xi Wu, Yingyu Liang, 和 Somesh Jha. 对比学习表示在通用性和标签效率之间的权衡. 在第十一届国际学习表示会议 (ICLR 2023) 上, 卢旺达基加利, 2023年5月1-5日, 2023.
Prajwal Singhania, Siddharth Singh, Shwai He, Soheil Feizi, and Abhinav Bhatele. Loki: Low-rank keys for efficient sparse attention. arXiv preprint arXiv:2406.02542, 2024.
Prajwal Singhania, Siddharth Singh, Shwai He, Soheil Feizi, 和 Abhinav Bhatele. Loki: 用于高效稀疏注意力的低秩键。arXiv 预印本 arXiv:2406.02542, 2024.
Jianlin Su. The extreme pull between cache and effect: From MHA, MQA, GQA to MLA. https: //spaces.ac.cn/archives/10091, May 2024.
Jianlin Su. 缓存与效果之间的极端拉扯:从 MHA、MQA、GQA 到 MLA。https://spaces.ac.cn/archives/10091,2024 年 5 月。
Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. Neuro computing, 568:127063, 2024a.
Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: 使用旋转位置嵌入增强的 Transformer. Neuro computing, 568:127063, 2024a.
Jianlin Su, Murtadha H. M. Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. Neuro computing, 568:127063, 2024b.
Jianlin Su, Murtadha H. M. Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, 和 Yunfeng Liu. Roformer: 使用旋转位置嵌入增强的 Transformer. 神经计算, 568:127063, 2024b.
Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, and Beidi Chen. Shadowkv: Kv cache in shadows for high-throughput long-context llm inference. arXiv preprint arXiv:2410.21465, 2024.
Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, 和 Beidi Chen. Shadowkv: 高吞吐量长上下文大语言模型推理中的影子键值缓存. arXiv 预印本 arXiv:2410.21465, 2024.
Appendix
附录
A Proofs of Theorems 21
A 定理 21 的证明
B More on Experiments 22
B 更多实验细节 22
A Proofs of Theorems
定理证明
Proof of Theorem 1.
定理1的证明
Proof. Because RoPE is a linear orthogonal transform, we can write
证明。因为 RoPE 是一个线性正交变换,我们可以写成
\tilde{\mathbf{Q}}_{t}=\mathbf{Q}_{t}\,\mathbf{T}_{t}=\frac{1}{R_{Q}}\bigl(\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\mathbf{B}_{Q}(\mathbf{x}_{t})\bigr)\,\mathbf{T}_{t}=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\bigl(\mathbf{B}_{Q}(\mathbf{x}_{t})\,\mathbf{T}_{t}\bigr),```
\tilde{\mathbf{Q}}{t}=\mathbf{Q}{t},\mathbf{T}{t}=\frac{1}{R{Q}}\bigl(\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\mathbf{B}{Q}(\mathbf{x}{t})\bigr),\mathbf{T}{t}=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top}\bigl(\mathbf{B}{Q}(\mathbf{x}{t}),\mathbf{T}_{t}\bigr),```
where $\mathbf{T}_{t}$ is the block-diagonal matrix encoding RoPE. This allows us to define
其中 $\mathbf{T}_{t}$ 是编码 RoPE 的块对角矩阵。这使我们能够定义
\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})=\mathbf{B}_{Q}(\mathbf{x}_{t})\,\mathbf{T}_{t},```
\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t})=\mathbf{B}{Q}(\mathbf{x}{t}),\mathbf{T}_{t},```
thereby obtaining
从而获得
\mathrm{RoPE}(\mathbf{Q}_{t})=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t}).```
\mathrm{RoPE}(\mathbf{Q}{t})=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top}\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t}).```
Similarly, for the key tensor $\mathbf{K}_{s}$ , we have
同样,对于关键张量 $\mathbf{K}_{s}$ ,我们有
\tilde{\mathbf{K}}_{s}=\mathbf{K}_{s}\,\mathbf{T}_{s}=\frac{1}{R_{K}}\big(\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\,\mathbf{B}_{K}(\mathbf{x}_{s})\big)\,\mathbf{T}_{s}=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\big(\mathbf{B}_{K}(\mathbf{x}_{s})\,\mathbf{T}_{s}\big),```
\tilde{\mathbf{K}}{s}=\mathbf{K}{s},\mathbf{T}{s}=\frac{1}{R{K}}\big(\mathbf{A}{K}(\mathbf{x}{s})^{\top},\mathbf{B}{K}(\mathbf{x}{s})\big),\mathbf{T}{s}=\frac{1}{R{K}}\mathbf{A}{K}(\mathbf{x}{s})^{\top}\big(\mathbf{B}{K}(\mathbf{x}{s}),\mathbf{T}_{s}\big),```
which defines
定义了
\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{s})=\mathbf{B}_{K}(\mathbf{x}_{s})\,\mathbf{T}_{s},```
\widetilde{\mathbf{B}}{K}(\mathbf{x}{s})=\mathbf{B}{K}(\mathbf{x}{s}),\mathbf{T}_{s},```
and thus
因此
\mathrm{RoPE}(\mathbf{K}_{s})=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{s}).```
\mathrm{RoPE}(\mathbf{K}{s})=\frac{1}{R{K}}\mathbf{A}{K}(\mathbf{x}{s})^{\top}\widetilde{\mathbf{B}}{K}(\mathbf{x}{s}).```
Now, consider the product of the rotated queries and keys:
现在,考虑旋转后的查询和键的乘积:
\begin{array}{r}{\widetilde{\mathbf{Q}}_{t}\widetilde{\mathbf{K}}_{s}^{\top}=\frac{1}{R_{Q}R_{K}}\left(\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})\right)\left(\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{s})\right)^{\top}}\\ {=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{s})^{\top}\mathbf{A}_{K}(\mathbf{x}_{s}),\quad\quad\quad}\end{array}```
\begin{array}{r}{\widetilde{\mathbf{Q}}{t}\widetilde{\mathbf{K}}{s}^{\top}=\frac{1}{R_{Q}R_{K}}\left(\mathbf{A}{Q}(\mathbf{x}{t})^{\top}\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t})\right)\left(\mathbf{A}{K}(\mathbf{x}{s})^{\top}\widetilde{\mathbf{B}}{K}(\mathbf{x}{s})\right)^{\top}}\ {=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{s})^{\top}\mathbf{A}_{K}(\mathbf{x}_{s}),\quad\quad\quad}\end{array}```
Since $\mathbf{T}{t}$ and ${\bf{T}}{s}$ encode positional rotations, the product $\mathbf{T}{t}\mathbf{T}{s}^{\top}$ corresponds to a relative rotation $\mathbf{T}_{t-s}$ . Therefore, we can express the above as
由于 $\mathbf{T}{t}$ 和 ${\bf{T}}{s}$ 编码了位置旋转,乘积 $\mathbf{T}{t}\mathbf{T}{s}^{\top}$ 对应于相对旋转 $\mathbf{T}_{t-s}$。因此,我们可以将上述表达式表示为
\begin{array}{r l}&{\tilde{\mathbf{Q}}_{t}\,\tilde{\mathbf{K}}_{s}^{\top}=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\left(\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t}\mathbf{T}_{s}^{\top}\mathbf{B}_{K}(\mathbf{x}_{s})^{\top}\right)\mathbf{A}_{K}(\mathbf{x}_{s})}\\ &{\,\,\,\,\,\,\,\,\,\,\,=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\left(\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\mathbf{B}_{K}(\mathbf{x}_{s})^{\top}\right)\mathbf{A}_{K}(\mathbf{x}_{s})}\\ &{\,\,\,\,\,\,\,\,\,\,\,\,=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\left(\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\right)\left(\mathbf{B}_{K}(\mathbf{x}_{s})^{\top}\mathbf{A}_{K}(\mathbf{x}_{s})\right)}\\ &{\,\,\,\,\,\,\,\,\,\,\,\,=\left(\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\right)\left(\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\mathbf{B}_{K}(\mathbf{x}_{s})\right)^{\top},}\end{array}```
\begin{array}{r l}&{\tilde{\mathbf{Q}}{t},\tilde{\mathbf{K}}{s}^{\top}=\frac{1}{R_{Q}R_{K}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top}\left(\mathbf{B}{Q}(\mathbf{x}{t})\mathbf{T}{t}\mathbf{T}{s}^{\top}\mathbf{B}{K}(\mathbf{x}{s})^{\top}\right)\mathbf{A}{K}(\mathbf{x}{s})}\ &{,,,,,,,,,,,=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\left(\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\mathbf{B}_{K}(\mathbf{x}_{s})^{\top}\right)\mathbf{A}_{K}(\mathbf{x}_{s})}\ &{,,,,,,,,,,,,=\frac{1}{R_{Q}R_{K}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\left(\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\right)\left(\mathbf{B}_{K}(\mathbf{x}_{s})^{\top}\mathbf{A}_{K}(\mathbf{x}_{s})\right)}\ &{,,,,,,,,,,,,=\left(\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\mathbf{B}_{Q}(\mathbf{x}_{t})\mathbf{T}_{t-s}\right)\left(\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{s})^{\top}\mathbf{B}_{K}(\mathbf{x}_{s})\right)^{\top},}\end{array}```
This shows that
这表明
\mathrm{RoPE}_{t-s}(\mathbf{Q}_{t})\mathbf{K}_{s}^{\top}=\widetilde{\mathbf{Q}}_{t}\,\widetilde{\mathbf{K}}_{s}^{\top},```
\mathrm{RoPE}{t-s}(\mathbf{Q}{t})\mathbf{K}{s}^{\top}=\widetilde{\mathbf{Q}}{t},\widetilde{\mathbf{K}}_{s}^{\top},```
Focusing on individual heads $i$ , the above matrix equality implies:
关注单个头 $i$,上述矩阵等式意味着:
\mathrm{RoPE}_{t-s}(\mathbf{q}_{t,i})^{\top}\mathbf{k}_{s,i}=\widetilde{\mathbf{q}}_{t,i}^{\top}\widetilde{\mathbf{k}}_{s,i},```
\mathrm{RoPE}{t-s}(\mathbf{q}{t,i})^{\top}\mathbf{k}{s,i}=\widetilde{\mathbf{q}}{t,i}^{\top}\widetilde{\mathbf{k}}_{s,i},```
where
其中
\widetilde{\mathbf{q}}_{t,i}=\mathrm{RoPE}(\mathbf{q}_{t,i})=\mathbf{T}_{t}\mathbf{q}_{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```
\widetilde{\mathbf{q}}{t,i}=\mathrm{RoPE}(\mathbf{q}{t,i})=\mathbf{T}{t}\mathbf{q}{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```
This equali ty confirms that the relative positional encoding between queries and keys is preserved under TPA’s factorization and RoPE’s rotation. Thus, TPA maintains compatibility with RoPE. This completes the proof of Theorem 1. 口
该等式确认了在TPA的分解和RoPE的旋转下,查询和键之间的相对位置编码得以保留。因此,TPA保持了与RoPE的兼容性。至此,定理1的证明完成。
B More on Experiments
B 更多实验细节
B.1 Experimental Settings
B.1 实验设置
We list the main architecture hyper-parameters and training devices in Table 6. We fix $d_{h},=,64$ for all the models. Moreover, we fix the number of KV heads with 2 for GQA models; $d_{h}^{R}=32$ for MLA models; and $R_{k},=,R_{v},=,2$ , $R_{q},=,6$ for TPA and TPA-KV only models. Other hyperparameters are listed in Table 7.
我们在表 6 中列出了主要的架构超参数和训练设备。对于所有模型,我们固定 $d_{h},=,64$。此外,我们为 GQA 模型固定了 2 个 KV 头;为 MLA 模型固定了 $d_{h}^{R}=32$;对于 TPA 和仅 TPA-KV 模型,固定了 $R_{k},=,R_{v},=,2$ 和 $R_{q},=,6$。其他超参数列在表 7 中。
Table 6: The architecture hyper-parameters and training devices of models. Abbreviations: $\mathbf{BS.=}$ Batch Size, GAS. $=$ Gradient Accumulation Steps.
表 6: 模型的架构超参数和训练设备。缩写:$\mathbf{BS.=}$ 批量大小 (Batch Size),GAS. $=$ 梯度累积步数 (Gradient Accumulation Steps)。
| 模型大小 | 参数量 | 设备 | 微批次大小 | GAS | 层数 | dMODEL |
|---|---|---|---|---|---|---|
| 小型 | 124M | 4xA100 GPUS | 24 | 5 | 12 | 768 |
| 中型 | 353M | 8×A100 GPUS | 20 | 3 | 24 | 1024 |
| 大型 | 772M | 8xA100 GPUS | 15 | 4 | 36 | 1280 |
Table 7: The architecture hyper-parameters for different models.
表 7: 不同模型的架构超参数。
| MODELSIZE | SMALL | MEDIUM | LARGE |
|---|---|---|---|
| h (MHA) | 12 | 16 | 20 |
| h (MQA) | 23 | 31 | 39 |
| h (GQA) | 22 | 30 | 38 |
| nh (MLA) | 12 | 23 | 34 |
| h (TPA-KVONLY) | 22 | 29 | 37 |
| h (TPA) | 34 | 47 | 61 |
| dc (MLA) | 256 | 512 | 512 |
| d' (MLA) | 512 | 1024 | 1024 |
B.2 Additional Experimental Results
B.2 额外实验结果
We display the evaluation results for small-size (124M) models in Tables 8-9.
我们在表 8-9 中展示了小型 (124M) 模型的评估结果。
Table 8: The evaluation results of small models with different attention mechanisms pre-trained using FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 8: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制小模型的评估结果 (0-shot 使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 50.63 | 26.96 | 59.39 | 36.18 | 32.00 | 64.96 | 51.85 | 23.40 | 70.30 | 46.19 |
| MQA | 49.62 | 25.34 | 55.72 | 35.94 | 31.40 | 64.85 | 51.30 | 23.37 | 68.70 | 45.14 |
| GQA | 48.70 | 25.68 | 56.15 | 35.58 | 31.40 | 64.91 | 51.62 | 23.12 | 68.20 | 45.04 |
| MLA | 49.66 | 26.45 | 61.22 | 33.94 | 32.40 | 62.73 | 50.43 | 23.29 | 71.50 | 45.74 |
| TPA-KVonly | 51.05 | 26.54 | 57.25 | 36.77 | 32.60 | 65.02 | 50.91 | 23.64 | 69.70 | 45.94 |
| TPA | 51.26 | 27.39 | 57.00 | 36.68 | 32.80 | 64.47 | 49.72 | 24.61 | 72.00 | 46.21 |
B.3 Ablation Studies on Learning Rates
B.3 学习率的消融研究
We implement a set of parallel experiments for medium models with learning rate $6\times10^{-4}$ , and the curves for training loss, validation loss and validation perplexity are displayed in Figure 5. We also show the performance of these models on the benchmarks described in Section 4 in Tables 10-11. The results show that TPA and TPA-KVonly models can also outperform other types of attention with different learning rates.
我们对中等模型进行了一组并行实验,学习率为 $6\times10^{-4}$,训练损失、验证损失和验证困惑度的曲线如图 5 所示。我们还在表 10-11 中展示了这些模型在第 4 节描述的基准测试中的表现。结果表明,TPA 和 TPA-KVonly 模型在不同学习率下也能优于其他类型的注意力机制。
Table 9: The evaluation results of small models with different attention mechanisms on FineWebEdu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 9: 不同注意力机制的小模型在 FineWebEdu 100B 数据集上的评估结果 (2-shot 使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 57.66 | 28.24 | 57.28 | 36.43 | 29.60 | 64.09 | 51.14 | 26.57 | 82.00 | 48.11 |
| MQA | 53.79 | 26.35 | 44.95 | 34.18 | 28.80 | 62.79 | 52.01 | 25.91 | 78.10 | 45.21 |
| GQA | 55.01 | 25.94 | 55.72 | 35.68 | 31.80 | 65.29 | 51.93 | 25.27 | 77.80 | 47.16 |
| MLA | 52.78 | 26.19 | 57.25 | 33.19 | 29.60 | 63.98 | 50.43 | 24.90 | 76.00 | 46.04 |
| TPA-KVonly | 54.25 | 27.90 | 57.06 | 36.36 | 31.80 | 64.31 | 53.59 | 26.18 | 79.20 | 47.85 |
| TPA | 57.53 | 28.07 | 56.33 | 36.49 | 31.80 | 64.36 | 51.14 | 25.92 | 79.70 | 47.93 |

Figure 5: The training loss, validation loss and validation perplexity of medium-size (353M) models with learning rate $6\times10^{-4}$ and different attention mechanisms on the FineWeb-Edu 100B dataset.
图 5: 在 FineWeb-Edu 100B 数据集上,学习率为 $6\times10^{-4}$ 的中等规模 (353M) 模型在不同注意力机制下的训练损失、验证损失和验证困惑度。
Table 10: The evaluation results of medium models (learning rate $\phantom{0}{=}6\times10^{-4}$ ) with different attention mechanisms pre-trained using FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 10: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制的中等模型 (学习率 $\phantom{0}{=}6\times10^{-4}$) 的评估结果 (零样本使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 59.51 | 29.52 | 59.60 | 45.68 | 34.20 | 68.82 | 53.43 | 23.33 | 76.90 | 50.11 |
| MQA | 57.62 | 31.91 | 59.45 | 45.69 | 35.40 | 69.31 | 53.51 | 26.47 | 74.60 | 50.44 |
| GQA | 28.67 | 31.48 | 58.29 | 45.45 | 35.20 | 68.50 | 54.46 | 24.58 | 76.50 | 47.01 |
| MLA | 57.49 | 29.44 | 59.97 | 44.09 | 25.77 | 68.66 | 53.04 | 25.77 | 76.40 | 48.96 |
| TPA-KVonly | 58.01 | 30.12 | 58.01 | 45.95 | 35.60 | 69.10 | 53.12 | 25.39 | 75.10 | 50.04 |
| TPA | 58.38 | 31.57 | 59.39 | 46.83 | 37.00 | 70.02 | 54.06 | 25.52 | 79.90 | 51.41 |
Table 11: The evaluation results of medium models (learning rate $6\times10^{-4},$ ) with different attention mechanisms pre-trained using FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each column are bolded. Abbreviations: HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande.
表 11: 使用 FineWeb-Edu 100B 数据集预训练的不同注意力机制的中等模型 (学习率 $6\times10^{-4},$ ) 的评估结果 (2-shot 使用 lm-evaluation-harness)。每列中的最佳分数加粗显示。缩写:HellaSw. $=$ HellaSwag, W.G. $=$ WinoGrande。
| 方法 | ARC-E | ARC-C | BoolQ | HellaSw. | OBQA | PIQA | W.G. | MMLU | SciQ | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| MHA | 64.73 | 32.42 | 58.29 | 45.89 | 34.20 | 68.50 | 53.20 | 25.86 | 88.00 | 52.34 |
| MQA | 64.98 | 33.62 | 55.02 | 45.81 | 34.00 | 69.59 | 53.43 | 24.30 | 85.20 | 51.77 |
| GQA | 65.24 | 33.19 | 56.54 | 45.41 | 34.80 | 69.04 | 55.72 | 24.73 | 87.90 | 52.51 |
| MLA | 63.80 | 31.06 | 58.50 | 44.19 | 35.40 | 68.44 | 51.62 | 25.22 | 88.50 | 51.86 |
| TPA-KVonly | 64.69 | 32.34 | 59.48 | 46.23 | 35.40 | 70.08 | 54.06 | 25.64 | 86.30 | 52.69 |
| TPA | 67.97 | 34.56 | 57.22 | 46.87 | 34.60 | 69.91 | 52.01 | 25.07 | 89.90 | 53.12 |
