[论文翻译]Tensor Product Attention Is All You Need


原文地址:https://arxiv.org/pdf/2501.06425


Tensor Product Attention Is All You Need

Tensor Product Attention Is All You Need

Abstract

摘要

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decomposition s to represent queries, keys, and values compactly, significantly shrinking KV cache size at inference time. By factorizing these representations into contextual low-rank components (contextual factorization) and seamlessly integrating with RoPE, TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation of language modeling tasks, we demonstrate that T6 exceeds the performance of standard Transformer baselines including MHA, MQA, GQA, and MLA across various metrics, including perplexity and a range of renowned evaluation benchmarks. Notably, TPA’s memory efficiency enables the processing of significantly longer sequences under fixed resource constraints, addressing a critical s cal ability challenge in modern language models. The code is available at https://github.com/tensorgi/T6.

扩展语言模型以处理更长的输入序列通常需要较大的键值(KV)缓存,导致推理过程中内存开销显著增加。本文提出了一种新颖的注意力机制——张量积注意力(Tensor Product Attention,TPA),该机制利用张量分解技术紧凑地表示查询、键和值,从而在推理时显著缩小KV缓存的大小。通过将这些表示分解为上下文低秩分量(上下文分解)并与RoPE无缝集成,TPA在提高模型质量的同时实现了内存效率的提升。基于TPA,我们引入了张量积注意力Transformer(T6),这是一种用于序列建模的新模型架构。通过对语言建模任务进行广泛的实证评估,我们证明T6在包括困惑度和一系列著名评估基准在内的各种指标上均优于标准Transformer基线模型,包括MHA、MQA、GQA和MLA。值得注意的是,TPA的内存效率使得在固定资源约束下能够处理显著更长的序列,解决了现代语言模型中的一个关键扩展性挑战。代码可在https://github.com/tensorgi/T6获取。

1 Introduction

1 引言

Large language models (LLMs) have revolutionized natural language processing, demonstrating exceptional performance across tasks (Brown et al., 2020; Chowdhery et al., 2023; Touvron et al., 2023; Bubeck et al., 2023). As these models evolve, their ability to process longer contexts becomes increasingly important for sophisticated applications such as document analysis, complex reasoning, and code completions. However, managing longer sequences during inference poses significant computational and memory challenges, particularly due to the storage of key-value (KV) caches (Zhang et al., 2023c; Liu et al., 2024c). Because memory consumption grows linearly with sequence length, the maximum context window is limited by practical hardware constraints.

大语言模型 (LLMs) 彻底改变了自然语言处理领域,在各种任务中展现了卓越的性能 (Brown et al., 2020; Chowdhery et al., 2023; Touvron et al., 2023; Bubeck et al., 2023)。随着这些模型的演进,它们处理更长上下文的能力对于文档分析、复杂推理和代码补全等复杂应用变得越来越重要。然而,在推理过程中管理更长的序列带来了显著的计算和内存挑战,特别是由于键值 (KV) 缓存的存储 (Zhang et al., 2023c; Liu et al., 2024c)。由于内存消耗随序列长度线性增长,最大上下文窗口受到实际硬件限制的制约。

A variety of solutions have been explored to address this memory bottleneck. Some approaches compress or selectively prune cached states through sparse attention patterns (Child et al., 2019) or token eviction strategies (Zhang et al., 2023c; Xiao et al., 2024; Ribar et al., 2024), though such methods risk discarding tokens that may later prove important. Other work proposes off-chip storage of keyvalue states (He & Zhai, 2024), at the expense of increased I/O latency. Attention variants like multi-query attention (MQA) (Shazeer, 2019) and grouped-query attention (GQA) (Ainslie et al., 2023) reduce per-token cache requirements by sharing keys and values across heads, but often compromise flexibility or require significant architectural modifications. Meanwhile, low-rank weight factorization methods such as LoRA (Hu et al., 2022) effectively reduce fine-tuning memory, yet do not address the KV cache overhead that dominates runtime. The recently introduced Multi-head Latent Attention (MLA) in Deepseek-V2 (Liu et al., 2024a) caches compressed key-value representations but needs additional position-encoded parameters per head due to incompatibility with Rotary Position Embedding (RoPE) efficiently (Su et al., 2024b).

为了解决这一内存瓶颈,人们探索了多种解决方案。一些方法通过稀疏注意力模式 (Child et al., 2019) 或 Token 淘汰策略 (Zhang et al., 2023c; Xiao et al., 2024; Ribar et al., 2024) 来压缩或选择性修剪缓存状态,尽管这些方法可能会丢弃后来可能被证明重要的 Token。其他工作提出了将键值状态存储在片外 (He & Zhai, 2024),但代价是增加了 I/O 延迟。像多查询注意力 (MQA) (Shazeer, 2019) 和分组查询注意力 (GQA) (Ainslie et al., 2023) 这样的注意力变体通过跨头共享键和值来减少每个 Token 的缓存需求,但通常会牺牲灵活性或需要显著的架构修改。与此同时,低秩权重分解方法如 LoRA (Hu et al., 2022) 有效减少了微调内存,但并未解决主导运行时开销的 KV 缓存问题。最近在 Deepseek-V2 (Liu et al., 2024a) 中引入的多头潜在注意力 (MLA) 缓存了压缩的键值表示,但由于与旋转位置嵌入 (RoPE) 不兼容,每个头需要额外的位置编码参数 (Su et al., 2024b)。


Figure 1: Tensor Product Attention (TPA) in the Tensor ProducT ATTenTion Transformer (T6). Different from multi-head attention, in each layer, firstly the hidden state goes through different linear layers to get the latent factor matrices A’s and B’s for query, key, and value. We additionally apply RoPE to $\mathbf{B}{Q}$ and $\mathbf{B}{K}$ for query and key. Then the multi-head query, key, and value vectors are attained by the tensor product of $\mathbf{A}{(\cdot)}$ and $\mathbf{B}{(\cdot)}$ . Finally, the output of TPA is produced by scaled dot-product attention followed by linear projection of concatenated results of multiple heads.

图 1: Tensor ProducT ATTenTion Transformer (T6) 中的张量积注意力 (Tensor Product Attention, TPA)。与多头注意力不同,在每一层中,首先隐藏状态通过不同的线性层得到查询、键和值的潜在因子矩阵 A 和 B。我们额外对查询和键的 $\mathbf{B}{Q}$ 和 $\mathbf{B}{K}$ 应用了 RoPE。然后,通过 $\mathbf{A}{(\cdot)}$ 和 $\mathbf{B}{(\cdot)}$ 的张量积得到多头查询、键和值向量。最后,TPA 的输出通过缩放点积注意力生成,随后对多头连接的结果进行线性投影。

In order to overcome the limitations of existing approaches, we introduce Tensor Product Attention (TPA), as illustrated in Figure 1, a novel architecture that uses higher-order tensors to factorize queries (Q), keys (K), and values (V) during attention computation. By dynamically factorizing activation s rather than static weights (e.g., LoRA), TPA constructs low-rank, contextual representations that substantially reduce KV cache memory usage with improved representational capacity. In practice, TPA can reduce the memory overhead by an order of magnitude compared to standard multi-head attention (MHA) with lower pre training validation loss (perplexity) and improved downstream performance.

为了克服现有方法的局限性,我们引入了张量积注意力 (Tensor Product Attention, TPA),如图 1 所示。这是一种新颖的架构,在注意力计算过程中使用高阶张量来分解查询 (Q)、键 (K) 和值 (V)。通过动态分解激活值而非静态权重(例如 LoRA),TPA 构建了低秩的上下文表示,显著减少了 KV 缓存的内存使用,同时提高了表示能力。在实际应用中,与标准的多头注意力 (MHA) 相比,TPA 可以将内存开销降低一个数量级,同时具有更低的预训练验证损失(困惑度)和更好的下游性能。

A key advantage of TPA is its native compatibility with rotary positional embeddings (RoPE) (Su et al., 2024b), enabling a straightforward drop-in replacement for multi-head attention (MHA) layers in modern LLM architectures such as LLaMA (Touvron et al., 2023) and Gemma (Team et al., 2024).

TPA 的一个关键优势是其与旋转位置嵌入 (RoPE) (Su et al., 2024b) 的原生兼容性,这使得它能够直接替代现代大语言模型架构(如 LLaMA (Touvron et al., 2023) 和 Gemma (Team et al., 2024))中的多头注意力 (MHA) 层。

Our primary contributions are summarized as follows:

我们的主要贡献总结如下:


Figure 2: Training loss and validation loss of pre training large-size (773M) models with different attention mechanisms on the FineWeb-Edu-100B dataset.

图 2: 在 FineWeb-Edu-100B 数据集上,使用不同注意力机制预训练大尺寸 (773M) 模型的训练损失和验证损失。

2 Background

2 背景

In this section, we review several classical forms of attention: Scaled Dot-Product Attention, MultiHead Attention (MHA) (Vaswani et al., 2017), Multi-Query Attention (MQA) (Shazeer, 2019), and Grouped Query Attention (GQA) (Ainslie et al., 2023), as well as Rotary Position Embedding (RoPE, Su et al. (2024b)). We also introduce a recent method called Multi-head Latent Attention (MLA) used in DeepSeek-V2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b).

在本节中,我们回顾了几种经典的注意力机制形式:缩放点积注意力 (Scaled Dot-Product Attention)、多头注意力 (MultiHead Attention, MHA) (Vaswani et al., 2017)、多查询注意力 (Multi-Query Attention, MQA) (Shazeer, 2019) 和分组查询注意力 (Grouped Query Attention, GQA) (Ainslie et al., 2023),以及旋转位置嵌入 (Rotary Position Embedding, RoPE, Su et al. (2024b))。我们还介绍了最近在 DeepSeek-V2 (Liu et al., 2024a) 和 DeepSeek-V3 (Liu et al., 2024b) 中使用的一种新方法,称为多头潜在注意力 (Multi-head Latent Attention, MLA)。

Notations. We use bold uppercase letters (e.g., X, $\mathbf{Q}$ ) for matrices, bold lowercase (e.g., a, b) for vectors, and italic uppercase (e.g., $\boldsymbol{W}{i}^{Q}$ ) for learnable parameter matrices. We denote by $[n]$ the set ${1,\ldots,n}$ for some positive integer $n$ . We use $\intercal$ to denote the transpose of a vector or a matrix. Let $d{\mathrm{model}}$ be the embedding dimension, $h$ the number of attention heads, $d_{h}$ the dimension per head, $\mathbf{x}{t},\in,\mathbb{R}^{d}$ the input for the $t$ -th token at a given attention layer, $\mathbf{X},\in,\mathbb{R}^{T\times d{\mathrm{model}}}$ denotes the input em beddings for $T$ tokens, and $\mathbf{Q}$ , K, $\mathbf{V}\in\breve{\mathbb{R}}^{T\times h\times d_{h}}$ denote the queries, keys, and values of $h$ heads for $T$ tokens. With a little abuse of notation, $\mathbf{Q}{i}$ , $\mathbf{K}{i}$ , $\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ denote the $i$ -th head of queries, keys, and values, and $\mathbf{Q}{t}$ , $\mathbf{K}{t}$ , $\mathbf{V}{t},\in,\mathbb{R}^{h\times d{h}}$ denote the heads of the query, key, and value for $t$ -th token.

符号说明。我们用粗体大写字母(例如,X, $\mathbf{Q}$)表示矩阵,粗体小写字母(例如,a, b)表示向量,斜体大写字母(例如,$\boldsymbol{W}{i}^{Q}$)表示可学习的参数矩阵。我们用 $[n]$ 表示集合 ${1,\ldots,n}$,其中 $n$ 是某个正整数。我们用 $\intercal$ 表示向量或矩阵的转置。设 $d{\mathrm{model}}$ 为嵌入维度,$h$ 为注意力头的数量,$d_{h}$ 为每个头的维度,$\mathbf{x}{t},\in,\mathbb{R}^{d}$ 为给定注意力层中第 $t$ 个 token 的输入,$\mathbf{X},\in,\mathbb{R}^{T\times d{\mathrm{model}}}$ 表示 $T$ 个 token 的输入嵌入,$\mathbf{Q}$、K、$\mathbf{V}\in\breve{\mathbb{R}}^{T\times h\times d_{h}}$ 表示 $T$ 个 token 的 $h$ 个头的查询、键和值。稍微滥用符号,$\mathbf{Q}{i}$、$\mathbf{K}{i}$、$\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ 表示第 $i$ 个头的查询、键和值,$\mathbf{Q}{t}$、$\mathbf{K}{t}$、$\mathbf{V}{t},\in,\mathbb{R}^{h\times d{h}}$ 表示第 $t$ 个 token 的查询、键和值的头。

Throughout the paper, $W^{Q},W^{K},W^{V}$ denote projection matrices for queries, keys, and values, respectively. In multi-head attention, each head is associated with its own set of $W_{i}^{Q},W_{i}^{K},W_{i}^{V}$ , and each has dimension $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ , where $d_{k}$ is typically set to $d_{h}$ , the dimension of each head.5 Similarly, we have an output projection matrix $W^{O}\in\mathbb{R}^{(h\cdot d_{h})\times d_{\mathrm{model}}}$ . For methods like MQA and GQA, some of these are shared or partially shared across heads, but their shapes remain consistent.

在整篇论文中,$W^{Q},W^{K},W^{V}$ 分别表示查询、键和值的投影矩阵。在多头注意力机制中,每个头都与其自己的一组 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}$ 相关联,每个矩阵的维度为 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$,其中 $d_{k}$ 通常设置为 $d_{h}$,即每个头的维度。同样地,我们有一个输出投影矩阵 $W^{O}\in\mathbb{R}^{(h\cdot d_{h})\times d_{\mathrm{model}}}$。对于像 MQA 和 GQA 这样的方法,这些矩阵中的一些在头之间是共享或部分共享的,但它们的形状保持一致。

We define the tensor product of two vectors as follows: for vectors $\mathbf{a}\in\mathbb{R}^{m},\mathbf{b}\in\mathbb{R}^{n}$ , the tensor product of a and $\mathbf{b}$ is:

我们定义两个向量的张量积如下:对于向量 $\mathbf{a}\in\mathbb{R}^{m},\mathbf{b}\in\mathbb{R}^{n}$,$\mathbf{a}$ 和 $\mathbf{b}$ 的张量积为:

\begin{array}{r}{\mathbf{a}\otimes\mathbf{b}=\mathbf{C}\in\mathbb{R}^{m\times n},\mathrm{with~}C_{i j}=a_{i}b_{j},}\end{array}

\begin{array}{r}{\mathbf{a}\otimes\mathbf{b}=\mathbf{C}\in\mathbb{R}^{m\times n},\mathrm{with~}C_{i j}=a_{i}b_{j},}\end{array}

where $a_{i}$ and $b_{j}$ are the $i$ -th and $j$ -th elements of a and $\mathbf{b}$ respectively, and $C_{i j}$ is the $(i,j)$ -th entry of $\mathbf{C}$ . We also define the vector iz ation of a matrix $\mathbf{C}\in\mathbb{R}^{m\times\bar{n}}$ by:

其中 $a_{i}$ 和 $b_{j}$ 分别是向量 $\mathbf{a}$ 和 $\mathbf{b}$ 的第 $i$ 个和第 $j$ 个元素,$C_{i j}$ 是矩阵 $\mathbf{C}$ 的第 $(i,j)$ 个元素。我们还定义了矩阵 $\mathbf{C}\in\mathbb{R}^{m\times\bar{n}}$ 的向量化操作:

\operatorname{vec}(\mathbf{C})=\mathbf{d}\in\mathbb{R}^{m n},\operatorname{with}d_{i\cdot n+j}=C_{i j},
\operatorname{vec}(\mathbf{C})=\mathbf{d}\in\mathbb{R}^{m n},\operatorname{with}d_{i\cdot n+j}=C_{i j},

where $d_{i\cdot n+j}$ is the $(i\cdot n+j)$ -th element of $\mathbf{d}$ .

其中 $d_{i\cdot n+j}$ 是 $\mathbf{d}$ 的第 $(i\cdot n+j)$ 个元素。

2.1 Scaled Dot-Product Attention

2.1 缩放点积注意力 (Scaled Dot-Product Attention)

Scaled dot-product attention (Vaswani et al., 2017) determines how to focus on different parts of an input sequence by comparing queries $(\mathbf{Q})$ and keys $({\bf K})$ . It produces a weighted combination of the

缩放点积注意力机制 (Vaswani et al., 2017) 通过比较查询 $(\mathbf{Q})$ 和键 $({\bf K})$ 来确定如何关注输入序列的不同部分。它生成一个加权组合的

values $({\bf V})$ . Formally, the attention output is:

值 $({\bf V})$。形式上,注意力输出为:

\begin{array}{r}{\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Softmax}\Big(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}\Big)\,\mathbf{V},}\end{array}
\begin{array}{r}{\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Softmax}\Big(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}\Big)\,\mathbf{V},}\end{array}

where each of $\mathbf{Q},\mathbf{K},\mathbf{V}$ is an $(n\times d_{k})$ matrix for $n$ tokens and key dimension $d_{k}$ . The division by $\sqrt{d_{k}}$ stabilizes training by controlling the scale of the inner products.

其中,$\mathbf{Q}$、$\mathbf{K}$、$\mathbf{V}$ 都是 $(n\times d_{k})$ 的矩阵,$n$ 表示 token 的数量,$d_{k}$ 表示键的维度。除以 $\sqrt{d_{k}}$ 通过控制内积的规模来稳定训练。

2.2 Multi-Head Attention (MHA)

2.2 多头注意力机制 (Multi-Head Attention, MHA)

Multi-Head Attention (MHA) extends scaled dot-product attention by dividing the model’s internal representation into several heads. Each head learns different projections for queries, keys, and values, allowing the model to attend to different types of information. For each token embedding $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ , MHA computes each head $i$ as follows:

多头注意力机制 (Multi-Head Attention, MHA) 通过将模型的内部表示划分为多个头来扩展缩放点积注意力。每个头为查询、键和值学习不同的投影,使模型能够关注不同类型的信息。对于每个 Token 嵌入 $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$,MHA 计算每个头 $i$ 如下:

\begin{array}{r l r}&{\mathbf{Q}_{t,i}=(W_{i}^{Q})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},\quad\mathbf{K}_{t,i}=(W_{i}^{K})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},\quad\mathbf{V}_{t,i}=(W_{i}^{V})^{\top}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},}&\\ &{\qquad\qquad\qquad\qquad\mathrm{head}_{i}=\mathrm{Attention}\Bigl(\mathbf{Q}_{i},\mathbf{K}_{i},{\mathbf{V}}_{i}\Bigr),}\end{array}```

\begin{array}{r l r}&{\mathbf{Q}{t,i}=(W{i}^{Q})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},\quad\mathbf{K}{t,i}=(W{i}^{K})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},\quad\mathbf{V}{t,i}=(W{i}^{V})^{\top},\mathbf{x}{t}\in\mathbb{R}^{d{h}},}&\ &{\qquad\qquad\qquad\qquad\mathrm{head}{i}=\mathrm{Attention}\Bigl(\mathbf{Q}{i},\mathbf{K}{i},{\mathbf{V}}{i}\Bigr),}\end{array}```

where $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\quad\in\quad\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ are learnable projection matrices for the $i$ -th head, $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i},\in,\mathbb{R}^{T\times d{h}}$ . After computing each head’s attention, the outputs are concatenated and mapped back to the original dimension via another matrix $W^{O}\in\mathbb{R}^{h d_{h}\times\dot{d}_{\mathrm{model}}}$ :

其中 $W_{i}^{Q},W_{i}^{K},W_{i}^{V}\quad\in\quad\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ 是可学习的投影矩阵,用于第 $i$ 个头,$\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i},\in,\mathbb{R}^{T\times d{h}}$。在计算每个头的注意力后,输出被拼接并通过另一个矩阵 $W^{O}\in\mathbb{R}^{h d_{h}\times\dot{d}_{\mathrm{model}}}$ 映射回原始维度:

\mathrm{MHA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}_{1},\dots,\mathbf{head}_{h}\big)\,W^{O}.```

\mathrm{MHA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}{1},\dots,\mathbf{head}{h}\big),W^{O}.```

MHA can capture a rich set of dependencies while each head focuses on different subspaces.

MHA 可以捕获丰富的依赖关系,而每个头专注于不同的子空间。

2.3 Multi-Query Attention (MQA)

2.3 多查询注意力机制 (Multi-Query Attention, MQA)

Multi-Query Attention (MQA) (Shazeer, 2019) significantly reduces memory usage by sharing keys and values across heads, while still preserving unique query projections. For a sequence of embeddings $\mathbf{X}\in\mathbb{R}^{T\times d_{\mathrm{model}}}$ ,

多查询注意力 (Multi-Query Attention, MQA) (Shazeer, 2019) 通过在多个头之间共享键和值,显著减少了内存使用,同时仍保留了唯一的查询投影。对于嵌入序列 $\mathbf{X}\in\mathbb{R}^{T\times d_{\mathrm{model}}}$,

\mathbf{Q}_{i}=\mathbf{X}W_{i}^{Q},\quad\mathbf{K}_{\mathrm{shared}}=\mathbf{X}W_{\mathrm{shared}}^{K},\quad\mathbf{V}_{\mathrm{shared}}=\mathbf{X}W_{\mathrm{shared}}^{V}.```

\mathbf{Q}{i}=\mathbf{X}W{i}^{Q},\quad\mathbf{K}{\mathrm{shared}}=\mathbf{X}W{\mathrm{shared}}^{K},\quad\mathbf{V}{\mathrm{shared}}=\mathbf{X}W{\mathrm{shared}}^{V}.```

Hence, each head $i$ only has a distinct query $\mathbf{Q}{i}\in\mathbb{R}^{T\times d{k}}$ , but shares the same key $\mathbf{K}{\mathrm{shared}}\in\mathbb{R}^{T\times d{k}}$ and value $\mathbf{V}{\mathrm{shared}}\in\mathbb{R}^{\check{T}\times d{k}}$ . In practice, this means:

因此,每个头 $i$ 仅有一个独特的查询 $\mathbf{Q}{i}\in\mathbb{R}^{T\times d{k}}$,但共享相同的键 $\mathbf{K}{\mathrm{shared}}\in\mathbb{R}^{T\times d{k}}$ 和值 $\mathbf{V}{\mathrm{shared}}\in\mathbb{R}^{\check{T}\times d{k}}$。在实践中,这意味着:

\pmb{W}_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}},\quad\pmb{W}_{\mathrm{shared}}^{K},\pmb{W}_{\mathrm{shared}}^{V}\;\in\;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}.```

\pmb{W}{i}^{Q}\in\mathbb{R}^{d{\mathrm{model}}\times d_{k}},\quad\pmb{W}{\mathrm{shared}}^{K},\pmb{W}{\mathrm{shared}}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}.```

The resulting MQA operation is:

生成的 MQA 操作为:

\operatorname{MQA}(\mathbf{X})=\operatorname{Concat}\!\left(\mathbf{head}_{1},\dots,\mathbf{head}_{h}\right)W^{O},```

\operatorname{MQA}(\mathbf{X})=\operatorname{Concat}!\left(\mathbf{head}{1},\dots,\mathbf{head}{h}\right)W^{O},```

where

其中

\mathbf{head}_{i}=\operatorname{Attention}\!\left(\mathbf{Q}_{i},\mathbf{K}_{\mathrm{shared}},\mathbf{V}_{\mathrm{shared}}\right)\!.```

\mathbf{head}{i}=\operatorname{Attention}!\left(\mathbf{Q}{i},\mathbf{K}{\mathrm{shared}},\mathbf{V}{\mathrm{shared}}\right)!.```

By sharing these key and value projections, MQA cuts down on memory usage (especially for the key-value cache in auto regressive inference) but loses some expressivity since all heads must rely on the same key/value representations.

通过共享这些键和值投影,MQA 减少了内存使用(特别是在自回归推理中的键值缓存),但由于所有头必须依赖相同的键/值表示,因此失去了一些表达能力。

2.4 Grouped Query Attention (GQA)

2.4 分组查询注意力机制 (Grouped Query Attention, GQA)

Grouped Query Attention (GQA) (Ainslie et al., 2023) generalizes MHA and MQA by grouping heads. Specifically, we partition the $h$ total heads into $G$ groups. Each group has a single set of keys and values, but each individual head within that group still retains its own query projection. Formally, if $g(i)$ maps a head $i\in[h]$ to its group index $g\in[G]$ , then:

分组查询注意力 (Grouped Query Attention, GQA) (Ainslie et al., 2023) 通过将头部分组来泛化多头注意力 (MHA) 和多查询注意力 (MQA)。具体来说,我们将总头数 $h$ 划分为 $G$ 组。每组共享一组键和值,但组内的每个头仍然保留自己的查询投影。形式上,如果 $g(i)$ 将头 $i\in[h]$ 映射到其组索引 $g\in[G]$,则:

\mathbf{K}_{g(i)}=\mathbf{X}\,W_{g(i)}^{K},\quad\mathbf{V}_{g(i)}=\mathbf{X}\,W_{g(i)}^{V},\quad\mathbf{Q}_{i}=\mathbf{X}\,W_{i}^{Q},```

\mathbf{K}{g(i)}=\mathbf{X},W{g(i)}^{K},\quad\mathbf{V}{g(i)}=\mathbf{X},W{g(i)}^{V},\quad\mathbf{Q}{i}=\mathbf{X},W{i}^{Q},```

and

\mathrm{head}_{i}=\mathrm{Attention}\Big(\mathbf{Q}_{i},\mathbf{K}_{g(i)},\mathbf{V}_{g(i)}\Big).```

\mathrm{head}{i}=\mathrm{Attention}\Big(\mathbf{Q}{i},\mathbf{K}{g(i)},\mathbf{V}{g(i)}\Big).```

Again, ${\cal W}{g}^{K},{\cal W}{g}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ for each group $g$ , and $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$ for each head $i$ . The complete output is again a concatenation of all heads:

再次,对于每个组 $g$,${\cal W}{g}^{K},{\cal W}{g}^{V};\in;\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$,而对于每个头 $i$,$W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{k}}$。完整的输出再次是所有头的拼接:

\operatorname{GQA}(\mathbf{X})=\operatorname{Concat}\!\left(\operatorname{head}_{1},\dots,\operatorname{head}_{h}\right)\!W^{O}.```

\operatorname{GQA}(\mathbf{X})=\operatorname{Concat}!\left(\operatorname{head}{1},\dots,\operatorname{head}{h}\right)!W^{O}.```

By adjusting $G$ between 1 and $h$ , GQA can interpolate between sharing all key/value projections across heads (i.e., MQA) and having one set of projections per head (i.e., MHA).

通过调整 $G$ 在 1 到 $h$ 之间的值,GQA 可以在多头之间共享所有键/值投影(即 MQA)和每个头拥有一组投影(即 MHA)之间进行插值。

2.5 Rotary Position Embedding (RoPE)

2.5 旋转位置编码 (Rotary Position Embedding, RoPE)

Many recent LLMs use rotary position embedding (RoPE; $\mathrm{Su}$ et al., 2024b) to encode positional information in the query/key vectors. Specifically, let $\mathrm{RoPE}{t}$ denote the rotation operator $\mathbf{T}{t}\mathbf{\Sigma}\in$ $\mathbb{R}^{d_{h}\times d_{h}}$ corresponding to the $t$ -th position. $\mathbf{T}{t}$ is a block-diagonal matrix, which consists of blockdiagonal matrix $\cdot\biggl(\cos(t\theta{j}),,,,,,-\sin(t\theta_{j})\biggr),j\in{1,\cdots,d_{h}/2}$ , where ${\theta_{j}}$ are pre-defined frequency parameters, e.g., $\theta_{j}=1/10000^{2j/d_{h}}$ . Then we define

许多最近的大语言模型使用旋转位置嵌入 (RoPE; Su 等, 2024b) 来在查询/键向量中编码位置信息。具体来说,令 $\mathrm{RoPE}{t}$ 表示对应于第 $t$ 个位置的旋转算子 $\mathbf{T}{t}\mathbf{\Sigma}\in$ $\mathbb{R}^{d_{h}\times d_{h}}$。$\mathbf{T}{t}$ 是一个块对角矩阵,由块对角矩阵 $\cdot\biggl(\cos(t\theta{j}),,,,,,-\sin(t\theta_{j})\biggr),j\in{1,\cdots,d_{h}/2}$ 组成,其中 ${\theta_{j}}$ 是预定义的频率参数,例如 $\theta_{j}=1/10000^{2j/d_{h}}$。然后我们定义

\mathrm{RoPE}\left(\mathbf{Q}_{t}\right)\triangleq\mathbf{Q}_{t}\mathbf{T}_{t},\quad\mathrm{where~}\mathbf{Q}_{t}\in\mathbb{R}^{h\times d_{h}}.```

\mathrm{RoPE}\left(\mathbf{Q}{t}\right)\triangleq\mathbf{Q}{t}\mathbf{T}{t},\quad\mathrm{where~}\mathbf{Q}{t}\in\mathbb{R}^{h\times d_{h}}.```

A fundamental property is that

一个基本属性是

\mathbf{T}_{t}\,\mathbf{T}_{s}^{\top}=\mathbf{T}_{t-s},```

\mathbf{T}{t},\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s},```

which ensures that relative positions $(t-s)$ are preserved, thereby providing a form of translation invariance in the rotary position embedding.

这确保了相对位置 $(t-s)$ 得以保留,从而在旋转位置嵌入中提供了一种平移不变性。

2.6 Multi-head Latent Attention (MLA)

2.6 多头潜在注意力机制 (Multi-head Latent Attention, MLA)

Below, we briefly outline the Multi-head Latent Attention (MLA) approach used by DeepSeekV2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b). MLA introduces a low-rank compression of the keys and values to reduce the Key-Value (KV) caching cost at inference.

下面,我们简要概述了 DeepSeekV2 (Liu et al., 2024a) 和 DeepSeek-V3 (Liu et al., 2024b) 使用的多头潜在注意力 (Multi-head Latent Attention, MLA) 方法。MLA 引入了键和值的低秩压缩,以减少推理时的键值 (Key-Value, KV) 缓存成本。

\begin{array}{r l r}&{}&{{\mathbf{C}}^{K V}={\mathbf{X}}{\boldsymbol{W}}^{D K V},\quad({\boldsymbol{W}}^{D K V}\in\mathbb{R}^{d_{m o d1}\times d_{c}}),}\\ &{}&{\mathrm{Concat}\big({\mathbf{K}}_{1}^{C},{\mathbf{K}}_{2}^{C},\dots,{\mathbf{K}}_{h}^{C}\big)={\mathbf{K}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U K},\quad({\boldsymbol{W}}^{U K}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\\ &{}&{{\mathbf{K}}^{R}=\mathrm{RoPE}\big({\mathbf{X}}{\boldsymbol{W}}^{K R}\big),\quad({\boldsymbol{W}}^{K R R}\in\mathbb{R}^{d_{m o d1}\times d_{h}^{R}}),}\\ &{}&{{\mathbf{K}}_{i}=\mathrm{Concat}\big({\mathbf{K}}_{i}^{C},{\mathbf{K}}^{R}\big),}\\ &{}&{\mathrm{Concat}\big({\mathbf{V}}_{1}^{C},{\mathbf{V}}_{2}^{C},\dots,{\mathbf{V}}_{h}^{C}\big)={\mathbf{V}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U V},\quad({\boldsymbol{W}}^{U V}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\end{array}```

\begin{array}{r l r}&{}&{{\mathbf{C}}^{K V}={\mathbf{X}}{\boldsymbol{W}}^{D K V},\quad({\boldsymbol{W}}^{D K V}\in\mathbb{R}^{d_{m o d1}\times d_{c}}),}\ &{}&{\mathrm{Concat}\big({\mathbf{K}}{1}^{C},{\mathbf{K}}{2}^{C},\dots,{\mathbf{K}}_{h}^{C}\big)={\mathbf{K}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U K},\quad({\boldsymbol{W}}^{U K}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\ &{}&{{\mathbf{K}}^{R}=\mathrm{RoPE}\big({\mathbf{X}}{\boldsymbol{W}}^{K R}\big),\quad({\boldsymbol{W}}^{K R R}\in\mathbb{R}^{d_{m o d1}\times d_{h}^{R}}),}\ &{}&{{\mathbf{K}}_{i}=\mathrm{Concat}\big({\mathbf{K}}_{i}^{C},{\mathbf{K}}^{R}\big),}\ &{}&{\mathrm{Concat}\big({\mathbf{V}}_{1}^{C},{\mathbf{V}}_{2}^{C},\dots,{\mathbf{V}}_{h}^{C}\big)={\mathbf{V}}^{C}={\mathbf{C}}^{K V}{\boldsymbol{W}}^{U V},\quad({\boldsymbol{W}}^{U V}\in\mathbb{R}^{d_{c}\times d_{h}h}),}\end{array}```

where $\mathbf{C}^{K V}\in\mathbb{R}^{T\times d_{c}}$ is the compressed KV latent (with $d_{c}\ll d_{h}h)$ ), and $\mathrm{{RoPE}(\cdot)}$ represents the RoPE transform applied to the separate key embeddings $\mathbf{K}^{R}$ of dimension $d_{h}^{R}$ . Thus, only $\mathbf{C}^{K V}$ and $\mathbf{K}^{R}$ need to be cached, reducing KV memory usage while largely preserving performance compared to standard MHA (Vaswani et al., 2017).

其中 $\mathbf{C}^{K V}\in\mathbb{R}^{T\times d_{c}}$ 是压缩的 KV 潜在表示(其中 $d_{c}\ll d_{h}h)$),$\mathrm{{RoPE}(\cdot)}$ 表示应用于维度为 $d_{h}^{R}$ 的独立键嵌入 $\mathbf{K}^{R}$ 的 RoPE 变换。因此,只需要缓存 $\mathbf{C}^{K V}$ 和 $\mathbf{K}^{R}$,从而在减少 KV 内存使用的同时,与标准 MHA (Vaswani et al., 2017) 相比,性能基本保持不变。

MLA also compresses the queries, lowering their training-time memory footprint:

MLA 还压缩了查询,降低了它们在训练时的内存占用:

\begin{array}{r l r}&{}&{{\mathbf{C}}^{Q}={\mathbf{X}}W^{D Q},\quad(W^{D Q}\in\mathbb{R}^{d_{\mathrm{modt}}\times d_{c}^{\prime}}),\quad}\\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{C},{\mathbf{Q}}_{2}^{C},\,.\,.\,.\,,{\mathbf{Q}}_{h}^{C}\bigr)={\mathbf{Q}}^{C}={\mathbf{C}}^{Q}W^{U Q},\quad(W^{U Q}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}h}),\quad}\\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{R},\,{\mathbf{Q}}_{2}^{R},\,.\,.\,.\,,\,{\mathbf{Q}}_{h}^{R}\bigr)={\mathbf{Q}}^{R}=\mathrm{RoPE}\bigl({\mathbf{C}}^{Q}W^{Q R}\bigr),\quad(W^{Q R}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}^{R}h}),}\\ &{}&{\quad{\mathbf{Q}}=\mathrm{Concat}\bigl({\mathbf{Q}}^{C},{\mathbf{Q}}^{R}\bigr).}\end{array}```

\begin{array}{r l r}&{}&{{\mathbf{C}}^{Q}={\mathbf{X}}W^{D Q},\quad(W^{D Q}\in\mathbb{R}^{d_{\mathrm{modt}}\times d_{c}^{\prime}}),\quad}\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}{1}^{C},{\mathbf{Q}}{2}^{C},,.,.,.,,{\mathbf{Q}}_{h}^{C}\bigr)={\mathbf{Q}}^{C}={\mathbf{C}}^{Q}W^{U Q},\quad(W^{U Q}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}h}),\quad}\ &{}&{\quad\quad\mathrm{Concat}\bigl({\mathbf{Q}}_{1}^{R},,{\mathbf{Q}}_{2}^{R},,.,.,.,,,{\mathbf{Q}}_{h}^{R}\bigr)={\mathbf{Q}}^{R}=\mathrm{RoPE}\bigl({\mathbf{C}}^{Q}W^{Q R}\bigr),\quad(W^{Q R}\in\mathbb{R}^{d_{c}^{\prime}\times d_{h}^{R}h}),}\ &{}&{\quad{\mathbf{Q}}=\mathrm{Concat}\bigl({\mathbf{Q}}^{C},{\mathbf{Q}}^{R}\bigr).}\end{array}```

Here, $\mathbf{C}^{Q}\in\mathbb{R}^{T\times d_{c}^{\prime}}$ (with $d_{c}^{\prime}\ll d_{h}h)$ is the compressed query latent. As above, each ${\pmb W}^{D Q},{\pmb W}^{U Q}$ , and $W^{Q R}$ connects these lower-dimensional query latents back to $h$ heads of dimension $d_{h}+d_{h}^{R}$ .

这里,$\mathbf{C}^{Q}\in\mathbb{R}^{T\times d_{c}^{\prime}}$(其中$d_{c}^{\prime}\ll d_{h}h)$是压缩后的查询潜在表示。如上所述,每个${\pmb W}^{D Q},{\pmb W}^{U Q}$和$W^{Q R}$将这些低维查询潜在表示连接回维度为$d_{h}+d_{h}^{R}$的$h$个头。

Given compressed queries, keys, and values, the final attention output for the $t$ -th token is:

给定压缩后的查询、键和值,第 $t$ 个 Token 的最终注意力输出为:

\begin{array}{r l}&{\mathbf{O}_{i}=\mathrm{Softmax}\Big(\frac{\mathbf{Q}_{i}\mathbf{K}_{i}^{\top}}{\sqrt{d_{h}+d_{h}^{R}}}\Big)\ \mathbf{V}_{i}^{C},}\\ &{\mathbf{U}=\mathrm{Concat}\big(\mathbf{O}_{1},\mathbf{O}_{2},...\,,\mathbf{O}_{h}\big)W^{O},}\end{array}```

\begin{array}{r l}&{\mathbf{O}{i}=\mathrm{Softmax}\Big(\frac{\mathbf{Q}{i}\mathbf{K}{i}^{\top}}{\sqrt{d{h}+d_{h}^{R}}}\Big)\ \mathbf{V}{i}^{C},}\ &{\mathbf{U}=\mathrm{Concat}\big(\mathbf{O}{1},\mathbf{O}{2},...,,\mathbf{O}{h}\big)W^{O},}\end{array}```

where $W^{O}\in\mathbb{R}^{(d_{h}h)\times{d_{\mathrm{model}}}}$ is the output projection.

其中 $W^{O}\in\mathbb{R}^{(d_{h}h)\times{d_{\mathrm{model}}}}$ 是输出投影。

In inference time, $\mathbf{C}^{K V}$ and $\mathbf{K}^{R}$ can be cached to accelerate decoding. In detail, when RoPE is ignored, the inner product $\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}$ (where $\mathbf{q}{t,i},\mathbf{k}{s,i},\in,\mathbb{R}^{d})$ of the $i$ -th head between $t$ -th and $s$ -th tokens can be calculated using the hidden state $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ for $t$ -th token and the cached latent state $\mathbf{c}{s}^{K V}\in\mathbb{R}^{d{c}}$ for $s$ -th token:

在推理时,$\mathbf{C}^{K V}$ 和 $\mathbf{K}^{R}$ 可以被缓存以加速解码。具体来说,当忽略 RoPE 时,第 $i$ 个头在第 $t$ 个和第 $s$ 个 Token 之间的内积 $\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}$(其中 $\mathbf{q}{t,i},\mathbf{k}{s,i},\in,\mathbb{R}^{d}$)可以通过第 $t$ 个 Token 的隐藏状态 $\mathbf{x}{t}\in\mathbb{R}^{d{\mathrm{model}}}$ 和第 $s$ 个 Token 的缓存潜在状态 $\mathbf{c}{s}^{K V}\in\mathbb{R}^{d{c}}$ 来计算:

\mathbf{q}_{t,i}^{\top}\mathbf{k}_{s,i}=[(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}_{t}]^{\top}[(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V},```

\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}=[(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}{t}]^{\top}[(W{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V},```

where ${W}{i}^{(\cdot)}$ is the $i^{\th}$ -th head of the original weight, and $[\pmb{W}{i}^{D Q}\pmb{W}{i}^{U Q}(\pmb{W}{i}^{U K})^{\top}]$ can be computed previously for faster decoding. However, this process fails when RoPE is considered according to Su (2024). Since RoPE can be considered as multiplication with a block-diagonal matrix $\mathbf{T}{t}\in\mathbb{R}^{d{h}\times d_{h}}$ (see Section 2.5), with the property (2.1) that $\dot{\mathbf{T}}{t}\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s}$ , then

其中 ${W}{i}^{(\cdot)}$ 是原始权重的第 $i^{\th}$ 个头,且 $[\pmb{W}{i}^{D Q}\pmb{W}{i}^{U Q}(\pmb{W}{i}^{U K})^{\top}]$ 可以预先计算以加快解码速度。然而,根据 Su (2024) 的研究,当考虑 RoPE 时,这一过程会失败。由于 RoPE 可以被视为与块对角矩阵 $\mathbf{T}{t}\in\mathbb{R}^{d{h}\times d_{h}}$ 的乘法(见第 2.5 节),并且具有性质 (2.1) $\dot{\mathbf{T}}{t}\mathbf{T}{s}^{\top}=\mathbf{T}_{t-s}$,那么

\begin{array}{r l}&{\mathbf{q}_{t,i}^{\top}\mathbf{k}_{s,i}=[{\mathbf{T}_{t}}^{\top}(W_{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}_{t}]^{\top}[{\mathbf{T}_{s}}^{\top}(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]}\\ &{\qquad\qquad=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}\mathbf{T}_{t-s}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V}.}\end{array}```

\begin{array}{r l}&{\mathbf{q}{t,i}^{\top}\mathbf{k}{s,i}=[{\mathbf{T}{t}}^{\top}(W{i}^{U Q})^{\top}(W_{i}^{D Q})^{\top}\mathbf{x}{t}]^{\top}[{\mathbf{T}{s}}^{\top}(W_{i}^{U K})^{\top}\mathbf{c}_{s}^{K V}]}\ &{\qquad\qquad=\mathbf{x}_{t}^{\top}[W_{i}^{D Q}W_{i}^{U Q}\mathbf{T}_{t-s}(W_{i}^{U K})^{\top}]\mathbf{c}_{s}^{K V}.}\end{array}```

Different from (2.2), acceleration by pre-computing $[W_{i}^{D Q}W_{i}^{U Q}{\bf T}{t-s}(W{i}^{U K})^{\top}]$ fails since it varies for different $(\boldsymbol{t},\boldsymbol{s})$ position pairs. Therefore, MLA adds the additional $\mathbf{k}_{t}^{R}$ part with a relatively smaller size for RoPE compatibility. In Section 3.2, we will show that TPA addresses the issue of RoPE-incompatibility by applying tensor product.

与 (2.2) 不同,通过预计算 $[W_{i}^{D Q}W_{i}^{U Q}{\bf T}{t-s}(W{i}^{U K})^{\top}]$ 来加速的方法失效了,因为它会因不同的 $(\boldsymbol{t},\boldsymbol{s})$ 位置对而变化。因此,MLA 添加了额外的 $\mathbf{k}_{t}^{R}$ 部分,其大小相对较小,以兼容 RoPE。在第 3.2 节中,我们将展示 TPA 如何通过应用张量积来解决 RoPE 不兼容的问题。

3 Tensor Product Attention

3 张量积注意力 (Tensor Product Attention)

In this section, we provide a detailed description of our proposed Tensor Product Attention (TPA), which allows contextual low-rank factorization for queries, keys, and values. First, we explain how TPA factorizes queries, keys, and values with explicit tensor shapes. Next, we describe how TPA can be integrated into the multi-head attention framework and how it reduces memory consumption in KV caching at inference time. Finally, we show how RoPE can seamlessly integrate with TPA (including a pre-rotated variant).

在本节中,我们详细介绍了我们提出的张量积注意力机制 (Tensor Product Attention, TPA),它允许对查询、键和值进行上下文低秩分解。首先,我们解释了 TPA 如何通过显式的张量形状对查询、键和值进行分解。接着,我们描述了 TPA 如何集成到多头注意力框架中,以及它如何在推理时减少 KV 缓存的内存消耗。最后,我们展示了 RoPE 如何与 TPA 无缝集成(包括预旋转变体)。

3.1 Tensor Factorization of Queries, Keys, and Values

3.1 查询、键和值的张量分解

Let $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}}$ for $t,=,1,\ldots,T$ be the hidden-state vector corresponding to the $t$ -th token in a sequence of length $T$ . A typical multi-head attention block has $h$ heads, each of dimension $d_{h}$ , satisfying $d_{\mathrm{model}}~=~h~\times~d_{h}$ . Standard attention projects the entire sequence into three tensors, Q, K, $\bar{\textbf{V}}\in\mathbb{R}^{T\times h\times d_{h}}$ ,× where $\mathbf{Q}_{t},\mathbf{K}_{t},\mathbf{V}_{t}\in\mathbb{R}^{h\times\mathbf{\hat{\boldsymbol{d}}}_{h}}$ denote the slices for the $t$ -th token.

设 $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}}$ 对于 $t,=,1,\ldots,T$ 是长度为 $T$ 的序列中第 $t$ 个 Token 对应的隐藏状态向量。一个典型的多头注意力块有 $h$ 个头,每个头的维度为 $d_{h}$,满足 $d_{\mathrm{model}}~=~h~\times~d_{h}$。标准注意力将整个序列投影到三个张量 Q、K 和 $\bar{\textbf{V}}\in\mathbb{R}^{T\times h\times d_{h}}$ 中,其中 $\mathbf{Q}_{t},\mathbf{K}_{t},\mathbf{V}_{t}\in\mathbb{R}^{h\times\mathbf{\hat{\boldsymbol{d}}}_{h}}$ 表示第 $t$ 个 Token 的切片。

Contextual Factorization $\mathbf{(CF)}$ . Instead of forming each head’s query, key, or value via a single linear map, TPA factorizes each $\mathbf{Q}{t},\mathbf{K}{t},\mathbf{V}{t}$ into a sum of (contextual) tensor products whose ranks are $R{q},R_{k}$ , and $R_{v}$ , respectively and may differ. Specifically, for each token $t$ , with a small abuse of notation, we define:

上下文分解 (Contextual Factorization, CF)。TPA 不是通过单一的线性映射来形成每个头的查询、键或值,而是将每个 $\mathbf{Q}{t},\mathbf{K}{t},\mathbf{V}{t}$ 分解为秩分别为 $R{q},R_{k}$ 和 $R_{v}$ 的(上下文)张量积之和,这些秩可能不同。具体来说,对于每个 token $t$,我们定义(稍微滥用符号):

\mathbf{Q}_{t}=\frac{1}{R_{Q}}\sum_{r=1}^{R_{Q}}\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{Q}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}},```

\mathbf{Q}{t}=\frac{1}{R{Q}}\sum_{r=1}^{R_{Q}}\mathbf{a}{r}^{Q}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{Q}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{Q}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}},```

\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{K}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{K}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}},```

\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{K}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{K}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{K}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}},```

\mathbf{V}_{t}=\frac{1}{R_{V}}\sum_{r=1}^{R_{V}}\mathbf{a}_{r}^{V}(\mathbf{x}_{t})\;\otimes\;\mathbf{b}_{r}^{V}(\mathbf{x}_{t}),\qquad\qquad\mathbf{a}_{r}^{V}(\mathbf{x}_{t})\in\mathbb{R}^{h},\;\mathbf{b}_{r}^{V}(\mathbf{x}_{t})\in\mathbb{R}^{d_{h}}.```

\mathbf{V}{t}=\frac{1}{R{V}}\sum_{r=1}^{R_{V}}\mathbf{a}{r}^{V}(\mathbf{x}{t});\otimes;\mathbf{b}{r}^{V}(\mathbf{x}{t}),\qquad\qquad\mathbf{a}{r}^{V}(\mathbf{x}{t})\in\mathbb{R}^{h},;\mathbf{b}{r}^{V}(\mathbf{x}{t})\in\mathbb{R}^{d_{h}}.```

Hence, for queries, each tensor product $\mathbf{a}{r}^{Q}(\mathbf{x}{t})\otimes\mathbf{b}{r}^{Q}(\mathbf{x}{t})\colon\mathbb{R}^{h}\times\mathbb{R}^{d_{h}}\to\mathbb{R}^{h\times d_{h}}$ adds up to form the query slice $\mathbf{Q}{t}\in\mathbb{R}^{h\times d{h}}$ . Similarly, analogous definitions apply to key slice $\mathbf{K}{t}$ and value slice $\mathbf{V}{t}$ .

因此,对于查询,每个张量积 $\mathbf{a}{r}^{Q}(\mathbf{x}{t})\otimes\mathbf{b}{r}^{Q}(\mathbf{x}{t})\colon\mathbb{R}^{h}\times\mathbb{R}^{d_{h}}\to\mathbb{R}^{h\times d_{h}}$ 相加形成查询切片 $\mathbf{Q}{t}\in\mathbb{R}^{h\times d{h}}$。类似地,相同的定义适用于键切片 $\mathbf{K}{t}$ 和值切片 $\mathbf{V}{t}$。

Latent Factor Maps. Each factor in the tensor product depends on the token’s hidden state $\mathbf{x}_{t}$ . For example, for queries, we can write:

潜在因子映射。张量积中的每个因子都依赖于 Token 的隐藏状态 $\mathbf{x}_{t}$。例如,对于查询,我们可以写成:

\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{a^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{h},\quad\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{b^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},```

\mathbf{a}{r}^{Q}(\mathbf{x}{t})=W_{r}^{a^{Q}},\mathbf{x}{t}\in\mathbb{R}^{h},\quad\mathbf{b}{r}^{Q}(\mathbf{x}{t})=W{r}^{b^{Q}},\mathbf{x}{t}\in\mathbb{R}^{d{h}},```

and similarly for keys and values.

同样适用于键和值。

One often merges the rank index into a single output dimension. For instance, for queries:

通常会将秩索引合并为单一的输出维度。例如,对于查询:

\mathbf{a}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{a^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{R_{q}\cdot h},\quad\mathbf{b}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{b^{Q}}\,\mathbf{x}_{t}\in\mathbb{R}^{R_{q}\cdot d_{h}},```

\mathbf{a}^{Q}(\mathbf{x}{t})=\boldsymbol{W}^{a^{Q}},\mathbf{x}{t}\in\mathbb{R}^{R_{q}\cdot h},\quad\mathbf{b}^{Q}(\mathbf{x}{t})=\boldsymbol{W}^{b^{Q}},\mathbf{x}{t}\in\mathbb{R}^{R_{q}\cdot d_{h}},```

which are then reshaped into $\mathbf{A}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times h}$ and $\mathbf{B}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times d_{h}}$ . Summing over $R_{q}$ and scaled by $\frac{1}{R_{q}}$ yields

然后将其重塑为 $\mathbf{A}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times h}$ 和 $\mathbf{B}{Q}(\mathbf{x}{t}),\in,\mathbb{R}^{R_{q}\times d_{h}}$。对 $R_{q}$ 求和并通过 $\frac{1}{R_{q}}$ 进行缩放得到

\mathbf{Q}_{t}=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\mathbf{B}_{Q}(\mathbf{x}_{t})\in\mathbb{R}^{h\times d_{h}}.```

\mathbf{Q}{t}=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{h\times d_{h}}.```

Repeating for all tokens reconstitutes $\mathbf{Q}\in\mathbb{R}^{T\times h\times d_{h}}$ . Similar procedures can be applied to obtain $\mathbf{K}$ and $\mathbf{V}$ with ranks $R_{k}$ and $R_{v}$ , respectively.

对所有 Token 重复上述过程,即可重构 $\mathbf{Q}\in\mathbb{R}^{T\times h\times d_{h}}$。类似的过程可以应用于获取 $\mathbf{K}$ 和 $\mathbf{V}$,其秩分别为 $R_{k}$ 和 $R_{v}$。

Scaled Dot-Product Attention. Once $\mathbf{Q},\mathbf{K},\mathbf{V}$ are factorized, multi-head attention proceeds as in standard Transformers. For each head $i\in{1,\ldots,h}$ :

缩放点积注意力 (Scaled Dot-Product Attention)。一旦 $\mathbf{Q},\mathbf{K},\mathbf{V}$ 被分解,多头注意力 (multi-head attention) 就会像标准 Transformer 中那样进行。对于每个头 $i\in{1,\ldots,h}$:

\begin{array}{r}{{\bf h e a d}_{i}=\mathrm{Softmax}\Big(\frac{1}{\sqrt{d_{h}}}\,{\bf Q}_{i}\,({\bf K}_{i})^{\top}\Big)\,{\bf V}_{i},}\end{array}```

\begin{array}{r}{{\bf h e a d}{i}=\mathrm{Softmax}\Big(\frac{1}{\sqrt{d{h}}},{\bf Q}{i},({\bf K}{i})^{\top}\Big),{\bf V}_{i},}\end{array}```

where $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ are the slices along the head dimension. Concatenating these $h$ heads along the last dimension yields an $\mathbb{R}^{T\times(h\cdot d_{h})}$ tensor, which is projected back to $\mathbb{R}^{T\times d_{\mathrm{model}}}$ by an output weight matrix W O ∈R(h·dh)×dmodel:

其中 $\mathbf{Q}{i},\mathbf{K}{i},\mathbf{V}{i}\in\mathbb{R}^{T\times d{h}}$ 是沿头维度切片的张量。将这些 $h$ 个头沿最后一个维度拼接,得到一个 $\mathbb{R}^{T\times(h\cdot d_{h})}$ 的张量,然后通过输出权重矩阵 W O ∈R(h·dh)×dmodel 将其投影回 $\mathbb{R}^{T\times d_{\mathrm{model}}}$:

\mathrm{TPA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}_{1},...\,,\mathbf{head}_{h}\big)W^{O}.```

\mathrm{TPA}(\mathbf{Q},\mathbf{K},\mathbf{V})=\mathrm{Concat}\big(\mathbf{head}{1},...,,\mathbf{head}{h}\big)W^{O}.```

Parameter Initialization. We initialize the weight matrices W raQ, W raK, W raV, W rb Q, W rbK, $\boldsymbol{W_{r}^{b}}^{V}$ using Xavier initialization (Glorot & Bengio, 2010). Specifically, each entry of the weight matrix is drawn from a uniform distribution with bounds $[-\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})},\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})}]$ , where $n_{\mathrm{in}}$ and $n_{\mathrm{out}}$ are the input and output dimensions of the respective weight matrices. This initialization strategy helps maintain the variance of activation s and gradients across the network.

参数初始化。我们使用 Xavier 初始化 (Glorot & Bengio, 2010) 来初始化权重矩阵 W raQ, W raK, W raV, W rb Q, W rbK, $\boldsymbol{W_{r}^{b}}^{V}$。具体来说,权重矩阵的每个元素都是从均匀分布中抽取的,其范围为 $[-\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})},\sqrt{6/(n_{\mathrm{in}}+n_{\mathrm{out}})}]$,其中 $n_{\mathrm{in}}$ 和 $n_{\mathrm{out}}$ 分别是相应权重矩阵的输入和输出维度。这种初始化策略有助于保持网络中激活值和梯度的方差。

3.2 RoPE Compatibility and Acceleration

3.2 RoPE 兼容性与加速

In a typical workflow of adding RoPE to standard multi-head attention, one first computes $\mathbf{Q}{t},\mathbf{K}{s}\in$ $\mathbb{R}^{h\times\check{d}_{h}^{\textbf{\textbf{1}}}}$ of the $t$ -th token and $s$ -th token and then applies:

在将RoPE添加到标准多头注意力机制的典型工作流程中,首先计算第$t$个Token和第$s$个Token的$\mathbf{Q}{t},\mathbf{K}{s}\in$ $\mathbb{R}^{h\times\check{d}_{h}^{\textbf{\textbf{1}}}}$,然后应用:

\mathbf{Q}_{t}\mapsto\widetilde{\mathbf{Q}}_{t}=\mathrm{RoPE}_{t}(\mathbf{Q}_{t}),\qquad\mathbf{K}_{s}\mapsto\widetilde{\mathbf{K}}_{s}=\mathrm{RoPE}_{s}(\mathbf{K}_{s}).```

\mathbf{Q}{t}\mapsto\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t}),\qquad\mathbf{K}{s}\mapsto\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}_{s}(\mathbf{K}_{s}).```

Direct Integration. A useful optimization is to integrate RoPE directly into the TPA factorization. For example, one can pre-rotate the token-dimension factors:

直接集成。一个有用的优化是将 RoPE 直接集成到 TPA 分解中。例如,可以预先旋转 Token 维度的因子:

\widetilde{\bf B}_{K}({\bf x}_{t})\,\longleftarrow\,\mathrm{RoPE}_{t}\bigl({\bf B}_{K}({\bf x}_{t})\bigr),```

\widetilde{\bf B}{K}({\bf x}{t}),\longleftarrow,\mathrm{RoPE}{t}\bigl({\bf B}{K}({\bf x}_{t})\bigr),```

yielding a pre-rotated key represe ntation:

生成一个预旋转的密钥表示:

\widetilde{\mathbf{K}}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{(r)}^{K}(\mathbf{x}_{t})\otimes\mathrm{RoPE}_{t}\bigl(\mathbf{b}_{(s)}^{K}(\mathbf{x}_{t})\bigr)=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{t})^{\top}\,\mathrm{RoPE}_{t}\bigl(\mathbf{B}_{K}(\mathbf{x}_{t})\bigr).```

\widetilde{\mathbf{K}}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{(r)}^{K}(\mathbf{x}{t})\otimes\mathrm{RoPE}{t}\bigl(\mathbf{b}{(s)}^{K}(\mathbf{x}_{t})\bigr)=\frac{1}{R_{K}}\mathbf{A}_{K}(\mathbf{x}_{t})^{\top},\mathrm{RoPE}_{t}\bigl(\mathbf{B}_{K}(\mathbf{x}_{t})\bigr).```

Thus, each $\mathbf{K}_{t}$ is already rotated before caching, removing the need for explicit rotation at the decoding time and accelerating auto regressive inference. Depending on hardware and performance requirements, one can also adopt different RoPE integration approaches for training and inference.

因此,每个 $\mathbf{K}_{t}$ 在缓存之前已经进行了旋转,从而在解码时无需显式旋转,并加速了自回归推理。根据硬件和性能需求,还可以在训练和推理时采用不同的 RoPE 集成方法。

Theorem 1 (RoPE’s Compatibility with TPA). Let $\mathbf{Q}_{t}$ be factorized by TPA as

定理 1 (RoPE 与 TPA 的兼容性). 令 $\mathbf{Q}_{t}$ 由 TPA 分解为

\mathbf{Q}_{t}=\frac{1}{R_{Q}}\,\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\mathbf{B}_{Q}(\mathbf{x}_{t})\ \in\mathbb{R}^{h\times d_{h}},```

\mathbf{Q}{t}=\frac{1}{R{Q}},\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\mathbf{B}{Q}(\mathbf{x}{t})\ \in\mathbb{R}^{h\times d_{h}},```

where $\mathbf{A}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times h}$ and $\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times d_{h}}$ . Then we have:

其中 $\mathbf{A}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times h}$ 和 $\mathbf{B}{Q}(\mathbf{x}{t})\in\mathbb{R}^{R_{Q}\times d_{h}}$。然后我们有:

\mathrm{RoPE}(\mathbf{Q}_{t})=\frac{1}{R_{Q}}\mathbf{A}_{Q}(\mathbf{x}_{t})^{\top}\,\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t}),\qquad\mathrm{where}~\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_{t})=\mathrm{RoPE}_{t}\big(\mathbf{B}_{Q}(\mathbf{x}_{t})\big).```

\mathrm{RoPE}(\mathbf{Q}{t})=\frac{1}{R{Q}}\mathbf{A}{Q}(\mathbf{x}{t})^{\top},\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t}),\qquad\mathrm{where}~\widetilde{\mathbf{B}}{Q}(\mathbf{x}{t})=\mathrm{RoPE}{t}\big(\mathbf{B}{Q}(\mathbf{x}_{t})\big).```

In addition, assume $\mathbf{Q}{t}$ and $\mathbf{K}{s}$ are factorized by TPA and then rotated by $\mathrm{RoPE}{t},\mathrm{RoPE}{s}$ . Let $\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t})$ and $\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}{s}(\mathbf{K}{s})$ . Then we have

此外,假设 $\mathbf{Q}{t}$ 和 $\mathbf{K}{s}$ 通过 TPA 进行分解,然后通过 $\mathrm{RoPE}{t},\mathrm{RoPE}{s}$ 进行旋转。设 $\widetilde{\mathbf{Q}}{t}=\mathrm{RoPE}{t}(\mathbf{Q}{t})$ 和 $\widetilde{\mathbf{K}}{s}=\mathrm{RoPE}{s}(\mathbf{K}{s})$。那么我们有

\mathrm{RoPE}_{t-s}(\mathbf{Q}_{t})\mathbf{K}_{s}^{\top}=\widetilde{\mathbf{Q}}_{t}\,\widetilde{\mathbf{K}}_{s}^{\top},```

\mathrm{RoPE}{t-s}(\mathbf{Q}{t})\mathbf{K}{s}^{\top}=\widetilde{\mathbf{Q}}{t},\widetilde{\mathbf{K}}_{s}^{\top},```

Focusing on individual heads $i$ , the above matrix equalit y im plies:

关注单个头 $i$,上述矩阵等式意味着:

\mathrm{RoPE}_{t-s}\big(\mathbf{q}_{t,i}\big)^{\top}\mathbf{k}_{s,i}=\widetilde{\mathbf{q}}_{t,i}^{\top}\,\widetilde{\mathbf{k}}_{s,i}.```

\mathrm{RoPE}{t-s}\big(\mathbf{q}{t,i}\big)^{\top}\mathbf{k}{s,i}=\widetilde{\mathbf{q}}{t,i}^{\top},\widetilde{\mathbf{k}}_{s,i}.```

where $\mathbf{q}{t,i}\in\mathbb{R}^{d{h}}$ is the $i$ -th query head of $t$ -th token, an d $\mathbf{k}{s,i}\in\mathbb{R}^{d{h}}$ is the $j$ -th key head of $s$ -th token, and

其中 $\mathbf{q}{t,i}\in\mathbb{R}^{d{h}}$ 是第 $t$ 个 token 的第 $i$ 个查询头,$\mathbf{k}{s,i}\in\mathbb{R}^{d{h}}$ 是第 $s$ 个 token 的第 $j$ 个键头,

\widetilde{\mathbf{q}}_{t,i}=\mathrm{RoPE}(\mathbf{q}_{t,i})=\mathbf{T}_{t}\mathbf{q}_{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```

\widetilde{\mathbf{q}}{t,i}=\mathrm{RoPE}(\mathbf{q}{t,i})=\mathbf{T}{t}\mathbf{q}{t,i}\in\mathbb{R}^{d_{h}},\quad\widetilde{\mathbf{k}}_{s,i}=\mathrm{RoPE}(\mathbf{k}_{s,i})=\mathbf{T}_{s}\mathbf{k}_{s,i}\in\mathbb{R}^{d_{h}}.```

Theorem 1 indicates that TPA does not break RoPE’s relative translational property. We prove Theorem 1 in Appendix A. In short, $\mathrm{RoPE}{t}$ acts as a block-diagonal orthogonal transform (i.e., a matrix $\mathbf{T}{t}$ ) on $\mathbf{B}{Q}(\mathbf{x}{t})$ . Consequently, $\mathbf{A}{Q}(\mathbf{x}{t})$ remains unchanged, while each column of $\mathbf{B}{Q}(\mathbf{x}{t})$ is rotated appropriately, preserving the TPA structure.

定理 1 表明 TPA 不会破坏 RoPE 的相对平移性质。我们在附录 A 中证明了定理 1。简而言之,$\mathrm{RoPE}{t}$ 作为块对角正交变换(即矩阵 $\mathbf{T}{t}$)作用于 $\mathbf{B}{Q}(\mathbf{x}{t})$。因此,$\mathbf{A}{Q}(\mathbf{x}{t})$ 保持不变,而 $\mathbf{B}{Q}(\mathbf{x}{t})$ 的每一列都适当旋转,从而保留了 TPA 结构。

3.3 KV Caching and Memory Reduction

3.3 KV 缓存与内存优化

In auto regressive decoding, standard attention caches $\mathbf{K}{t},\mathbf{V}{t}\in\mathbb{R}^{h\times d_{h}}$ for each past token $t$ . This accumulates to $\mathbb{R}^{T\times h\times d_{h}}$ for keys and $\mathbb{R}^{T\times h\times d_{h}}$ for values, i. e., $2,T,h,d_{h}$ total.

在自回归解码中,标准注意力缓存为每个过去的 token $t$ 存储 $\mathbf{K}{t},\mathbf{V}{t}\in\mathbb{R}^{h\times d_{h}}$。对于键和值,这会累积到 $\mathbb{R}^{T\times h\times d_{h}}$,即总共 $2,T,h,d_{h}$。

TPA Factorized KV Caching. Instead of storing the full $\mathbf{K}{t}$ and $\mathbf{V}{t}$ , TPA stores only their factorized ranks. Specifically, we keep

TPA 分解 KV 缓存。TPA 不存储完整的 $\mathbf{K}{t}$ 和 $\mathbf{V}{t}$,而是仅存储它们的分解秩。具体来说,我们保留

\mathbf{A}_{K}(\mathbf{x}_{t}),\,\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{t})\quad\mathrm{and}\quad\mathbf{A}_{V}(\mathbf{x}_{t}),\,\mathbf{B}_{V}(\mathbf{x}_{t}),```

\mathbf{A}{K}(\mathbf{x}{t}),,\widetilde{\mathbf{B}}{K}(\mathbf{x}{t})\quad\mathrm{and}\quad\mathbf{A}{V}(\mathbf{x}{t}),,\mathbf{B}{V}(\mathbf{x}{t}),```

\begin{array}{r}{\mathbf{A}_{K}(\mathbf{x}_{t})\in\mathbb{R}^{R_{K}\times h},\;\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{t})\in\mathbb{R}^{R_{K}\times d_{h}},\;\mathbf{A}_{V}(\mathbf{x}_{t})\in\mathbb{R}^{R_{V}\times h},\;\mathbf{B}_{V}(\mathbf{x}_{t})\in\mathbb{R}^{R_{V}\times d_{h}}.}\end{array}```

\begin{array}{r}{\mathbf{A}{K}(\mathbf{x}{t})\in\mathbb{R}^{R_{K}\times h},;\widetilde{\mathbf{B}}{K}(\mathbf{x}{t})\in\mathbb{R}^{R_{K}\times d_{h}},;\mathbf{A}{V}(\mathbf{x}{t})\in\mathbb{R}^{R_{V}\times h},;\mathbf{B}{V}(\mathbf{x}{t})\in\mathbb{R}^{R_{V}\times d_{h}}.}\end{array}```

Hence, the memory cost per token is

因此,每个 Token 的内存成本为

\underbrace{R_{K}(h+d_{h})}_{\mathrm{for\,K}}\,+\,\underbrace{R_{V}(h+d_{h})}_{\mathrm{for\,V}}=\left(\,R_{K}+R_{V}\,\right)\left(h+d_{h}\right).```

\underbrace{R_{K}(h+d_{h})}{\mathrm{for,K}},+,\underbrace{R{V}(h+d_{h})}{\mathrm{for,V}}=\left(,R{K}+R_{V},\right)\left(h+d_{h}\right).```

Compared to the standard caching cost of $2,h,d_{h}$ , the ratio is:

与标准缓存成本 $2,h,d_{h}$ 相比,比率为:


\frac{\left(R_{K}+R_{V}\right)\left(h+d_{h}\right)}{2\,h\,d_{h}}.```

\frac{\left(R_{K}+R_{V}\right)\left(h+d_{h}\right)}{2,h,d_{h}}.```

For large $h$ and $d_{h}$ (typically $d_{h}=64$ or 128), setting $R_{K},R_{V}\ll d_{h}$ (e.g., 1 or 2) often yields $10\times$ or more reduction.

对于较大的 $h$ 和 $d_{h}$(通常 $d_{h}=64$ 或 128),设置 $R_{K},R_{V}\ll d_{h}$(例如 1 或 2)通常可以减少 $10\times$ 或更多。

Table 1: Comparison of different attention mechanisms. Here, $R_{Q}$ , $R_{K}$ , and $R_{V}$ denote the ranks for queries, keys, and values in TPA, respectively. Variants of TPA, such as TPA (KVonly), TPA (Non-contextual A), and TPA (Non-contextual B), are detailed in Section 3.5. For MLA, $d_{h}^{R}$ and $d_{h}$ are the dimensions for RoPE and non-RoPE parts; $d_{c}^{\prime}$ and $d_{c}$ are the dimensions of compressed vectors for query and key-value, respectively.

表 1: 不同注意力机制的比较。其中,$R_{Q}$、$R_{K}$ 和 $R_{V}$ 分别表示 TPA 中查询、键和值的秩。TPA 的变体,如 TPA (KVonly)、TPA (Non-contextual A) 和 TPA (Non-contextual B),详见第 3.5 节。对于 MLA,$d_{h}^{R}$ 和 $d_{h}$ 分别是 RoPE 部分和非 RoPE 部分的维度;$d_{c}^{\prime}$ 和 $d_{c}$ 分别是查询和键值压缩向量的维度。

方法 KVCACHE #参数 #查询头 #KV头
MHA 2hdh h h
MQA 2dh h 1
GQA 2gdh d(dmodel +hdn +hdR) h 6
MLA de+ dR +dmodeld + de(dmodel + 2hdn) h h
TPA (RK+Rv)(h+dn) dmodel(RQ+RK+Rv)(h + dn)+ dmode1 hdh h h
TPA (KVonly) (RK+Rv)(h+dh) dmodel(RK +Rv)(h +dn)+2dmodel hdh h h
TPA(Non-contextualA) (RK+Rv)dh (RQ+RK+Rv)(dmodeldn+h)+dmodel hdh h h
TPA(Non-contextualB) (RK+Rv)h (RQ+RK+Rv)(dmodeih+dh)+dmodel hdh h h

3.4 Unifying MHA, MQA, and GQA as Non-contextual TPA

3.4 将 MHA、MQA 和 GQA 统一为非上下文 TPA

3.4.1 MHA as Non-contextual TPA

3.4.1 多头注意力机制 (MHA) 作为非上下文 TPA

Standard multi-head attention (MHA) can be viewed as a specific instance of TPA in which: 1) the rank is set equal to the number of heads; 2) the head dimension factor is non-contextual (i.e., independent of the $t$ -th token embedding $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}})$ ; 3) the token dimension factor is a linear function of $\mathbf{x}_{t}$ .

标准多头注意力机制 (MHA) 可以被视为 TPA 的一个特定实例,其中:1) 秩被设置为头的数量;2) 头维度因子是非上下文相关的(即独立于第 $t$ 个 token 嵌入 $\mathbf{x}{t},\in,\mathbb{R}^{d{\mathrm{model}}})$;3) token 维度因子是 $\mathbf{x}_{t}$ 的线性函数。

To match MHA with TPA, let $R_{Q}=R_{K}=R_{V}=h$ . Focusing on $\mathbf{Q}_{t}$ :

为了将 MHA 与 TPA 匹配,令 $R_{Q}=R_{K}=R_{V}=h$。重点关注 $\mathbf{Q}_{t}$:

(a) Non-contextual head factors. Define

(a) 非上下文头因子。定义


\mathbf{a}_{i}^{Q}=R_{Q}\mathbf{e}_{i}\in\mathbb{R}^{h},\quad(\mathbf{e}_{i}\in\mathbb{R}^{h}\mathrm{~is~the~}i\mathbf{\cdot}\mathbf{t}\mathrm{h~standard~basis~vector}),```

\mathbf{a}{i}^{Q}=R{Q}\mathbf{e}{i}\in\mathbb{R}^{h},\quad(\mathbf{e}{i}\in\mathbb{R}^{h}\mathrm{}i\mathbf{\cdot}\mathbf{t}\mathrm{~标准基向量}),```

so that $\mathbf{e}{i}\otimes\cdot$ corresponds to the $i$ -th head of $\mathbf{Q}{t}$ .

使得 $\mathbf{e}{i}\otimes\cdot$ 对应于 $\mathbf{Q}{t}$ 的第 $i$ 个头。

(b) Contextual token factors. Define

(b) 上下文 Token 因素。定义


\mathbf{b}_{i}^{Q}(\mathbf{x}_{t})=(\boldsymbol{W}_{i}^{Q})^{\top}\mathbf{x}_{t}\in\mathbb{R}^{d_{h}},```

\mathbf{b}{i}^{Q}(\mathbf{x}{t})=(\boldsymbol{W}{i}^{Q})^{\top}\mathbf{x}{t}\in\mathbb{R}^{d_{h}},```

where $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ is the per-head query projection defined before, hence $\mathbf{b}{i}^{Q}(\mathbf{x}{t})$ dependent on $\mathbf{x}_{t}$ .

其中 $W_{i}^{Q}\in\mathbb{R}^{d_{\mathrm{model}}\times d_{h}}$ 是之前定义的每个头的查询投影,因此 $\mathbf{b}{i}^{Q}(\mathbf{x}{t})$ 依赖于 $\mathbf{x}_{t}$。

Substituting (3.8)–(3.9) into (3.1) gives:

将 (3.8)–(3.9) 代入 (3.1) 可得:


\mathbf{Q}_{t}=\sum_{i=1}^{h}\left[\mathbf{e}_{i}\otimes\left((W_{i}^{Q})^{\top}\,\mathbf{x}_{t}\right)\right]\in\mathbb{R}^{h\times d_{h}}.```

\mathbf{Q}{t}=\sum{i=1}^{h}\left[\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top},\mathbf{x}{t}\right)\right]\in\mathbb{R}^{h\times d{h}}.```

Each term $\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top}\mathbf{x}{t}\right)$ in (3.10) contributes only to the $i$ -th row, reconstituting the usual MHA form of $\mathbf{Q}{t}$ . Analogous constructions hold for $\mathbf{K}{t}$ and $\mathbf{V}{t}$ using $W_{i}^{K},W_{i}^{V}$ . Thus, MHA is $a$ non-contextual, full-rank variant of $T P A$ .

(3.10) 式中的每一项 $\mathbf{e}{i}\otimes\left((W{i}^{Q})^{\top}\mathbf{x}{t}\right)$ 仅对第 $i$ 行有贡献,重构了 $\mathbf{Q}{t}$ 的常规多头注意力 (MHA) 形式。类似地,使用 $W_{i}^{K},W_{i}^{V}$ 可以构造 $\mathbf{K}{t}$ 和 $\mathbf{V}{t}$。因此,MHA 是 $T P A$ 的一种非上下文、满秩变体。

TPA with Non-contextual A. More broadly, TPA can use non-contextual head-dimension factors $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}{r}^{V}\ \in\ \mathbb{R}^{h}$ (i.e., independent of ${\bf x}{t}$ ), while allowing $\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ to remain context-dependent. Then, for keys:

TPA 与非上下文 A。更广泛地说,TPA 可以使用非上下文头维度因子 $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}{r}^{V}\ \in\ \mathbb{R}^{h}$ (即与 ${\bf x}{t}$ 无关),同时允许 $\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ 保持上下文依赖性。然后,对于键:


\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}\otimes\mathbf{b}_{r}^{K}(\mathbf{x}_{t}),```

\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}\otimes\mathbf{b}{r}^{K}(\mathbf{x}_{t}),```

and similarly for queries/values. This reduces per-token computations and can be effective when head-dimension relationships are relatively stable across all tokens.

对于查询/值也是如此。这减少了每个 Token 的计算量,并且当所有 Token 之间的头维度关系相对稳定时,这种方法可能有效。

MQA and GQA as Non-Contextual TPA. Multi-Query Attention (MQA) (Shazeer, 2019) and Grouped Query Attention (GQA) (Ainslie et al., 2023) also emerge naturally from TPA by restricting the head-dimension factors to be non-contextual and low-rank:

MQA 和 GQA 作为非上下文 TPA。多查询注意力 (MQA) (Shazeer, 2019) 和分组查询注意力 (GQA) (Ainslie et al., 2023) 也通过将头维度因子限制为非上下文和低秩,自然地从 TPA 中产生:

• MQA as Rank-1 TPA. In MQA, all heads share a single set of keys/values, corresponding to $R_{K}=R_{V}=1$ along the head dimension. Concretely,

• MQA 作为 Rank-1 TPA。在 MQA 中,所有头共享一组键/值,对应于沿头维度的 $R_{K}=R_{V}=1$。具体来说,


\mathbf{K}_{t}=(1,\ldots,1)^{\top}\,\otimes\,\mathbf{b}^{K}(\mathbf{x}_{t}),\quad\mathbf{V}_{t}=(1,\ldots,1)^{\top}\,\otimes\,\mathbf{b}^{V}(\mathbf{x}_{t}),```

\mathbf{K}{t}=(1,\ldots,1)^{\top},\otimes,\mathbf{b}^{K}(\mathbf{x}{t}),\quad\mathbf{V}{t}=(1,\ldots,1)^{\top},\otimes,\mathbf{b}^{V}(\mathbf{x}{t}),```

forces every head to use the same $\mathbf{K}{t},\mathbf{V}{t}$ . Each head retains a distinct query projection, matching the MQA design.

强制每个头使用相同的 $\mathbf{K}{t},\mathbf{V}{t}$。每个头保留一个独特的查询投影,与 MQA 设计相匹配。

• GQA as Grouped Rank-1 TPA. GQA partitions $h$ heads into $G$ groups, each sharing keys/values within that group. In TPA form, each group $g$ has a dedicated non-contextual factor pair $\bar{\mathbf{a}}{g}^{K},\mathbf{a}{g}^{V}\in$ $\mathbb{R}^{h}$ , which acts as a “mask” for the heads in that group. Varying $G$ from 1 to $h$ interpolates from MQA to standard MHA.

• GQA 作为分组秩-1 TPA。GQA 将 $h$ 个头部分为 $G$ 个组,每个组内的头共享键/值。在 TPA 形式中,每个组 $g$ 都有一个专用的非上下文因子对 $\bar{\mathbf{a}}{g}^{K},\mathbf{a}{g}^{V}\in$ $\mathbb{R}^{h}$,作为该组内头部的“掩码”。将 $G$ 从 1 变化到 $h$ 时,从 MQA 过渡到标准 MHA。

Hence, by constraining TPA’s head-dimension factors to be constant masks (one for MQA; multiple for GQA), these popular variants are recovered as special cases.

因此,通过将 TPA 的头维度因子限制为常数掩码(MQA 为一个;GQA 为多个),这些流行的变体可以作为特例恢复。

3.5 Other Variants of TPA

3.5 TPA 的其他变体

TPA with Non-contextual B. Conversely, one may fix the token-dimension factors $\mathbf{b}{r}^{Q},\mathbf{b}{r}^{K},\mathbf{b}{r}^{V}\in$ $\mathbb{R}^{d{h}}$ as learned parameters, while allowing $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t})$ to adapt to $\mathbf{x}_{t}$ . For keys:

TPA 与非上下文 B。相反,可以将 Token 维度因子 $\mathbf{b}{r}^{Q},\mathbf{b}{r}^{K},\mathbf{b}{r}^{V}\in$ $\mathbb{R}^{d{h}}$ 固定为学习参数,同时允许 $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t})$ 适应 $\mathbf{x}_{t}$。对于键:


\mathbf{K}_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}}\mathbf{a}_{r}^{K}(\mathbf{x}_{t})\otimes\mathbf{b}_{r}^{K},```

\mathbf{K}{t}=\frac{1}{R{K}}\sum_{r=1}^{R_{K}}\mathbf{a}{r}^{K}(\mathbf{x}{t})\otimes\mathbf{b}_{r}^{K},```

and similarly for keys/values. This arrangement is effective if the token-dimension structure remains mostly uniform across the sequence, while the head-dimension factors capture context.

同样适用于键/值。如果Token维度的结构在序列中基本保持一致,而头部维度的因素捕捉上下文,这种安排是有效的。

TPA KV Only. One can preserve a standard query mapping,

TPA KV Only。可以保留标准的查询映射,


\mathbf{Q}_{t}=W^{Q}\,\mathbf{x}_{t}\in\mathbb{R}^{h\times d_{h}},```

\mathbf{Q}{t}=W^{Q},\mathbf{x}{t}\in\mathbb{R}^{h\times d_{h}},```

and factorize only the keys and values. This leaves the query projection as the original linear transformation while reducing memory usage via factorized KV caching.

仅对键和值进行分解。这使得查询投影保持原始的线性变换,同时通过分解的键值缓存减少内存使用。

TPA KV with Shared B. Another variant is to share the token-dimension factors of keys and values:

TPA KV 与共享 B。另一种变体是共享键和值的 Token 维度因子:


\mathbf{b}_{r}^{K}(\mathbf{x}_{t})=\mathbf{b}_{r}^{V}(\mathbf{x}_{t}),```

\mathbf{b}{r}^{K}(\mathbf{x}{t})=\mathbf{b}{r}^{V}(\mathbf{x}{t}),```

lowering parameter counts and the KV cache footprint. While it constrains $\mathbf{K}$ and $\mathbf{V}$ to be formed from the same token basis, it can still perform well and provide additional memory savings.

降低参数量和 KV 缓存占用。虽然它限制了 $\mathbf{K}$ 和 $\mathbf{V}$ 必须基于相同的 Token 基础,但它仍然可以表现良好并提供额外的内存节省。

Nonlinear Head Factors. Rather than applying purely linear mappings to the head-dimension factors $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}_{r}^{V}$ , one may introduce element-wise nonlinear i ties such as $\sigma(\cdot)$ or softmax $(\cdot)$ . This effectively yields a Mixture of Heads Attention (MoH Attention), where each component becomes a learned mixture weight modulated by the non linearity.

非线性头因子。与对头维度因子 $\mathbf{a}{r}^{Q},\mathbf{a}{r}^{K},\mathbf{a}_{r}^{V}$ 应用纯线性映射不同,可以引入逐元素的非线性关系,例如 $\sigma(\cdot)$ 或 softmax $(\cdot)$。这实际上产生了一种多头注意力混合(MoH Attention),其中每个分量都成为由非线性调制的学习混合权重。

Discussion. These variants illustrate TPA’s versatility in balancing memory cost, computational overhead, and representation power. By choosing which dimensions (heads or tokens) remain contextual and adjusting ranks $(R_{Q},R_{K},R_{V})$ , TPA unifies multiple existing attention mechanisms— such as MHA, MQA, and GQA—under one framework, while potentially reducing the KV cache size by an order of magnitude during auto regressive inference.

讨论。这些变体展示了 TPA 在平衡内存成本、计算开销和表示能力方面的多功能性。通过选择哪些维度(头或 Token)保持上下文并调整秩 $(R_{Q},R_{K},R_{V})$,TPA 将多种现有的注意力机制(如 MHA、MQA 和 GQA)统一在一个框架下,同时可能在自回归推理期间将 KV 缓存大小减少一个数量级。

3.6 Model Architectures

3.6 模型架构

We propose a new architecture called Tensor ProducT ATTenTion Transformer (T6), which uses our Tensor Product Attention (TPA) in place of standard MHA (multi-head attention) or GQA (grouped-query attention). Building upon the query, key, and value tensors $\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathrm{~\mathbb{R}}^{N\times h\times\tilde{d}_{h}}$ defined in Section 3.1, T6 utilize the overall architecture of LLaMA (Touvron et al., 2023) while changing the self-attention block to our TPA-based version. The feed-forward network (FFN) adopts a SwiGLU layer, as in (Shazeer, 2020; Touvron et al., 2023).

我们提出了一种名为 Tensor ProducT ATTenTion Transformer (T6) 的新架构,该架构使用我们的 Tensor Product Attention (TPA) 替代了标准的 MHA(多头注意力)或 GQA(分组查询注意力)。基于第 3.1 节中定义的查询、键和值张量 $\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathrm{~\mathbb{R}}^{N\times h\times\tilde{d}_{h}}$,T6 采用了 LLaMA (Touvron et al., 2023) 的整体架构,同时将自注意力块更改为基于 TPA 的版本。前馈网络 (FFN) 采用了 SwiGLU 层,如 (Shazeer, 2020; Touvron et al., 2023) 所述。

TPA QKV Factorization. Let each token’s hidden-state vector be $\mathbf{x}{t}\ \in\ \mathbb{R}^{d{\mathrm{model}}}$ , and we follow Section 3.1 to project the entire sequence into three tensors $\mathbf{Q},\mathbf{K},\mathbf{V}\quad\in\mathrm{\quad}\mathbb{R}^{T\times h\times d_{h}}$ , where $\mathbf{Q}{t}$ , Kt, $\mathbf{V}{t}^{\lambda}\in,\mathbb{R}^{h\times d_{h}}$ denote the slices for the $t$ -th token. The factor components $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ are produced by linear transformations on $\mathbf{x}{t}$ . For instance, letting $W{r}^{a^{Q}}\in\mathbb{R}^{h\times d_{\mathrm{model}}}$ and $\boldsymbol{W_{r}^{b}}^{Q}\in\mathbb{R}^{d_{h}\times d_{\mathrm{model}}}$ , we have:

TPA QKV 分解。设每个 Token 的隐藏状态向量为 $\mathbf{x}{t}\ \in\ \mathbb{R}^{d{\mathrm{model}}}$,我们按照第 3.1 节将整个序列投影为三个张量 $\mathbf{Q},\mathbf{K},\mathbf{V}\quad\in\mathrm{\quad}\mathbb{R}^{T\times h\times d_{h}}$,其中 $\mathbf{Q}{t}$、Kt、$\mathbf{V}{t}^{\lambda}\in,\mathbb{R}^{h\times d_{h}}$ 表示第 $t$ 个 Token 的切片。因子分量 $\mathbf{a}{r}^{Q}(\mathbf{x}{t}),\mathbf{b}{r}^{Q}(\mathbf{x}{t}),\mathbf{a}{r}^{K}(\mathbf{x}{t}),\mathbf{b}{r}^{K}(\mathbf{x}{t}),\mathbf{a}{r}^{V}(\mathbf{x}{t}),\mathbf{b}{r}^{V}(\mathbf{x}{t})$ 通过对 $\mathbf{x}{t}$ 进行线性变换生成。例如,设 $W{r}^{a^{Q}}\in\mathbb{R}^{h\times d_{\mathrm{model}}}$ 和 $\boldsymbol{W_{r}^{b}}^{Q}\in\mathbb{R}^{d_{h}\times d_{\mathrm{model}}}$,我们有:


\mathbf{a}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{a}^{Q}\,\mathbf{x}_{t},\quad\mathbf{b}_{r}^{Q}(\mathbf{x}_{t})=W_{r}^{b}^{Q}\,\mathbf{x}_{t}.```

\mathbf{a}{r}^{Q}(\mathbf{x}{t})=W_{r}^{a}^{Q},\mathbf{x}{t},\quad\mathbf{b}{r}^{Q}(\mathbf{x}{t})=W{r}^{b}^{Q},\mathbf{x}_{t}.```

In practice, we merge all ranks $r$ into a single dimension of the output, reshape, and sum over rank indices; see Section 3.1 for details. The factorization for $\mathbf{K}$ and $\mathrm{v}$ follows the same pattern.

在实践中,我们将所有秩 $r$ 合并到输出的单一维度中,进行重塑,并对秩索引求和;详见第 3.1 节。$\mathbf{K}$ 和 $\mathrm{v}$ 的分解遵循相同的模式。

Rotary Positional Embedding $(\mathbf{RoPE})$ ). As discussed in Section 3.2, RoPE (Su et al., 2024b) is applied to the $\mathbf{Q}$ and $\mathbf{K}$ . Within TPA, we pre-rotate the factor $\mathbf{b}{t}^{Q}(\mathbf{x}{t})$ and $\mathbf{b}{s}^{K}(\mathbf{x}{s})$ directly, so that each $\mathbf{K}_{s}$ is already rotated prior to caching, see (3.6) and Theorem 1.

旋转位置嵌入 (Rotary Positional Embedding, RoPE) 。如第3.2节所述,RoPE (Su et al., 2024b) 被应用于 $\mathbf{Q}$ 和 $\mathbf{K}$ 。在TPA中,我们直接预旋转因子 $\mathbf{b}{t}^{Q}(\mathbf{x}{t})$ 和 $\mathbf{b}{s}^{K}(\mathbf{x}{s})$ ,使得每个 $\mathbf{K}_{s}$ 在缓存之前已经旋转,参见 (3.6) 和定理1。

Attention Step and Output Projection. Once we have $\mathbf{Q},\mathbf{K},\mathbf{V}$ factorized per token with RoPE applied on $\mathbf{Q}$ and $\mathbf{K}$ , the attention step proceeds for each head $i\in{1,\ldots,h}$ using (3.4). Finally, concatenating these $h$ heads and then projecting them back using an output weight matrix gives the final attention result, as shown in (3.5).

注意力步骤和输出投影。一旦我们为每个 Token 分解了 $\mathbf{Q},\mathbf{K},\mathbf{V}$,并在 $\mathbf{Q}$ 和 $\mathbf{K}$ 上应用了 RoPE,注意力步骤就会为每个头 $i\in{1,\ldots,h}$ 使用 (3.4) 进行。最后,将这些 $h$ 个头连接起来,然后使用输出权重矩阵将它们投影回去,得到最终的注意力结果,如 (3.5) 所示。

SwiGLU Feed-Forward Network. Following Shazeer (2020); Touvron et al. (2023), our T6 uses a SwiGLU-based Feed-Forward Network (FFN):

SwiGLU 前馈网络。遵循 Shazeer (2020) 和 Touvron et al. (2023) 的研究,我们的 T6 使用了基于 SwiGLU 的前馈网络 (FFN):

\mathrm{FFN}(\mathbf{x})=\left[\sigma(\mathbf{x}\,{W_{1}})\,\odot\,(\mathbf{x}\,{W_{2}})\right]{W_{3}},```

\mathrm{FFN}(\mathbf{x})=\left[\sigma(\mathbf{x},{W_{1}}),\odot,(\mathbf{x},{W_{2}})\right]{W_{3}},```

where $\sigma$ is the SiLU (a.k.a., swish) non linearity, $\odot$ is element-wise product, and $W_{1},W_{2},W_{3}$ are learnable parameters. Note that other activation functions can also be used.

其中,$\sigma$ 是 SiLU(也称为 swish)非线性函数,$\odot$ 是逐元素乘积,$W_{1},W_{2},W_{3}$ 是可学习的参数。注意,其他激活函数也可以使用。

Overall T6 Block Structure. Putting everything together, one T6 block consists of:

总体 T6 块结构。将所有内容整合在一起,一个 T6 块由以下部分组成:

\begin{array}{r l}&{\textbf{x}\leftarrow\textbf{x}+\mathrm{TPA}\big(\mathrm{RMSNorm}(\mathbf{x})\big),}\\ &{\textbf{x}\leftarrow\textbf{x}+\mathrm{SwiGLU-FFN}\big(\mathrm{RMSNorm}(\mathbf{x})\big).}\end{array}```

\begin{array}{r l}&{\textbf{x}\leftarrow\textbf{x}+\mathrm{TPA}\big(\mathrm{RMSNorm}(\mathbf{x})\big),}\ &{\textbf{x}\leftarrow\textbf{x}+\mathrm{SwiGLU-FFN}\big(\mathrm{RMSNorm}(\mathbf{x})\big).}\end{array}```

We place norm layers (e.g., RMSNorm) before each sub-layer. Stacking $L$ such blocks yields a T6 model architecture with $L$ layers.

我们将归一化层(例如 RMSNorm)放置在每个子层之前。堆叠 $L$ 个这样的块会生成一个具有 $L$ 层的 T6 模型架构。

4 Experiments

4 实验

4.1 Language Modeling Tasks

4.1 语言建模任务

All experiments reported in this paper are implemented on the nanoGPT code base (Karpathy, 2022), using the FineWeb-Edu 100B dataset (Lozhkov et al., 2024). The dataset contains 100 billion tokens for training and 0.1 billion tokens for validation. We compare T6 against the baseline Llama architecture (Touvron et al., 2023) with SwiGLU activation (Shazeer, 2020) and RoPE embeddings (Su et al., 2024a), as well as Llama variants that replace Multi-Head Attention (MHA; Vaswani et al., 2017) with Multi-Query Attention (MQA; Shazeer, 2019), Grouped Query Attention (GQA; Ainslie et al., 2023), or Multi-head Latent Attention (MLA; Liu et al., 2024a). In our experiments, the numtbheer osfa mheea dnsu $h$ ibse ra dojfu sptaerda fmoer teearsc ha sa tttehnet isotan nmdaercdh aMniusltmi -tHo eeands uArtet tehnatit oanll (atMteHntAi)o,n wmheicchh ahnaiss $4d_{\mathrm{model}}^{2}$ parameters per attention layer.