[论文翻译]通过从数万亿token中检索改进语言模型


原文地址:https://arxiv.org/pdf/2112.04426


Improving language models by retrieving from trillions of tokens

通过从数万亿token中检索改进语言模型

We enhance auto-regressive language models by conditioning on document chunks retrieved from a large corpus, based on local similarity with preceding tokens. With a 2 trillion token database, our Retrieval-Enhanced Transformer (Retro) obtains comparable performance to GPT-3 and Jurassic-1 on the Pile, despite using $\mathbf{25\times}$ fewer parameters. After fine-tuning, Retro performance translates to downstream knowledge-intensive tasks such as question answering. Retro combines a frozen Bert retriever, a differentiable encoder and a chunked cross-attention mechanism to predict tokens based on an order of magnitude more data than what is typically consumed during training. We typically train Retro from scratch, yet can also rapidly Retrofit pre-trained transformers with retrieval and still achieve good performance. Our work opens up new avenues for improving language models through explicit memory at unprecedented scale.

我们通过基于与先前token的局部相似性,从大型语料库中检索文档块作为条件,增强了自回归语言模型。在拥有2万亿token数据库的情况下,我们的检索增强型Transformer (Retro) 在Pile数据集上取得了与GPT-3和Jurassic-1相当的性能,尽管参数数量减少了$\mathbf{25\times}$。经过微调后,Retro的性能可迁移至知识密集型下游任务(如问答)。Retro结合了冻结的Bert检索器、可微分编码器和分块交叉注意力机制,能够基于比训练时通常消耗数据量高出一个数量级的规模来预测token。我们通常从头开始训练Retro,但也能快速为预训练Transformer添加检索功能(Retrofit)并保持良好性能。这项工作为通过显式内存以前所未有的规模改进语言模型开辟了新途径。

1. Introduction

1. 引言

Language modelling (LM) is an unsupervised task that consists of modelling the probability of text, usually by facto rising it into conditional next-token predictions $p(x_{1},...,x_{n})=\prod_{i}p(x_{i}|x_{<i})$ . Neural networks have proven to be powerful language models, first in the form of recurrent architectures (Graves, 2013; Jozefowicz et al., 2016; Mikolov et al., 2010) and more recently in the form of Transformers (Vaswani et al., 2017), that use attention to contextual is e the past. Large performance improvements have come from increasing the amount of data, training compute, or model parameters. Transformers have been scaled from 100 million parameter models in seminal work to over hundred billion parameters (Brown et al., 2020; Radford et al., 2019) in the last two years which has led to models that do very well on a wide array of tasks in a zero or few-shot formulation. Increasing model size predictably improves performance on a wide range of downstream tasks (Kaplan et al., 2020). The benefits of increasing the number of parameters come from two factors: additional computations at training and inference time, and increased memorization of the training data.

语言建模 (LM) 是一种无监督任务,旨在对文本概率进行建模,通常通过将其分解为条件性下一token预测来实现 $p(x_{1},...,x_{n})=\prod_{i}p(x_{i}|x_{<i})$ 。神经网络已被证明是强大的语言模型,最初以循环架构形式出现 (Graves, 2013; Jozefowicz et al., 2016; Mikolov et al., 2010),最近则以Transformer形式 (Vaswani et al., 2017) 通过注意力机制对历史上下文进行建模。性能的大幅提升来自数据量、训练算力或模型参数的增加。Transformer的参数量从开创性工作的1亿规模,在过去两年内扩展到超过千亿参数 (Brown et al., 2020; Radford et al., 2019),这使得模型在零样本或少样本设定下能出色完成广泛任务。增大模型规模可预期地提升各类下游任务性能 (Kaplan et al., 2020)。参数增加的优势来自两个因素:训练和推理时的额外计算量,以及对训练数据记忆能力的增强。

In this work, we endeavor to decouple these, by exploring efficient means of augmenting language models with a massive-scale memory without significantly increasing computations. Specifically, we suggest retrieval from a large text database as a complementary path to scaling language models. Instead of increasing the size of the model and training on more data, we equip models with the ability to directly access a large database to perform predictions—a semi-parametric approach. At a high level, our Retrieval Transformer (Retro) model splits the input sequence into chunks and retrieves text similar to the previous chunk to improve the predictions in the current chunk. Existing retrieval for language modelling work only considers small transformers (100 millions parameters) and databases of limited size (up to billions of tokens) (Guu et al., 2020; Khandelwal et al., 2020; Lewis et al., 2020; Yogatama et al., 2021). To our knowledge, our work is the first to show the benefits of scaling the retrieval database to trillions of tokens for large parametric language models. Our main

在本研究中,我们致力于通过探索高效方法,在不显著增加计算量的情况下为语言模型扩展海量记忆能力。具体而言,我们提出从大型文本数据库中进行检索,作为扩展语言模型的补充路径。不同于增大模型规模或使用更多训练数据,我们赋予模型直接访问大型数据库进行预测的能力——这是一种半参数化方法。从高层次看,我们的检索式Transformer(Retro)模型将输入序列分割为多个片段,并检索与前一片段相似的文本来提升当前片段的预测效果。现有语言建模的检索工作仅考虑小型Transformer(1亿参数)和有限规模的数据库(至多数十亿token)(Guu et al., 2020; Khandelwal et al., 2020; Lewis et al., 2020; Yogatama et al., 2021)。据我们所知,本研究首次证明了将检索数据库扩展至数万亿token对大型参数化语言模型的益处。我们的主要...


Figure 1 | Scaling of Retro. The performance gain of our retrieval models remains constant with model scale (left), and is comparable to multiplying the parameter ic model size by $\sim10\times$ . The gain increases with the size of the retrieval database (middle) and the number of retrieved neighbours (right) on the C4 validation set, when using up to 40 neighbours. Past this, performance begins to degrade, perhaps due to the reduced quality. At evaluation Retro can be used without retrieval data (Retro[OFF]), bringing limited performance degradation compared to baseline transformers.

图 1 | Retro的扩展性。我们的检索模型性能提升随模型规模保持恒定(左图),其效果相当于将参数模型规模扩大约10倍。在C4验证集上,当使用最多40个邻近检索项时,性能提升随检索数据库规模(中图)和检索邻近项数量(右图)增加而增大。超过该阈值后性能开始下降,可能源于检索质量降低。评估时Retro可不依赖检索数据运行(Retro[OFF]),与基准Transformer相比仅造成有限性能衰减。

contributions are the following.

贡献如下。

2. Method

2. 方法

We design our retrieval-enhanced architecture to be capable of retrieving from a database with trillions of tokens. For this purpose, we retrieve at the level of contiguous token chunks instead of individual tokens which reduces storage and computation requirements by a large linear factor. Our method first constructs a key-value database, where values store raw chunks of text tokens and keys are frozen Bert embed d dings (Devlin et al., 2019). We use a frozen model to avoid having to periodically re-compute embeddings over the entire database during training. Each training sequence is then split into chunks, which are augmented with their $k$ -nearest neighbour retrieved from the database. An encoder-decoder architecture integrates retrieval chunks into the model’s predictions. We summarize the Retro architecture in Fig. 2, and detail it in this section. We end the section by introducing a new methodology to evaluate language models when an evaluation set is partially present in the training set.

我们设计的检索增强架构能够从包含数万亿token的数据库中进行检索。为此,我们在连续token块级别进行检索而非单个token,这通过线性倍数显著降低了存储和计算需求。该方法首先构建一个键值数据库:值存储原始文本token块,键采用冻结的Bert嵌入向量 (Devlin et al., 2019) 。使用冻结模型可避免训练期间需定期重新计算整个数据库的嵌入向量。每个训练序列被分割成块后,会与其从数据库中检索到的$k$近邻共同增强。编码器-解码器架构将检索块整合至模型预测中。图2总结了Retro架构,本节将详细说明。最后,我们提出一种新方法用于评估当训练集包含部分测试集时的语言模型表现。


Figure 2 | Retro architecture. Left: simplified version where a sequence of length $n=12$ is split into $l=3$ chunks of size $m=4$ . For each chunk, we retrieve $k=2$ neighbours of $r=5$ tokens each. The retrieval pathway is shown on top. Right: Details of the interactions in the Cca operator. Causality is maintained as neighbours of the first chunk only affect the last token of the first chunk and tokens from the second chunk.

图 2 | 回顾式架构。左图:简化版本,其中长度为 $n=12$ 的序列被分割为 $l=3$ 个大小为 $m=4$ 的块。对于每个块,我们检索 $k=2$ 个邻居,每个邻居包含 $r=5$ 个token。检索路径显示在顶部。右图:Cca算子中交互的细节。因果性得以保持,因为第一个块的邻居仅影响第一个块的最后一个token以及第二个块的token。

2.1. Training dataset

2.1. 训练数据集

We use a multi-lingual version of Massive Text (Rae et al., 2021) for both training and retrieval data. The dataset consists of text documents from multiple sources and multiple languages totalling over 5 trillion tokens (detailed in Table 1). Sequences are sampled from subsets of the training data, with sampling weights given in the right-most column of Table 1. We tokenize the dataset using Sentence Piece (Kudo and Richardson, 2018) with a vocabulary of 128,000 tokens. During training (unless otherwise specified), we retrieve from 600B tokens from the training data. The training retrieval database is made of the same subsets as the training data, in proportion that matches the training sampling frequencies. During evaluation the retrieval database consists in the full union of these datasets, with the exception of books for which we use a sub-sample of $4%$ . The evaluation retrieval database thus contains 1.75T tokens. To limit test set leakage, we compute the 13-gram Jaccard similarity between train and test documents using the MinHash scheme and remove all training documents with high similarity (0.8 or higher) to a validation or test set document. Additionally, we remove all validation and test articles from Wikitext 103 (Merity et al., 2017) from our Wikipedia training data.

我们使用多语言版本的Massive Text (Rae等人,2021)作为训练和检索数据。该数据集包含来自多个来源和多种语言的文本文档,总计超过5万亿token (详见表1)。训练数据序列从各子集中按表1最右列所示的采样权重进行抽取。我们采用128,000 token词汇表的Sentence Piece (Kudo和Richardson,2018)进行分词处理。训练期间(除非另有说明),我们从训练数据的6000亿token中进行检索。训练检索数据库的构成比例与训练采样频率保持一致。评估时检索数据库由这些数据集的完整并集构成(书籍类数据采用$4%$的子样本),总计包含1.75万亿token。为防止测试集泄露,我们通过MinHash算法计算训练文档与测试文档的13-gram Jaccard相似度,并移除与验证集/测试集文档相似度高于0.8的训练文档。此外,我们从维基百科训练数据中剔除了Wikitext 103 (Merity等人,2017)的所有验证集和测试集文章。

2.2. Retrieval-enhanced auto regressive token models

2.2. 检索增强的自回归 Token 模型

Our approach uses retrieval as a way to augment input examples at the granularity of small chunks of tokens. Formally, we consider sequences of integer tokens in $\mathbb{V}=[1,\nu]$ , obtained using a text tokenizer1. We split each $n$ -token-long example $X=(x_{1},\ldots,x_{n})$ into a sequence of 𝑙 chunks $(C_{1},\dots,C_{l})$ of size $\begin{array}{r}{m={\frac{n}{l}}}\end{array}$ , i.e. $C_{1}\triangleq\left(x_{1},\ldots,x_{m}\right)$ , . . . , $\mathcal{C}_ {l}\triangleq(x_{n-m+1},...,x_{n})\in\mathbb{V}^{m}$ . We use $n=2048$ and $m=64$ . We augment each chunk $C_{u}$ with a set ${\mathrm{RET}}_ {\mathcal{D}}(C_{u})$ of $k$ neighbours from the database $\mathcal{D}$ . $\mathrm{RET}_{\mathscr D}$ (or

我们的方法采用检索机制,在细粒度的token小块级别上增强输入样本。形式上,我们考虑由文本分词器1得到的整数token序列$\mathbb{V}=[1,\nu]$。将每个包含$n$个token的样本$X=(x_{1},\ldots,x_{n})$分割为$l$个大小为$\begin{array}{r}{m={\frac{n}{l}}}\end{array}$的块序列$(C_{1},\dots,C_{l})$,即$C_{1}\triangleq\left(x_{1},\ldots,x_{m}\right)$,...,$\mathcal{C}_ {l}\triangleq(x_{n-m+1},...,x_{n})\in\mathbb{V}^{m}$。我们设定$n=2048$和$m=64$。对于每个块$C_{u}$,我们从数据库$\mathcal{D}$中检索$k$个最近邻集合${\mathrm{RET}}_ {\mathcal{D}}(C_{u})$进行增强。$\mathrm{RET}_{\mathscr D}$(或

Ret for brevity) is a non-trainable operator specified in $\S2.3$ . Token likelihoods are provided by a model, parameterized by $\theta$ , that takes as input both previous tokens and their retrieved neighbours. This defines the following retrieval-enhanced sequence log-likelihood:

Ret (为简洁起见) 是 $\S2.3$ 中指定的不可训练算子。Token 似然由模型提供,该模型由 $\theta$ 参数化,并将先前的 Token 及其检索到的邻居作为输入。这定义了以下检索增强的序列对数似然:

$$
L\left(X|\theta,\mathcal{D}\right)\triangleq\sum_{u=1}^{l}\sum_{i=1}^{m}\ell_{\theta}\left(x_{(u-1)m+i}|(x_{j})_ {j<(u-1)m+i},\left(\operatorname{RET}_ {\mathcal{D}}(C_{u^{\prime}})\right)_{u^{\prime}<u}\right).
$$

$$
L\left(X|\theta,\mathcal{D}\right)\triangleq\sum_{u=1}^{l}\sum_{i=1}^{m}\ell_{\theta}\left(x_{(u-1)m+i}|(x_{j})_ {j<(u-1)m+i},\left(\operatorname{RET}_ {\mathcal{D}}(C_{u^{\prime}})\right)_{u^{\prime}<u}\right).
$$

We set $\mathsf{R E T}(C_{1})=\emptyset$ , namely the likelihood of tokens from the first chunk does not depend on any retrieval data. This likelihood definition preserves auto regress iv it y: the probability of the $i\cdot$ -th token of the $u$ -th chunk, $x_{(u-1)m+i}$ , only depends on previously seen tokens $(x_{j})_ {1\leqslant j<(u-1)m+i}$ and on the data retrieved from the previous chunks $(\mathrm{RET}\left(C_{u^{\prime}}\right))_ {u^{\prime}<u}$ . We can therefore directly sample with logprobability $\ell$ , where sampling within the chunk $C_{u}$ is conditioned on the neighbours $(\mathrm{RET}(C_{u^{\prime}}))_{u^{\prime}<u}$ . This makes retrieval-enhanced models directly comparable with the largest language models that are evaluated by sampling.

我们设 $\mathsf{R E T}(C_{1})=\emptyset$,即第一个分块中 token 的概率不依赖于任何检索数据。这一似然定义保留了自回归性:第 $u$ 个分块的第 $i$ 个 token $x_{(u-1)m+i}$ 的概率仅取决于先前见过的 token $(x_{j})_ {1\leqslant j<(u-1)m+i}$ 以及从之前分块检索的数据 $(\mathrm{RET}\left(C_{u^{\prime}}\right))_ {u^{\prime}<u}$。因此,我们可以直接采样对数概率 $\ell$,其中分块 $C_{u}$ 内的采样以前邻分块的检索结果 $(\mathrm{RET}(C_{u^{\prime}}))_{u^{\prime}<u}$ 为条件。这使得检索增强模型可直接与通过采样评估的最大规模大语言模型进行对比。

2.3. Nearest neighbour retrieval

2.3. 最近邻检索

Retrieval neighbours. Our database consists of a key-value memory. Each value consists of two contiguous chunks of tokens which we denote $[N,F]$ where $N$ is the neighbour chunk which is used to compute the key, and $F$ is its continuation in the original document. The corresponding key is the Bert embedding of $N$ , averaged over time, that we denote ${\mathrm{BERT}}(N)$ . For each chunk $C$ , we retrieve its approximate $k$ -nearest neighbours from our key-value database using the $L_{2}$ distance on BERT embeddings $d(C,N)=||\mathrm{BERT}(C)-\mathrm{BERT}(N)||_ {2}^{2}$ . The model receives the corresponding values $\operatorname{RET}(C)\triangleq{\bigl(}[N^{1},F^{1}],\dots,[N^{k},F^{k}]{\bigr)}$ . Both neighbour chunks and their continuations provide meaningful improvements, as illustrated in our ablation study (Appendix D). We use a length 64 for both $N^{j}$ and $F^{j}$ , thus $\operatorname{RET}(C)$ has a shape of $k\times r$ with $r=128$ . To avoid retrieving the chunk $C_{u+1}$ in the retrieval set $\operatorname{RET}(C_{u})$ , which would break causality during training, we filter out neighbours originating from the same document as the training sequence $X$ .

检索邻居。我们的数据库由键值记忆组成。每个值包含两个连续的token块,记为$[N,F]$,其中$N$是用于计算键的邻居块,$F$是其在原始文档中的延续部分。对应的键是$N$的Bert嵌入随时间平均后的结果,记为${\mathrm{BERT}}(N)$。对于每个块$C$,我们使用BERT嵌入的$L_{2}$距离$d(C,N)=||\mathrm{BERT}(C)-\mathrm{BERT}(N)||_ {2}^{2}$从键值数据库中检索其近似的$k$近邻。模型接收对应的值$\operatorname{RET}(C)\triangleq{\bigl(}[N^{1},F^{1}],\dots,[N^{k},F^{k}]{\bigr)}$。如消融实验所示(附录D),邻居块及其延续部分均能带来显著改进。我们将$N^{j}$和$F^{j}$的长度均设为64,因此$\operatorname{RET}(C)$的形状为$k\times r$($r=128$)。为避免在检索集$\operatorname{RET}(C_{u})$中检索到块$C_{u+1}$(这会破坏训练时的因果关系),我们过滤掉与训练序列$X$同源文档的邻居。

For a database of $T$ elements, we can query the approximate nearest neighbours in ${\cal O}(\log T)$ time. We use the SCaNN library (Guo et al., 2020) to achieve this. This means that we can query our 2 trillion token database in $10\mathrm{m}s$ whilst evaluating or sampling from the model; this expense is amortized over a chunk length. Performing retrieval on-the-fly is too slow to keep up with the training calculations—we leverage the frozen aspect of the embedding operator Bert to precompute all approximate nearest neighbours and save the results as part of the data. In Fig. 9 in the Appendix, we show results where we only retrieve neighbours within Wikipedia. We find that neighbours tend to come from 2-3 links away from a given article whereas random articles are more than 5 links apart.

对于一个包含 $T$ 个元素的数据库,我们可以在 ${\cal O}(\log T)$ 时间内查询近似最近邻。我们使用 SCaNN 库 (Guo et al., 2020) 实现这一功能。这意味着在模型评估或采样过程中,我们能在 $10\mathrm{m}s$ 内查询包含 2 万亿 token 的数据库;这部分开销会分摊到每个数据块长度上。实时检索的速度过慢,无法跟上训练计算节奏——我们利用嵌入算子 Bert 的冻结特性预先计算所有近似最近邻,并将结果作为数据的一部分保存。附录中的图 9 展示了仅检索维基百科内部邻居的结果,我们发现邻居通常来自给定文章 2-3 个链接范围内的页面,而随机文章间的距离则超过 5 个链接。

Table 1 | Massive Text. The last column indicates the sampling weight during training. The multilingual subsets include documents in 10 languages. The full breakdown is given in $\S\mathrm{A}.1$ .

SourceToken count (M)Documents (M)MultilingualSampling frequency
Web977,5631,208Yes55%
Books3,423,74020No25%
News236,918398No10%
Wikipedia13,28823Yes5%
GitHub374,952143No5%

表 1: 海量文本。最后一列表示训练时的采样权重。多语言子集包含10种语言的文档。完整分类见 $\S\mathrm{A}.1$。

来源 Token计数(百万) 文档数(百万) 多语言 采样频率
网页 977,563 1,208 55%
书籍 3,423,740 20 25%
新闻 236,918 398 10%
维基百科 13,288 23 5%
GitHub 374,952 143 5%

2.4. Retro model architecture

2.4. Retro模型架构

Our model relies on an encoder-decoder transformer architecture, integrating the retrieved data through a cross-attention mechanism as introduced in Vaswani et al. (2017). First, the retrieved tokens $\operatorname{RET}(C)$ are fed into an encoder Transformer, which computes the encoded neighbours set $E$ . Denoting the intermediate activation s by $H$ , our transformer decoder then interleaves Retro-blocks $\mathrm{RETRo}(H,E)$ and standard Transformer blocks $\operatorname{LM}(H)$ (the hyper parameter $P\subseteq[1,L]$ determines at which layers we use a Retro-block). These blocks are built from three different residual operators with signature $\mathbb{R}^{n\times d}\rightarrow\mathbb{R}^{n\times d}$ : a fully-connected layer Ffw, the standard sequence-level self-attention layer Attn, and a chunked cross-attention layer $\mathbf{C}\mathbf{C}\mathbf{A}(\cdot,E)$ that incorporates information from the retrieval encoder:

我们的模型基于编码器-解码器Transformer架构,通过Vaswani等人(2017)提出的交叉注意力机制整合检索数据。首先,检索到的token $\operatorname{RET}(C)$ 被输入编码器Transformer,计算得到编码邻域集 $E$。以 $H$ 表示中间激活状态,我们的Transformer解码器交替使用Retro块 $\mathrm{RETRo}(H,E)$ 和标准Transformer块 $\operatorname{LM}(H)$(超参数 $P\subseteq[1,L]$ 决定在哪些层使用Retro块)。这些模块由三个具有 $\mathbb{R}^{n\times d}\rightarrow\mathbb{R}^{n\times d}$ 特征的残差算子构成:全连接层Ffw、标准序列级自注意力层Attn,以及从检索编码器整合信息的分块交叉注意力层 $\mathbf{C}\mathbf{C}\mathbf{A}(\cdot,E)$:

$$
\operatorname{RETRO}\left(H,E\right)\triangleq\operatorname{FFW}\left(\operatorname{CcA}\left(\operatorname{ATTN}\left(H\right),E\right)\right),\quad\mathrm{and}\quad\operatorname{LM}\left(H\right)\triangleq\operatorname{FFW}\left(\operatorname{ATTN}\left(H\right)\right)
$$

$$
\operatorname{RETRO}\left(H,E\right)\triangleq\operatorname{FFW}\left(\operatorname{CcA}\left(\operatorname{ATTN}\left(H\right),E\right)\right),\quad\mathrm{and}\quad\operatorname{LM}\left(H\right)\triangleq\operatorname{FFW}\left(\operatorname{ATTN}\left(H\right)\right)
$$

Since Ffw, Attn and Cca are all auto regressive operators whose output at position 𝑖 only depends on $(h_{j})_{j\leqslant i}.$ any succession of Retro and lm layers, followed by a token classification head defines an auto regressive log-likelihood (1). An overview of the model architecture is given in Algorithm 1 and in Fig. 2. We next describe the retrieval encoder and the chunked cross-attention layer in more detail, and explain how to sample from Retro.

由于 Ffw、Attn 和 Cca 都是自回归算子,其位置 𝑖 的输出仅取决于 $(h_{j})_{j\leqslant i}$,因此任何 Retro 和 lm 层的连续组合,加上一个 token 分类头,就定义了一个自回归对数似然 (1)。算法 1 和图 2 给出了模型架构的概览。接下来我们将更详细地描述检索编码器和分块交叉注意力层,并解释如何从 Retro 中进行采样。

Encoding retrieval neighbours. For each chunk $C_{u}.$ the $k$ retrieval neighbours $\operatorname{RET}(C_{u})$ are fed into a bi-directional transformer Encoder, yielding the outputs $E_{u}^{j}\triangleq\operatorname{ENCoDER}\left(\operatorname{RET}\left(C_{u}\right)^{j},H_{u}\right)\in\mathbb{R}^{r\times d^{\prime}}$ , where $j\in[1,k]$ indexes each neighbour. The retrieval encoder is a non-causal transformer. It is conditioned on $H_{u}$ , the activation s of chunk $C_{u}$ , through cross-attention layers; this allows the representations of the retrieval encoder to be modulated by the retrieving chunk in a differentiable way. More precisely, the encoding of the $j^{\mathrm{th}}$ neighbour of the $\boldsymbol{u}^{\mathrm{th}}$ chunk, $\mathrm{RET}(C_{u})^{j}$ , depends on the attended activation $H_{u}\triangleq\left(h_{(u-1)m+i}\right)_ {i\in[1,m]}\in\mathbb{R}^{m\times d}$ of chunk $C_{u}$ at layer $\operatorname*{min}(P)$ . All neighbours for all chunks are encoded in parallel, yielding a full encoded set $\begin{array}{r}{E\triangleq\big(E_{u}^{j}\big)_ {u\in[1,l],j\in[1,k]}\in\mathbb{R}^{l\times k\times r\times d^{\prime}}}\end{array}$ . We denote $E_{u}\in\mathbb{R}^{k\times r\times d^{\prime}}$ as the encoded neighbours for chunk $u\in[1,l]$ .

编码检索邻块。对于每个块 $C_{u}$,其 $k$ 个检索邻块 $\operatorname{RET}(C_{u})$ 被输入双向Transformer编码器,生成输出 $E_{u}^{j}\triangleq\operatorname{ENCoDER}\left(\operatorname{RET}\left(C_{u}\right)^{j},H_{u}\right)\in\mathbb{R}^{r\times d^{\prime}}$,其中 $j\in[1,k]$ 表示每个邻块的索引。该检索编码器采用非因果Transformer结构,通过交叉注意力层以块 $C_{u}$ 的激活状态 $H_{u}$ 为条件,使得检索邻块的表征能够以可微分方式受检索块调控。具体而言,第 $\boldsymbol{u}$ 个块的第 $j^{\mathrm{th}}$ 个邻块 $\mathrm{RET}(C_{u})^{j}$ 的编码,取决于块 $C_{u}$ 在 $\operatorname*{min}(P)$ 层的注意力激活状态 $H_{u}\triangleq\left(h_{(u-1)m+i}\right)_ {i\in[1,m]}\in\mathbb{R}^{m\times d}$。所有块的邻块均并行编码,最终生成完整编码集合 $\begin{array}{r}{E\triangleq\big(E_{u}^{j}\big)_ {u\in[1,l],j\in[1,k]}\in\mathbb{R}^{l\times k\times r\times d^{\prime}}}\end{array}$。记 $E_{u}\in\mathbb{R}^{k\times r\times d^{\prime}}$ 为块 $u\in[1,l]$ 的编码邻块集合。

Chunked cross-attention. To perform the Cca operation, we first split a given intermediate activation $H\in\mathbb{R}^{n\times d}$ into $l{-}1$ attending chunks $\begin{array}{r}{\left(H_{u}^{+}\triangleq\stackrel{\overleftarrow{\big(h_{u m+i-1}\big)}_ {i\in[1,m]}}{\overleftarrow{\big(\mathrm{I}_ {u}^{m+i-1}\big)}_ {i\in[1,m]}}\in\mathbb{R}^{m\times d}\right)_ {u\in[1,l-1]},}\end{array}$ as depicted on the right of Fig. 2. $H_{u}^{+}$ holds the intermediary embeddings of the last token in chunk $C_{u}$ and of the first $m-1$ tokens in $C_{u+1}$ 2. We compute the cross-attention between $H_{u}^{+}$ and $E_{u}$ —the encoded retrieval set obtained from chunk $C_{u}$ . Attention is computed across time and across neighbours simultaneously, as we merge the neighbour and time dimensions of $E_{u}$ before applying cross-attention. Since there is a notion of alignment between data chunks and retrieval neighbours, we use relative positional encodings as described in $\S\mathrm{B}.1.2$ .

分块交叉注意力。为执行Cca操作,我们首先将给定的中间激活$H\in\mathbb{R}^{n\times d}$拆分为$l{-}1$个注意力块$\begin{array}{r}{\left(H_{u}^{+}\triangleq\stackrel{\overleftarrow{\big(h_{u m+i-1}\big)}_ {i\in[1,m]}}{\overleftarrow{\big(\mathrm{I}_ {u}^{m+i-1}\big)}_ {i\in[1,m]}}\in\mathbb{R}^{m\times d}\right)_ {u\in[1,l-1]},}\end{array}$,如图2右侧所示。$H_{u}^{+}$包含块$C_{u}$中最后一个token及$C_{u+1}$中前$m-1$个token的中间嵌入。我们计算$H_{u}^{+}$与$E_{u}$(从块$C_{u}$获取的编码检索集)之间的交叉注意力。由于需要同时跨时间和跨邻居计算注意力,我们在应用交叉注意力前合并了$E_{u}$的邻居和时间维度。由于数据块与检索邻居间存在对齐关系,我们采用$\S\mathrm{B}.1.2$所述的相对位置编码方法。

We concatenate the $l{-}1$ outputs of the per-chunk cross-attentions (each of shape $m\times d$ ) across time, and properly pad the result; we thus form the output activation $\mathbf{C}\mathbf{C}\mathbf{A}(H,E)\in\mathbb{R}^{n\times d}$ . Formally, for each chunk $C_{u}$ and for each token $i\in[1,m]$ we set

我们将各时间段的跨注意力(每段形状为$m\times d$)的$l{-}1$个输出在时间维度上拼接,并进行适当填充;由此形成输出激活$\mathbf{C}\mathbf{C}\mathbf{A}(H,E)\in\mathbb{R}^{n\times d}$。具体而言,对于每个分块$C_{u}$和每个token $i\in[1,m]$,我们设定

$$
\mathtt{C C A}(H,E)_ {u m+i-1}\triangleq\mathtt{C A}(h_{u m+i-1},E_{u}),
$$

$$
\mathtt{C C A}(H,E)_ {u m+i-1}\triangleq\mathtt{C A}(h_{u m+i-1},E_{u}),
$$

Algorithm 1: Overview of Retro model architecture.

算法 1: Retro 模型架构概述

def $E N C O D E R(R E T(C_{u})_ {1\leqslant u\leqslant l},H)$ : $(H_{u})_ {u\in[1,l]}\leftarrow\operatorname{SpLIT}(H)$ for $j\in[1,k],u\in[1,l]\mathrm{c}$ o // Encoder shared across neighbours and chunks $E_{u}^{j}=\operatorname{EmB}_ {\operatorname{enc}}(\operatorname{RET}(C_{u})^{j})$ // May be shared with the decoder E M B for $p^{\prime}\in[1,L_{e n c}]$ do $E_{u}^{j}\gets\mathrm{ATTN}_ {\mathrm{enc}}(E_{u}^{j})$ // Bi-directional attention if $p^{\prime}\in P_{e n c}$ then $\begin{array}{r l}{\vert}&{{}E_{u}^{j}\leftarrow\mathrm{CA}_ {\mathrm{enc}}(E_{u}^{j},H_{u})}\end{array}$ 𝐸𝑢𝑗 ← Ffwenc(𝐸𝑢𝑗 ) return 𝐸

def $ENCODE(RET(C_{u})_ {1\leqslant u\leqslant l},H)$ : $(H_{u})_ {u\in[1,l]}\leftarrow\operatorname{SpLIT}(H)$ for $j\in[1,k],u\in[1,l]\mathrm{c}$ o // 编码器在邻居和分块间共享 $E_{u}^{j}=\operatorname{EmB}_ {\operatorname{enc}}(\operatorname{RET}(C_{u})^{j})$ // 可能与解码器EMB共享 for $p^{\prime}\in[1,L_{e n c}]$ do $E_{u}^{j}\gets\mathrm{ATTN}_ {\mathrm{enc}}(E_{u}^{j})$ // 双向注意力 if $p^{\prime}\in P_{e n c}$ then $\begin{array}{r l}{\vert}&{{}E_{u}^{j}\leftarrow\mathrm{CA}_ {\mathrm{enc}}(E_{u}^{j},H_{u})}\end{array}$ 𝐸𝑢𝑗 ← Ffwenc(𝐸𝑢𝑗 ) return 𝐸

where Ca is the cross-attention residual operator over time-concatenated encoded neighbours. We recall that this operator is defined in its simplest version by three parameter matrices $K\in\mathbb{R}^{d\times c}$ , $Q\in$ $\mathbb{R}^{d\times c}$ and $V\in\mathbb{R}^{d\times d}$ . For all $h\in\mathbb{R}^{d}$ and $Y\in\mathbb{R}^{T\times d}$ , we define

其中Ca是时间拼接编码邻居的交叉注意力残差算子。我们回顾该算子最简单的版本由三个参数矩阵$K\in\mathbb{R}^{d\times c}$、$Q\in\mathbb{R}^{d\times c}$和$V\in\mathbb{R}^{d\times d}$定义。对于所有$h\in\mathbb{R}^{d}$和$Y\in\mathbb{R}^{T\times d}$,我们定义

$$
\operatorname{CA}(h,Y)\triangleq\operatorname{softmax}(Y K Q^{T}h)Y V,
$$

$$
\operatorname{CA}(h,Y)\triangleq\operatorname{softmax}(Y K Q^{T}h)Y V,
$$

where the softmax is performed on the second dimension and all products are matrix products. We use multi-head cross-attention, and add positional encodings to the softmax(see $\S\mathrm{B}.1.2)$ .

其中softmax在第二个维度执行,所有乘积均为矩阵乘积。我们使用多头交叉注意力机制,并在softmax中添加位置编码(参见 $\S\mathrm{B}.1.2)$。

The first $m-1$ tokens cannot attend to any neighbour of a previous chunk; at these positions, we define Cca as the identity, setting $\operatorname{Cca}(H,E)_ {j}\triangleq h_{j}$ for all tokens $j\in[1,m-1]$ . Finally, the last token $h_{l m}$ attends to the last retrieval set $E_{l}$ and we set $h_{l m}\triangleq\mathbf{C}\mathrm{A}\left(h_{l m},E_{l}\right)$ (not shown in Fig. 2). Listing 1 contains a simplified implementation of Cca. Note that chunked cross-attention is auto regressive: the output of Cca at position 𝑖 depends on the sequence from tokens from 0 to 𝑖 that is input to Cca.

前 $m-1$ 个 token 无法关注前一个块的任何邻居;在这些位置,我们将 Cca 定义为恒等映射,对所有 $j\in[1,m-1]$ 的 token 设置 $\operatorname{Cca}(H,E)_ {j}\triangleq h_{j}$。最后,末尾 token $h_{l m}$ 关注最后的检索集 $E_{l}$,并设置 $h_{l m}\triangleq\mathbf{C}\mathrm{A}\left(h_{l m},E_{l}\right)$ (图 2 中未展示)。代码清单 1 给出了 Cca 的简化实现。注意分块交叉注意力是自回归的:Cca 在位置 𝑖 的输出取决于输入到 Cca 的 0 至 𝑖 的 token 序列。

With Retro models, even though each Cca cross-attention attends only to the neighbours of the preceding chunk $\mathrm{RET}(C_{u-1})$ , the dependencies over previous neighbours are propagated via the self-attention operations. The activation s of the $i^{\mathrm{th}}$ token in the $\boldsymbol{u}^{\mathrm{th}}$ chunk therefore potentially depend upon the set of $a l l$ previous neighbours $\mathrm{RET}\left(C_{u^{\prime}}\right)_{u^{\prime}<u}$ , without incurring the quadratic cost of cross attending to that set.

对于Retro模型,尽管每个Cca交叉注意力仅关注前一个块$\mathrm{RET}(C_{u-1})$的相邻部分,但通过自注意力操作实现了对先前相邻部分的依赖传递。因此,第$\boldsymbol{u}^{\mathrm{th}}$块中第$i^{\mathrm{th}}$个token的激活状态可能依赖于所有先前相邻块$\mathrm{RET}\left(C_{u^{\prime}}\right)_{u^{\prime}<u}$的集合,而无需承担对该集合进行交叉注意力的二次计算成本。

Sampling. When sampling, at the end of a chunk $C_{u}$ , we use SCaNN to retrieve neighbours $\textstyle\operatorname{RET}(C_{u})$ , based on the embedding ${\mathrm{B}}{\mathrm{E}}{\mathrm{RT}}(C_{u})$ . The encoded neighbours $E_u = \mathrm{ENCODER}\left( \mathrm{RET}\big(C_u\big) \right)$ are then used to condition the generation of the next chunk $C_{u+1}$ , which we do increment ally: overall the cost of sampling is thus quadratic in the size of the sampled sequence, as when sampling from regular Transformers; the added cost of retrieval is linear in the number of chunks $l$ , and is negligible compared to the token sampling cost in practice.

采样。在采样时,在块 $C_{u}$ 的末尾,我们使用 SCaNN 基于嵌入 ${\mathrm{B}}{\mathrm{E}}{\mathrm{RT}}(C_{u})$ 检索邻居 $\textstyle\operatorname{RET}(C_{u})$。编码后的邻居 $E_u = \mathrm{ENCODER}\left( \mathrm{RET}\big(C_u\big) \right)$ 随后用于生成下一个块 $C_{u+1}$,这一过程是增量式的:因此,采样的总成本与采样序列的大小呈平方关系,这与常规 Transformer 采样时的情况相同;检索的额外成本与块数 $l$ 呈线性关系,在实践中相比 token 采样成本可以忽略不计。

2.5. Baseline Transformer architecture

2.5. 基线 Transformer 架构

We use a transformer (Vaswani et al., 2017) similar to the one described in (Radford et al., 2019), with some minimal changes: we replace LayerNorm with RMSNorm (Zhang and Sennrich, 2019) and use relative position encodings (Dai et al., 2019). As baselines, we train retrieval-free transformers with 132M, 368M, 1.3B and 7.0B parameters (embedding matrices are excluded from parameter counts). The hyper parameters we used are detailed in Table 2. All retrieval models use the same size encoder for the retrieval data, with $d^{\prime}=896$ and 2 layers, which roughly adds 19𝑀 parameters. The encoder uses relative positional encodings. The retrieval models contain one Retro-block every 3 blocks, starting from layer 6. For our smallest model, Cca is applied in layers 6, 9 and 12 of the main pathway and also once for query conditioning in the encoder, which adds an additional 12𝑀 parameters. The relative number of extra parameters reduces as we increase the baseline model size. All models are implemented using JAX (Bradbury et al., 2018) and Haiku (Hennigan et al., 2020).

我们采用了一种与 (Radford et al., 2019) 中描述的类似的 Transformer (Vaswani et al., 2017) ,仅做了少量修改:将 LayerNorm 替换为 RMSNorm (Zhang and Sennrich, 2019) ,并使用相对位置编码 (Dai et al., 2019) 。作为基线,我们训练了参数量分别为 132M、368M、1.3B 和 7.0B(不计入嵌入矩阵参数)的无检索 Transformer 。具体超参数见表 2 。所有检索模型对检索数据使用相同大小的编码器,其中 $d^{\prime}=896$ ,共 2 层,约增加 19M 参数。编码器采用相对位置编码。检索模型从第 6 层开始,每 3 个块插入一个 Retro 块。对于最小模型,Cca 应用于主路径的第 6、9、12 层,并在编码器中额外执行一次查询条件处理,共增加 12M 参数。随着基线模型规模增大,额外参数的相对比例逐渐降低。所有模型均基于 JAX (Bradbury et al., 2018) 和 Haiku (Hennigan et al., 2020) 实现。

2.6. Quantifying dataset leakage exploitation

2.6. 量化数据集泄露利用

Retro models may arguably benefit more easily from evaluation dataset leakage, i.e. the fact that we evaluate on data that were also present in the training set. To better understand how retrieval improves language modelling performance, we therefore quantify evaluation likelihood as a function of the overlap between the evaluation and training datasets.

回溯模型(Retro models)可能更容易从评估数据集泄露中获益,即我们在训练集中也出现过的数据上进行评估。为了更好地理解检索如何提升语言建模性能,我们因此将评估似然量化为评估数据集与训练数据集之间重叠程度的函数。

The following approach can be used with any language model, and depends only on the frozen retriever system presented in $\S2.3$ . We split the evaluation sequences $(X_{i})_{i}$ into chunks of length $m\leq64$ , and we see the training data as a set of chunks $c$ . For each evaluation chunk $C\in C$ , we retrieve the 10 closest neighbours (of length up to 128) in the training data. We then compute the longest token substring common to both the evaluation chunk and its neighbours. This gives a number $s\in[0,m]$ . The value $\textstyle r(C)={\frac{s}{m}}$ , ranging from 0 (chunk never seen) to 1 (chunk entirely seen), gives a reliable indication of how much overlap there is between the evaluation chunk and the training data. For a given model, we then obtain the log-likelihood $\ell(C)$ of each chunk $C$ , and the number of bytes $N(C)$ it encodes. We then consider the filtered bits-per-bytes of the model:

以下方法适用于任何语言模型,且仅依赖于$\S2.3$中提出的冻结检索系统。我们将评估序列$(X_{i})_{i}$分割为长度$m\leq64$的块,并将训练数据视为块集合$c$。对于每个评估块$C\in C$,我们从训练数据中检索10个最接近的邻居(长度不超过128),然后计算评估块与其邻居之间最长的公共token子串,得到数值$s\in[0,m]$。比值$\textstyle r(C)={\frac{s}{m}}$范围从0(完全未见过该块)到1(完全见过该块),可靠地反映了评估块与训练数据的重叠程度。对于给定模型,我们获取每个块$C$的对数似然$\ell(C)$及其编码的字节数$N(C)$,进而计算模型的过滤字节比特率:

$$
\forall\alpha\in[0,1],\quad C_{\alpha}\triangleq{C\in C,r(C)\leqslant\alpha},\quad\mathrm{bpb}(\alpha)\triangleq\frac{\sum_{C\in C_{\alpha}}\ell(C)}{\sum_{C\in C_{\alpha}}N(C)},
$$

$$
\forall\alpha\in[0,1],\quad C_{\alpha}\triangleq{C\in C,r(C)\leqslant\alpha},\quad\mathrm{bpb}(\alpha)\triangleq\frac{\sum_{C\in C_{\alpha}}\ell(C)}{\sum_{C\in C_{\alpha}}N(C)},
$$

Table 2 | Number of parameters for our baseline and Retro models, excluding embeddings, along with the corresponding hyper parameters.

Baseline parametersRETROddffw#headsHead size# layers
132M172M (+30%)8963,584166412
368M425M (+15%)1,5366,1441212812
1,309M1,451M I (+11%)2,0488,1921612824
6,982M7,532M (+8%)4,09616,3843212832

表 2: 基准模型和Retro模型的参数量(不包括嵌入层)及对应超参数

基准参数量 RETRO d dffw #heads Head size # layers
132M 172M (+30%) 896 3,584 16 64 12
368M 425M (+15%) 1,536 6,144 12 128 12
1,309M 1,451M (+11%) 2,048 8,192 16 128 24
6,982M 7,532M (+8%) 4,096 16,384 32 128 32

which correspond to the bits-per-bytes on the set of chunks that overlap less than $\alpha%$ with the training chunks. Note that the full evaluation bit-per-bytes performance is recovered by bpb(1). The function bpb(·) allows us to evaluate the impact of evaluation leakage over predictive performance: for low $\alpha$ , bpb $(\alpha)$ gives an indication on how the model performs on chunks that are entirely new; the slope of bpb(·) shows how much the model exploits evaluation leakage.

对应的是与训练数据块重叠度低于$\alpha%$的数据块集合上的比特每字节(bpb)表现。注意,完整的评估比特每字节性能由bpb(1)恢复。函数bpb(·)使我们能够评估预测性能中评估泄漏的影响:对于较低的$\alpha$,bpb$(\alpha)$表明模型在全新数据块上的表现;bpb(·)的斜率则显示模型利用评估泄漏的程度。

3. Related Work

3. 相关工作

We first review existing work on using retrieval for language modelling, and compare Retro to these works (see Table 3). As we train Retro models on a large dataset containing a substantial section of the internet, our work raises potential privacy, safety, and fairness issues that we then review.

我们首先回顾了利用检索增强语言建模的现有工作,并将Retro与这些研究进行了对比(见表3)。由于我们在包含大量互联网内容的大规模数据集上训练Retro模型,这项工作引发了隐私、安全与公平性等潜在问题,我们将对此展开讨论。

3.1. Retrieval for language modelling

3.1. 语言建模检索

Brants et al. (2007) show that scaling the training data to trillions of tokens improves the machine translation performance of $n$ -gram models. More recently, GPT-2 (Radford et al., 2019), GPT-3 (Brown et al., 2020), and Jurassic-1 (Lieber et al., 2021) show that scaling up language models leads to massive improvements on many downstream tasks. At the same time, Carlini et al. (2021) demonstrate that large-scale language models can perfectly memorise parts of their training data, suggesting that enhancing models with retrieval may lead to further improvements. However, significant leakage between train and test datasets (Lee et al., 2021; Lewis et al., 2021) makes comparing and evaluating large models trained on large datasets difficult, especially once retrieval capabilities over the training dataset are added.

Brants等人 (2007) 研究表明,将训练数据规模扩展到数万亿token可以提升 $n$ 元文法模型的机器翻译性能。近期,GPT-2 (Radford等人,2019)、GPT-3 (Brown等人,2020) 和Jurassic-1 (Lieber等人,2021) 证明扩展大语言模型规模能显著提升下游任务表现。与此同时,Carlini等人 (2021) 指出大规模语言模型会完整记忆部分训练数据,这表明引入检索机制可能带来进一步改进。然而训练集与测试集间存在严重数据泄露 (Lee等人,2021;Lewis等人,2021),这使得评估基于海量数据训练的大模型变得困难,特别是在模型具备训练数据检索能力的情况下。

Historically, information retrieval for text relies on inverted index matching such as TF-IDF and BM25 (Robertson and Zaragoza, 2009). Foundational work use latent topic modelling approaches like LDA (Blei et al., 2003) to identify relevant neighbours (Wei and Croft, 2006). Work in machine translation such as Zhang et al. (2018) and Gu et al. (2018) retrieve translation pairs based on edit distance between source sentences and guide the translation output using the closest retrieved target sentences. The retrieval database may also be structured — for example, Ahn et al. (2016) use a symbolic knowledge graph to improve an RNN language model.

历史上,文本信息检索依赖于倒排索引匹配技术,如 TF-IDF 和 BM25 (Robertson and Zaragoza, 2009)。基础性工作采用潜在主题建模方法,如 LDA (Blei et al., 2003) 来识别相关邻近内容 (Wei and Croft, 2006)。机器翻译领域的研究,如 Zhang et al. (2018) 和 Gu et al. (2018),基于源语句之间的编辑距离检索翻译对,并使用最接近的检索目标语句指导翻译输出。检索数据库也可以是结构化的——例如,Ahn et al. (2016) 使用符号化知识图谱改进 RNN 语言模型。

With the success of deep learning, retrieving systems have partly switched to dense learned representations based on a neural network’s activation s. Continuous cache (Grave et al., 2017) adds probability mass to tokens for which previous activation s resemble the current activation vector, extending the model’s context to the local history. 𝑘NN-LM (Khandelwal et al., 2020) applies this idea to transformers and extends the retrieval database to English Wikipedia, resulting in substantial improvements on Wikitext 103 evaluation. Continuous cache and 𝑘NN-LM do not modify the underlying neural-network models, but interpolate at inference between the language model’s output and distributions computed from retrieved tokens. These methods can therefore be plugged into any model without additional training, although this limits the model’s ability to reason about the retrieved text. Spalm (Yogatama et al., 2021) addresses this limitation by adding an extra gating network to post-process the retrieved data; yet most of the network is unaffected by the retrieval during inference.

随着深度学习的成功,检索系统已部分转向基于神经网络激活的密集学习表征。Continuous cache (Grave等人,2017) 为那些先前激活与当前激活向量相似的token增加概率质量,从而将模型的上下文扩展到局部历史。𝑘NN-LM (Khandelwal等人,2020) 将这一思想应用于Transformer,并将检索数据库扩展至英文维基百科,在Wikitext 103评估中取得了显著改进。Continuous cache和𝑘NN-LM并不修改底层神经网络模型,而是在推理时对语言模型的输出与基于检索token计算的分布进行插值。因此,这些方法可以无需额外训练就嵌入任何模型,尽管这会限制模型对检索文本的推理能力。Spalm (Yogatama等人,2021) 通过添加额外的门控网络对检索数据进行后处理来解决这一限制;然而在推理过程中,网络的大部分仍不受检索影响。

Table 3 | Comparison of Retro with existing retrieval approaches.

#RetrievaltokensGranularityRetriever trainingRetrieval integration
ContinuousCache(103TokenFrozen (LSTM)Add to probs
kNN-LMO 109TokenFrozen (Transformer)Add to probs
SPALM109TokenFrozen (Transformer)Gated logits
DPR109PromptContrastiveproxyExtractive QA
REALMO 109PromptEnd-to-EndPrepend to prompt
RAG109PromptFine-tunedDpRCross-attention
F1D109PromptFrozenDpRCross-attention
EMDR2O 109PromptEnd-to-End (EM)Cross-attention
RETRO (ours)1012ChunkFrozen (BERT)Chunked cross-attention

表 3 | Retro与现有检索方法的对比。

#检索Token数 粒度 检索器训练 检索集成方式
ContinuousCache (10^3 Token 冻结 (LSTM) 概率叠加
kNN-LM O(10^9) Token 冻结 (Transformer) 概率叠加
SPALM 10^9 Token 冻结 (Transformer) 门控logits
DPR 10^9 Prompt 对比代理 抽取式QA
REALM O(10^9) Prompt 端到端 前置Prompt
RAG 10^9 Prompt 微调DPR 交叉注意力
F1D 10^9 Prompt 冻结DPR 交叉注意力
EMDR2 O(10^9) Prompt 端到端 (EM) 交叉注意力
RETRO (ours) 10^12 Chunk 冻结 (BERT) 分块交叉注意力

The retrieval representations may be trained directly instead of relying on a pre-trained model— retriever systems have been developed for this purpose, primarily on open-domain question answering. For example, Dpr (Karpukhin et al., 2020) trains two Bert models (for queries and keys respectively) using a contrastive loss to align the representations of a question and of its answers. Lee et al. (2019) use an inverse cloze task to find semantic representations of passages for retrieval. These works differs from continuous cache and 𝑘NN-LM in that they embeds passages (or chunks) of text together, as opposed to each token individually. The retriever network is trained in isolation of the downstream task that uses the retrieval data. This potential issue is specifically addressed by Realm (Guu et al., 2020), which trains the retrieval system end-to-end to maximize the final training cross-entropy. This comes with the extra complexity of searching the database during training and periodically updating the embedding table, severely limiting the scale at which it can operate. RAG (Lewis et al., 2020) and FiD (Izacard and Grave, 2021) build upon Dpr to set the state of the art on question answering benchmarks by training encoder-decoder transformer models. More recently, $\mathrm{E}\mathbf{M}\mathbf{D}\mathbf{R}^{2}$ (Sachan et al., 2021) extends FiD by using an expectation-maximization algorithm to train the retriever end-to-end and achieves state of the art results compared to similarly sized models.

检索表征可以直接训练,而不必依赖预训练模型——为此开发的检索系统主要应用于开放域问答任务。例如,Dpr (Karpukhin等人,2020) 使用对比损失训练两个Bert模型(分别处理查询和键)以对齐问题与其答案的表征。Lee等人(2019)采用逆完形填空任务来获取段落语义表征以进行检索。这些工作与连续缓存和𝑘NN-LM的区别在于它们将文本段落(或块)整体嵌入,而非单独处理每个token。检索网络的训练独立于使用检索数据的下游任务。Realm (Guu等人,2020)专门解决了这一潜在问题,通过端到端训练检索系统以最大化最终训练的交叉熵。这带来了在训练期间搜索数据库和定期更新嵌入表的额外复杂性,严重限制了其可操作的规模。RAG (Lewis等人,2020)和FiD (Izacard和Grave,2021)基于Dpr,通过训练编码器-解码器Transformer模型,在问答基准测试中达到最先进水平。最近,$\mathrm{E}\mathbf{M}\mathbf{D}\mathbf{R}^{2}$ (Sachan等人,2021)扩展了FiD,使用期望最大化算法端到端训练检索器,在同等规模模型中取得了最先进的结果。

In the open-domain dialogue setting, BlenderBot 2.0 (Komeili et al., 2021) learns to issue textual internet queries, outperforming dense retrieval methods when evaluated on a task measuring how close model responses are to those of humans. This involves collecting a dataset of human dialogues with associated search queries, which limits the s cal ability of this approach. Hashemi et al. (2020) introduce the Guided Transformer, a modified Transformer similar to Retro, for document retrieval and clarifying question selection. Although effective on question answering and other tasks with strong conditioning, none of these methods are designed to model arbitrary text sequences, in contrast with Retro.

在开放域对话场景中,BlenderBot 2.0 (Komeili et al., 2021) 学会了发起文本网络查询,在评估模型响应与人类响应接近程度的任务上表现优于密集检索方法。该方法需要收集附带搜索查询的人类对话数据集,这限制了其扩展能力。Hashemi et al. (2020) 提出了Guided Transformer(一种类似于Retro的改进版Transformer),用于文档检索和澄清问题选择。虽然这些方法在问答和其他强条件任务上表现良好,但与Retro不同,它们都不是为建模任意文本序列而设计的。

Retro shares components with 𝑘NN-LM and Dpr in that it uses frozen retrieval representations. Retro models longer sequences than QA examples; this requires to reason at a sub-sequence level, and to retrieve different documents for the different chunks of a sequence. Similar to FiD, Retro processes the retrieved neighbours separately in the encoder, and assemble them in the chunked cross-attention. This differs from e.g. Realm, that prepends retrieved documents to the prompt. Using chunks allows for repeated retrieval whilst generating a sequence as opposed to retrieving only once based on the prompt alone. Furthermore, retrieval is done during the whole pre-training process in Retro, and is not simply plugged-in to solve a certain downstream task. Finally, previous methods based on dense query vectors use small models and retrieval datasets with less than 3B tokens (English Wikipedia). Table 3 summarizes the difference of Retro with existing approaches.

Retro 与 𝑘NN-LM 和 Dpr 共享组件,均使用冻结检索表征。相比 QA 示例,Retro 能建模更长的序列,这需要在子序列级别进行推理,并为序列的不同片段检索不同文档。与 FiD 类似,Retro 在编码器中分别处理检索到的邻近项,并在分块交叉注意力中组装它们。这与 Realm 等方法不同,后者将检索到的文档前置到提示词中。分块机制支持在生成序列时重复检索,而非仅基于初始提示词单次检索。此外,Retro 在整个预训练过程中持续进行检索,而非简单插入来解决特定下游任务。最后,此前基于稠密查询向量的方法使用小型模型和不足 30 亿 token 的检索数据集(英文维基百科)。表 3 总结了 Retro 与现有方法的差异。

3.2. Privacy, safety and fairness

3.2. 隐私、安全与公平性

Bender et al. (2021); Weidinger et al. (2021) highlight several dangers of large language models. Those stem from their ability to memorise training data, their high training cost, the static nature of their training data (Lazaridou et al., 2021), their tendency of amplifying inherent biases in the training data, and their ability to generate toxic language (Gehman et al., 2020). In this section we inspect these dangers, focusing on how retrieval augmented language models may exacerbate or

Bender等人 (2021) 和Weidinger等人 (2021) 强调了大语言模型的若干风险。这些风险源于其记忆训练数据的能力、高昂的训练成本、训练数据的静态特性 (Lazaridou等人, 2021)、放大训练数据固有偏见的倾向,以及生成有害语言的能力 (Gehman等人, 2020)。本节我们将检视这些风险,重点关注检索增强型语言模型可能加剧或...

mitigate them.

缓解它们。

Large language models can perfectly memorise parts of their training data (Carlini et al., 2021). When coupled with large training datasets gathered from the web or other sources, this has clear privacy and safety implications. Retrieval models such as Retro that have access to the entire training dataset during inference exacerbate these privacy issues by being able to directly copy training data. However, retrieval systems offer a path towards mitigating these concerns via obliteration of the retrievable data at inference time. In addition, differential privacy training (Abadi et al., 2016) of retrieval models could guarantee that no private information is stored in the model weights, while individual is ation on private data could be made by updating the retrieval database at inference time.

大语言模型能够完美记忆其训练数据的部分内容 (Carlini et al., 2021)。当这些模型与从网络或其他来源收集的大规模训练数据集结合使用时,会带来明显的隐私和安全问题。像Retro这样的检索模型在推理过程中可以访问整个训练数据集,能够直接复制训练数据,从而加剧了这些隐私问题。然而,检索系统提供了一种通过在推理时删除可检索数据来缓解这些问题的途径。此外,对检索模型进行差分隐私训练 (Abadi et al., 2016) 可以确保模型权重中不存储任何私人信息,同时可以通过在推理时更新检索数据库来实现对私人数据的单独处理。

Due to their high training cost, re-training large language model regularly to incorporate new data, languages, and norms is prohibitively expensive. To keep retrieval models up-to-date, it may be sufficient to update the retrieval database, which is orders of magnitude cheaper than re-training a model from scratch. In addition to the benefits of updating models in terms of fairness and bias, simply training large language models has a significant energy cost (Schwartz et al., 2020; Strubell et al., 2019). Retrieval mechanisms offer a path to reducing the compute requirements needed to train and update language models that reach a certain performance.

由于训练成本高昂,定期重新训练大语言模型以纳入新数据、语言和规范的成本令人望而却步。为了保持检索模型的最新性,仅更新检索数据库可能就足够了,这比从头开始重新训练模型的成本低几个数量级。除了在公平性和偏见方面更新模型的好处外,仅训练大语言模型就会产生巨大的能源消耗 (Schwartz et al., 2020; Strubell et al., 2019)。检索机制为降低达到特定性能的语言模型训练和更新所需的计算需求提供了一条途径。

Large language models are prone to generating toxic outputs, as shown in Gehman et al. (2020). Bender et al. (2021); Jo and Gebru (2020) advocate for the importance of better training data curation and documentation. Additionally, if portions of the training data are found to be eliciting biased or toxic outputs after training, retrieval allows for some correction, as the offending retrieval data can be retroactively filtered. However, it is also the case that without careful analysis and intervention, retrieval models may exacerbate biases that are present in the training data. Retrieval models can also add a further source of bias through the selection mechanism for retrieval documents. Further work in this area is required to better understand how retrieval affects the bias and toxicity of the model outputs.

大语言模型容易生成有害输出,如Gehman等人(2020)所示。Bender等人(2021)与Jo和Gebru(2020)强调了优化训练数据筛选和文档记录的重要性。此外,若发现训练数据的某些部分在训练后诱发偏见或有害输出,检索机制可通过事后过滤问题检索数据来实现部分修正。但需注意的是,若缺乏细致分析和干预,检索模型可能放大训练数据中存在的偏见。检索模型还可能通过检索文档的选择机制引入新的偏见源。该领域需要进一步研究以更好地理解检索如何影响模型输出的偏见和毒性。

Finally, samples from large models are difficult to interpret, making mitigating these issues all the more challenging (Belinkov et al., 2020; Jain and Wallace, 2019). Retrieval provides more insights in to the outputs of a model, as one can directly visualise or modify the neighbours that are being used. The examples in Table 6, 7, 20 and 21 illustrate how retrieval makes language models more factual and interpret able by providing more transparent outputs.

最后,大模型的样本难以解释,这使得缓解这些问题变得更加困难 (Belinkov et al., 2020; Jain and Wallace, 2019) 。检索为模型输出提供了更多洞察,因为人们可以直接可视化或修改正在使用的邻近内容。表 6、7、20 和 21 中的示例展示了检索如何通过提供更透明的输出,使语言模型更具事实性和可解释性。

4. Results

4. 结果

We first report results on language modelling benchmarks. Second, we show how to Retrofit pre-trained Transformer language models into retrieval models with few additional FLOPs. Next, we report Retro results on question answering. Finally, we report evaluation metrics with leakage filtering, to better understand the source of the gains with retrieval.

我们首先报告语言建模基准测试的结果。其次,我们展示如何以少量额外FLOPs将预训练的Transformer语言模型改造为检索模型。接着,我们报告问答任务上的Retro表现。最后,我们采用泄漏过滤的评估指标,以更好地理解检索带来增益的来源。

4.1. Language modelling

4.1. 语言建模

Datasets. We evaluate our models on C4 (Raffel et al., 2020), Wikitext 103 (Merity et al., 2017), Curation Corpus (Curation, 2020), Lambada (Paperno et al., 2016) and the Pile (Gao et al., 2020). We also evaluate on a set of manually selected Wikipedia articles that were added or heavily edited in September 2021, months after our pre-training and retrieval dataset was collected (details are given in $\S\mathrm{A}.2)$ . We construct the dataset with articles from the “future” and manually remove new articles that strongly overlap documents in our training data. This guarantees that the evaluation documents are not leaked in our training data.

数据集。我们在C4 (Raffel et al., 2020)、Wikitext 103 (Merity et al., 2017)、Curation Corpus (Curation, 2020)、Lambada (Paperno et al., 2016) 和 the Pile (Gao et al., 2020) 上评估模型性能。同时评估了一组人工筛选的维基百科文章,这些文章发布于2021年9月(即预训练和检索数据收集完成数月后)且经过大量编辑(详见 $\S\mathrm{A}.2)$ 。我们使用"未来"文章构建数据集,并手动剔除与训练数据文档高度重合的新文章,确保评估文档未在训练数据中泄露。


Figure 3 | Scaling with respect to model size. (a) LAMBADA top-1 accuracy. (b) Evaluation loss on curation corpus. (c) Perplexity on Wikitext 103 valid. (d) Bits-per-byte on selected Wikipedia articles from September 2021.

图 3 | 模型尺寸的扩展性分析。(a) LAMBADA top-1准确率。(b) 精选语料库的评估损失。(c) Wikitext 103验证集的困惑度。(d) 2021年9月部分维基百科文章的每字节比特数。

For C4, Wikitext 103, the Pile, and our Wikipedia dataset we evaluate the language modelling performance on entire documents and measure the bits-per-byte (bpb). We favour bits-per-byte over loss as it is tokenizer agnostic. We evaluate with a sequence length of 2048 tokens but use a stride of 1024 within documents to mitigate boundary effects. On Curation Corpus we concatenate the article, the “TL;DR:” string, and the summary, but only evaluate the bpb on the summary. For Lambada we evaluate the accuracy on the last word, using greedy generation.

对于C4、Wikitext 103、the Pile以及我们的维基百科数据集,我们评估整个文档的语言建模性能,并测量每字节比特数(bpb)。我们优先选用bpb而非损失值,因其不受分词器影响。评估时采用2048个token的序列长度,但在文档内部使用1024的步长以减轻边界效应。在Curation Corpus上,我们将文章、"TL;DR:"字符串和摘要拼接起来,但仅计算摘要部分的bpb。对于Lambada数据集,我们通过贪婪生成方式评估最后一个单词的预测准确率。

Model scaling. In Fig. 1(left) and Fig. 3 we show the language modelling performance as we scale models from 150 million to 7 billion (non-embedding) parameters. We see that on all datasets, Retro outperforms the baseline at all model sizes. Furthermore, we observe that improvements do not diminish as we scale the models. The performance is dataset dependent, with the largest gains on Wikitext 103 and C4. Wikipedia articles and other web pages are similar to Wikitext 103 documents, even if not exact copies (§4.4), we thus obtain dramatic improvements on Wikitext 103 as our retrieval model is able to directly exploit these overlaps. The smallest gains are for Curation Corpus, where Retro only slightly outperforms the baseline. This is expected as Curation Corpus summaries are designed to only contain information from the source article and are not included in our retrieval database. On our “future” Wikipedia September 2021 dataset, we also observe consistent gains for all model sizes.

模型缩放。在图1(左)和图3中,我们展示了模型参数从1.5亿扩展到70亿(非嵌入)时的语言建模性能。在所有数据集上,Retro模型均优于各尺寸的基线模型。此外,我们观察到性能提升不会随模型规模扩大而减弱。具体表现因数据集而异,其中Wikitext 103和C4数据集提升最为显著。由于维基百科文章与其他网页内容虽非完全复制但高度类似(见4.4节),当检索模型能够直接利用这些重叠内容时,我们在Wikitext 103上获得了显著改进。提升最小的是Curation Corpus数据集,Retro仅略微超越基线,这符合预期——因为该数据集的摘要设计仅包含源文章信息,且未被纳入我们的检索数据库。在2021年9月版的"未来"维基百科数据集上,所有模型尺寸也都表现出稳定的性能提升。

Data scaling. Fig. 1 (middle) shows how scaling the retrieval database at evaluation improves the language modelling performance. We observe dramatic gains as the retrieval data is increased from Wikipedia (4 billion tokens) to all of Massive text (1.7T tokens). Fig. 1(right) shows how performance scales as we increase the number of retrieved chunks. Despite being only trained with 2 neighbours, we see consistent improvements for all models when the number of neighbours is increased from 1 to 10. Furthermore, we observe that larger models are able to better utilise more neighbours: the 172M model improves with up to 10 neighbours, whereas the 7B model improves with up to 40 neighbours.

数据扩展。图1(中)展示了在评估时扩展检索数据库如何提升语言建模性能。我们观察到当检索数据从维基百科(40亿token)扩展到Massive文本全集(1.7万亿token)时,性能获得显著提升。图1(右)显示了随着检索块数量增加时的性能变化。尽管所有模型仅使用2个邻近项进行训练,但当邻近项数量从1增加到10时,所有模型都表现出持续改进。此外,我们发现更大规模的模型能更有效地利用更多邻近项:1.72亿参数模型在邻近项增至10个时持续提升,而70亿参数模型在邻近项增至40个时仍能保持改进。

The Pile. We evaluate our 7B models on the Pile test sets3 and compare against the 178B parameter Jurrasic-1 (Lieber et al., 2021) model and the 280B parameter Gopher (Rae et al., 2021) model. We do not compare against GPT-3 as it is outperformed by Jurassic-1 and Gopher on almost all subsets. Fig. 4 shows the relative improvements in bits-per-byte over our 7B transformer baseline for our

The Pile。我们在Pile测试集上评估了7B参数模型,并与178B参数的Jurassic-1 (Lieber et al., 2021) 模型和280B参数的Gopher (Rae et al., 2021) 模型进行对比。由于GPT-3在几乎所有子集上都表现不如Jurassic-1和Gopher,因此未将其纳入比较范围。图4展示了我们7B参数Transformer基线模型在每字节比特数指标上的相对改进情况。


Figure 4 | The Pile: Comparison of our 7B baseline against Jurassic-1, Gopher, and Retro. We observe that the retrieval model outperforms the baseline on all test sets and outperforms Jurassic-1 on a majority of them, despite being over an order of magnitude smaller.

图 4 | The Pile数据集:我们的7B基线模型与Jurassic-1、Gopher和Retro的对比。尽管检索模型的规模小了一个数量级,但我们观察到其在所有测试集上都优于基线模型,并且在大多数测试集上超越了Jurassic-1。

7.5B Retro model, Jurassic-1 and Gopher. Jurassic-1 outperforms the baseline on all datasets except for books, likely due to the inclusion of books in our training data. Gopher and Retro outperform the baseline on all test sets. Overall, Retro 7.5B outperforms Jurassic-1 and Gopher on a majority of the test sets. On the dm mathematics and ubuntu_irc subsets, our Retro model does not outperform our 7B baseline and under performs Jurassic-1. We hypothesis e that the retrieved neighbours on these datasets are not helpful, due to a combination of what is in our retrieval dataset and the efficacy of the nearest-neighbour search.

7.5B Retro模型、Jurassic-1和Gopher。除书籍数据集外,Jurassic-1在所有数据集上均优于基线,这可能是由于我们的训练数据中包含了书籍。Gopher和Retro在所有测试集上都优于基线。总体而言,7.5B Retro模型在大多数测试集上优于Jurassic-1和Gopher。在dm_mathematics和ubuntu_irc子集上,我们的Retro模型未能超越7B基线,且表现不及Jurassic-1。我们推测这些数据集的检索邻居无效,原因可能包括检索数据集内容与最近邻搜索效果的综合影响。

Wikitext 103. To validate our approach in a controlled setting, we compare our method with 𝑘NN-LM (Khandelwal et al., 2020) on the Wikitext 103 dataset in Table 4. We train a baseline transformer on the training set of Wikitext 103. This transformer has 24 layers, 1024 hidden units, 16 heads and a key size of 64, as in Baevski and Auli (2019). Our baseline does not have adaptive input, and our tokenizer has an open vocabulary, unlike Baevski and Auli (2019), which makes our baseline perplexities a bit higher. The full experiment details and hyper parameters are given in $\S\mathrm{C}.2$ and Table 11.

Wikitext 103。为了在受控环境中验证我们的方法,我们在表 4 中将我们的方法与 𝑘NN-LM (Khandelwal et al., 2020) 在 Wikitext 103 数据集上进行了比较。我们在 Wikitext 103 的训练集上训练了一个基线 Transformer (Transformer)。该 Transformer 具有 24 层、1024 个隐藏单元、16 个头和 64 的键大小,与 Baevski 和 Auli (2019) 相同。我们的基线没有自适应输入,并且我们的分词器 (tokenizer) 具有开放词汇表,这与 Baevski 和 Auli (2019) 不同,这使得我们的基线困惑度略高。完整的实验细节和超参数在 $\S\mathrm{C}.2$ 和表 11 中给出。

Table 4 | Perplexities on Wikitext 103. When using the Wikpedia dataset for retrieval, Retro performs similarly to our implementation of 𝑘NN-LM. As we scale the retrieval dataset, Retro performs much better. The perplexities for retrieving from full Massive Text are quite low, which is partly due to partial overlap with Wikitext 103 not caught by our de duplication.

ModelRetrieval Set#Databasetokens#Database keysValidTest
Adaptive Inputs (Baevski and Auli, 2019)17.9618.65
SPALM (Yogatama et al.,2021)Wikipedia3B3B17.2017.60
kNN-LM (Khandelwal et al., 2020)Wikipedia3B3B16.0616.12
Megatron (Shoeybi et al., 2019)10.81
Baseline transformer (ours)21.5322.96
kNN-LM (ours)Wikipedia4B4B18.5219.54
RETROWikipedia4B0.06B18.4618.97
RETROC4174B2.9B12.8710.23
RETROMassiveText (1%)18B0.8B18.9220.33
RETROMassiveText(10%)179B4B13.5414.95
RETROMassiveText (100%)1792B28B3.213.92

表 4 | Wikitext 103上的困惑度。当使用维基百科数据集进行检索时,Retro的表现与我们实现的𝑘NN-LM类似。随着我们扩展检索数据集,Retro的表现要好得多。从完整的Massive Text中检索的困惑度相当低,这部分是由于与Wikitext 103的部分重叠未被我们的去重处理捕获。

模型 检索集 数据库Token数 数据库键数 验证集 测试集
Adaptive Inputs (Baevski and Auli, 2019) 17.96 18.65
SPALM (Yogatama et al., 2021) Wikipedia 3B 3B 17.20 17.60
kNN-LM (Khandelwal et al., 2020) Wikipedia 3B 3B 16.06 16.12
Megatron (Shoeybi et al., 2019) 10.81
Baseline transformer (ours) 21.53 22.96
kNN-LM (ours) Wikipedia 4B 4B 18.52 19.54
RETRO Wikipedia 4B 0.06B 18.46 18.97
RETRO C4 174B 2.9B 12.87 10.23
RETRO MassiveText (1%) 18B 0.8B 18.92 20.33
RETRO MassiveText (10%) 179B 4B 13.54 14.95
RETRO MassiveText (100%) 1792B 28B 3.21 3.92

We re-implement $k\mathrm{NN-LM}$ with our tokenizer and baseline transformer to produce embeddings of size 1024 for every token in Wikitext 103. 𝑘NN-LM has probabilities $p_{k\mathrm{NN-LM}}=\lambda p_{k\mathrm{NN}}+(1-\lambda)p_{\mathrm{LM}}$ with $p_{k\mathrm{NN}}\left(n_{k}\right)\propto\exp\left(-\alpha d_{k}\right)$ . We tune $\lambda=0.118$ and $\alpha=0.00785$ on the validation set (Fig. 7) and report performance for these hyper parameters on both the validation and test set.

我们使用自己的分词器和基线Transformer重新实现了$k\mathrm{NN-LM}$,为Wikitext 103中的每个token生成1024维的嵌入向量。该模型的概率计算公式为$p_{k\mathrm{NN-LM}}=\lambda p_{k\mathrm{NN}}+(1-\lambda)p_{\mathrm{LM}}$,其中$p_{k\mathrm{NN}}\left(n_{k}\right)\propto\exp\left(-\alpha d_{k}\right)$。我们在验证集上调试得到$\lambda=0.118$和$\alpha=0.00785$ (图7),并在验证集和测试集上报告了这些超参数的性能表现。

We fine-tune our baseline transformer into a Retro model (Fig. 7), using the Wikitext 103 training data and retrieving from Wikipedia with 2 neighbours. We only train the new weights, as explained in $\S4.2$ , and share the embedding weights between the encoder and the main pathway. This is necessary for Wikitext 103 which is quite small, as training Retro from scratch in this setting leads to over-fitting.

我们将基线Transformer微调为Retro模型(图7),使用Wikitext 103训练数据并从维基百科检索2个邻近结果。如$\S4.2$所述,我们仅训练新权重,并在编码器和主路径之间共享嵌入权重。这对于规模较小的Wikitext 103数据集是必要的,因为在此设置下从头训练Retro会导致过拟合。

We evaluate the fine-tuned Retro model with different retrieval sets. We use 10 neighbours at evaluation for both Retro and 𝑘NN-LM. When retrieving from Wikipedia, we obtain results comparable to our 𝑘NN-LM implementation. Furthermore, scaling the retrieval database to Massive Text yields dramatic improvements, though this is partly due to leakage (see $\S4.4)$ . For reproducibility, we also include results when retrieving from C4, which are close to previous state-of-the-art and comparable to using $10~%$ of Massive Text.

我们评估了使用不同检索集的微调Retro模型。在评估时,Retro和𝑘NN-LM均采用10个最近邻。从维基百科检索时,我们获得了与𝑘NN-LM实现相当的结果。此外,将检索数据库扩展至Massive Text带来了显著提升,尽管这部分归因于数据泄露(见$\S4.4$)。为确保可复现性,我们还提供了从C4检索的结果,其接近先前最优水平,与使用Massive Text的$10~%$数据效果相当。

It is worth noting that 𝑘NN-LM requires 1024 floats for every token in the retrieval dataset, totalling 15 terabytes (Tb) for the 4 billion tokens in Wikipedia. 𝑘NN-LM and other token-level retrieval approaches therefore don’t scale to retrieval databases with trillions of tokens such as Massive Text. In comparison, Retro only requires 215Gb to index our Wikipedia dataset, and 93Tb for Massive Text. Inspecting the number of retrieval database entries in Table 4 makes it clear why retrieving at the chunk level is necessary when scaling to datasets with trillions of tokens.

值得注意的是,𝑘NN-LM 需要为检索数据集中的每个 token 存储 1024 个浮点数,仅维基百科的 40 亿 token 就需占用 15TB 空间。因此,𝑘NN-LM 等 token 级检索方法难以扩展到包含数万亿 token 的检索库(如 Massive Text)。相比之下,Retro 仅需 215GB 即可索引我们的维基百科数据集,处理 Massive Text 也仅需 93TB。通过观察表 4 中的检索数据库条目数量可以明显看出:在处理数万亿 token 规模的数据集时,分块 (chunk) 级检索的必要性。

4.2. Retro-fitting baseline models

4.2. 基准模型改造

We extend baseline models into Retro models by freezing the pre-trained weights and training only chunked cross-attention and neighbour encoder parameters (less than $10%$ of weights for the 7B model) in Fig. 5. This offers an efficient alternative path to enhance transformers with retrieval, requiring only 6 million sequences ( $3%$ of the pre-training sequences that we used). Additionally, by only training the new weights we ensure that when evaluated without retrieval, the original model performance is exactly maintained. Retrofitting models quickly surpasses the performance of baseline models and even achieves performance close to that of Retro models trained from scratch. The experiment hyper parameters are given in $\S\mathrm{C}.3$ .

我们在图5中通过冻结预训练权重并仅训练分块交叉注意力(chunked cross-attention)和邻近编码器参数(neighbour encoder parameters)(7B模型中不到$10%$的权重),将基线模型扩展为Retro模型。这为通过检索增强Transformer提供了一条高效路径,仅需600万序列(相当于我们所用预训练序列的$3%$)。此外,由于仅训练新增权重,当不启用检索评估时,原始模型性能能得到完全保持。改造后的模型迅速超越基线模型性能,甚至接近从头训练的Retro模型水平。实验超参数详见$\S\mathrm{C}.3$。

4.3. Question answering

4.3. 问答

We fine-tune our retrieval models on the Natural Questions (Kwiatkowski et al., 2019) dataset to demonstrate that our retrieval pathway can be used to inject information from arbitrary data sources. We use the version4 provided by Izacard and Grave (2021) which is augmented with the retrieved passages from Dpr (Karpukhin et al., 2020). We fine-tune all the weights of our 7.5B pre-trained Retro model for 25,000 steps using the top 20 retrieved passages. We format the data as “question: {question} \n answer: {answer}” and left pad the data such that “answer:” coincides with the end of the first chunk of 64 tokens and thus aligns with the first retrieving chunk. The model has access to the question via the previous tokens in the sequence as well as the top 20 DPR Wikipedia passages and their titles via the chunked cross-attention mechanism.

我们在Natural Questions (Kwiatkowski等人,2019) 数据集上微调检索模型,以证明我们的检索路径可用于注入任意数据源的信息。采用Izacard和Grave (2021) 提供的version4版本,该版本通过Dpr (Karpukhin等人,2020) 检索到的段落进行了增强。我们使用前20个检索段落对预训练的7.5B参数Retro模型所有权重进行25,000步微调。数据格式化为"question: {question} \n answer: {answer}",并进行左填充使"answer:"对齐64个token的首个分块末尾,从而与首个检索分块保持对齐。模型通过序列中的先前token获取问题信息,并通过分块交叉注意力机制获取前20个DPR维基百科段落及其标题。


Figure 5 | Retro-fitting a baseline transformer. Any transformer can be fine-tuned into a retrievalenhanced transformer by randomly initializing and training only the chunked cross-attention and retrieval encoder weights. Fine-tuning in this way quickly recovers and surpasses the non-retrieval performance, and almost achieves the same performance as training a retrieval model from scratch (shown by the arrow on the right hand side of each plot). We find good performance Retro-fitting our models training on only $3%$ the number of tokens seen during pre-training.

图 5 | 基线Transformer的改造适配。通过仅随机初始化并训练分块交叉注意力(chunked cross-attention)和检索编码器(retrieval encoder)权重,任何Transformer都可以微调成检索增强型Transformer。这种微调方式能快速恢复并超越非检索模型的性能,几乎达到从头训练检索模型的同等性能(如各子图右侧箭头所示)。实验表明,仅用预训练阶段3%的Token量进行改造适配训练即可获得良好性能。

The exact match scores are shown in Table 5 and the full fine-tuning details are given in $\S\mathrm{C}.4$ . Our method is competitive with previous approaches such as Realm, RAG and Dpr, but under performs the more recent FiD. In contrast with this work, we find that increasing the number of neighbours past 20 does not improve Retro performance on this task. We hypothesis e that the encoder-decoder structure of T5—the base model in FiD— and the T5 pre-training objective leads to a model that relies more on the encoder output than Retro, which is important in the QA setting. To compete with T5-finetuned models, future work should consider ways of forcing Retro to rely further on the retrieval encoder output when producing tokens.

精确匹配分数如表5所示,完整微调细节见$\S\mathrm{C}.4$。我们的方法与Realm、RAG和Dpr等先前方法具有竞争力,但表现略逊于较新的FiD。与此工作不同的是,我们发现将邻居数量增加到20以上并不会提升Retro在此任务上的性能。我们假设FiD的基础模型T5采用的编码器-解码器结构及其预训练目标,使其比Retro更依赖编码器输出,这在问答场景中至关重要。为达到与T5微调模型相当的水平,未来工作应考虑如何让Retro在生成token时更依赖检索编码器的输出。

4.4. Relating retrieval performance to dataset leakage.

4.4. 检索性能与数据集泄漏的关联

We report the filtered eval losses as detailed in $\S2.6$ on C4, Curation Corpus and Wikitext 103 in Fig. 6. On C4 and Wikitext 103, for which there is leakage into the training set, the slope is negative for both baseline models and Retro models. Retro models exploit leakage more strongly than baseline models, as indicated by the more negative slope. This is due to its explicit ability to copy-paste existing training chunks to predict leaked evaluation chunks (see a qualitative example of this model behavior on a Wikitext 103 article in Table 19). On Curation Corpus, retrieval provides a constant offset, which is expected as there is by design no leakage between Curation Corpus and the training dataset.

我们在图6中报告了C4、Curation Corpus和Wikitext 103数据集上经过过滤的评估损失(详细方法见$\S2.6$)。对于存在训练集泄露的C4和Wikitext 103数据集,基线模型和Retro模型的损失曲线斜率均为负值。Retro模型通过更显著的负斜率表明其比基线模型更善于利用数据泄露,这得益于其显式复制粘贴训练数据块来预测泄露评估块的能力(见表19中Wikitext 103文章的定性示例)。在Curation Corpus数据集上,检索提供了恒定偏移量,这符合预期,因为该数据集与训练集在设计上不存在泄露。

Table 5 | Question answering results. Exact match accuracy on Natural Questions.

ModelTestAccuracy
REALM (Guu et al., 2020) DPR( (Karpukhin et al., 2020) RAG40.4 41.5
(Lewisetal.,2020) EMDR2 (Sachan et al.,2021) F1D (Izacard and Grave, 2021)44.5 52.5 51.4
F1D + Distill.(Izacard et al.,2020)54.7
Baseline 7B (closed book) RETRO 7.5B (DPR retrieval)30.4

表 5 | 问答结果。Natural Questions 上的精确匹配准确率。

模型 测试准确率
REALM (Guu et al., 2020) DPR (Karpukhin et al., 2020) RAG 40.4 41.5
(Lewis et al., 2020) EMDR2 (Sachan et al., 2021) F1D (Izacard and Grave, 2021) 44.5 52.5 51.4
F1D + Distill. (Izacard et al., 2020) 54.7
Baseline 7B (closed book) RETRO 7.5B (DPR retrieval) 30.4


Figure 6 | Performance vs. longest common retrieval substring. Evaluation loss as a function of allowed longest common substring between evaluation data chunks and their nearest neighbours. Retrieval still helps when considering chunks with no more than 8 contiguous tokens overlapping with training dataset chunks.

图 6 | 性能与最长公共检索子串的关系。评估损失随评估数据块与其最近邻之间允许的最长公共子串长度变化。当数据块与训练数据集块重叠的连续token不超过8个时,检索仍能带来性能提升。

On the other hand, Retro outperforms baseline models at all leakage levels, down to $\alpha=12.5%$ . At this level, the loss is computed on chunks with less than 8 contiguous tokens shared with the closest matching chunk in the training dataset—this is a reasonable level of overlap at which we consider that there is no local leakage. Retrieval thus improves predictions on both chunks that are syntactically similar to chunks in the training set, and on chunks that are syntactically different from all training chunks. This points toward a non trivial Retro capacity of generalizing based on both model parameters and retrieval database. Similar results are found on the Pile dataset (see Fig. 12, $\S\mathrm{F}.3)$ .

另一方面,Retro在所有泄漏级别(低至$\alpha=12.5%$)的表现均优于基线模型。在此级别下,损失计算针对的是与训练数据集中最接近匹配块共享连续token少于8个的文本块——我们认为这种重叠程度合理,不存在局部泄漏。因此,检索机制既能提升对训练集中语法相似文本块的预测能力,也能改善对语法相异文本块的预测表现。这表明Retro具备基于模型参数和检索数据库进行泛化的非凡能力。在Pile数据集上也观察到类似结果(见图12,$\S\mathrm{F}.3$)。

4.5. Using Retro for sampling

4.5. 使用Retro进行采样

We show examples of samples obtained using the 7.5B Retro model in Table 6, Table 7 and Appendix E. For each chunk (the first one being the prompt), we juxtapose sampled chunks $C_{u}$ with retrieved neighbours $\textstyle\operatorname{RET}(C_{u})$ . To give an indication of local overlap, we colour each sampled token in chunk $C_{u}$ based on the length of the longest common prefix (LCP) found in the retrieved chunks $\mathrm{RET}\left(C_{u-1}\right)$ . Similarly, we colour the retrieved chunks based on the LCP in the sampled chunk. For the sample in Table 6, for which we chose the prompt, we observe that the retrieved chunks influence the sample as there are overlaps between the sampled tokens and neighbour tokens. Overall, retrieval reduces hallucinations (in line with the findings of Shuster et al. (2021)) and makes the model more knowledgeable, when comparing with samples produced with retrieval disabled. In the sample in Table 7, the model recognises that the prompt is the beginning of the first scene of Hamlet and leverages retrieval data to continue it with only a few mistakes. We provide further examples in Appendix E, including examples from the evaluation sets, as well as the detailed procedure used for colouring the tables.

我们在表6、表7和附录E中展示了使用7.5B Retro模型获得的样本示例。对于每个数据块(第一个为提示词),我们将采样数据块$C_{u}$与检索到的相邻块$\textstyle\operatorname{RET}(C_{u})$并列展示。为显示局部重叠情况,我们根据检索块$\mathrm{RET}\left(C_{u-1}\right)$中找到的最长公共前缀(LCP)长度对$C_{u}$块中每个采样token进行着色。同样地,我们也基于采样块中的LCP对检索块进行着色。对于表6中我们选定提示词的样本,观察到采样token与相邻token存在重叠,表明检索块对样本产生了影响。总体而言,与禁用检索时生成的样本相比,检索机制减少了幻觉现象(与Shuster等人(2021)的研究结果一致),并使模型表现出更强的知识性。在表7的样本中,模型识别出提示词是《哈姆雷特》第一幕的开场,并利用检索数据继续生成文本,仅出现少量错误。附录E提供了更多示例,包括来自评估集的样本以及用于表格着色的详细流程。

5. Conclusion

5. 结论

We present Retrieval-Enhanced Transformers (Retro), a method for modelling arbitrary text sequences whilst retrieving from databases with trillions of tokens—scaling the data available to models by an order of magnitude compared to what is typically consumed during training. Retro models gains do not diminish for models with up to at least 7B parameters, and correspond to non-retrieval models with $10\times$ more parameters on certain datasets. On Wikitext 103 and the Pile, Retro outperforms previous models trained on large scale datasets. We also show that Retro is competitive on retrieval-intensive downstream tasks such as question answering.

我们提出检索增强型Transformer (Retro)方法,该方法能在建模任意文本序列的同时从包含数万亿token的数据库中进行检索,使模型可利用的数据量相比常规训练规模提升一个数量级。对于参数规模至少达70亿的模型,Retro仍能保持性能增益,其表现相当于参数规模扩大10倍的非检索模型在特定数据集上的效果。在Wikitext 103和Pile数据集上,Retro超越了先前基于大规模数据集训练的模型。我们还证明Retro在问答等检索密集型下游任务中具有竞争力。

Retro models are flexible and can be used without retrieval at evaluation and still achieve comparable performance to baseline models. Conversely, baseline models can be rapidly fine-tuned into Retro models to obtain nearly the same performance as if trained from scratch. Careful analysis shows that only a modest fraction of the gains obtained by Retro are due to test set leakage. In general, we caution for such leakage in large-scale language datasets and suggest further work in better understanding the role of test set leakage in the performance of large-scale language models.

Retro模型具有灵活性,在评估时无需检索仍可获得与基线模型相当的性能。相反,基线模型可快速微调为Retro模型,获得接近从头训练的同等性能。细致分析表明,Retro所获增益中仅有少量部分源于测试集泄漏。我们普遍警示大规模语言数据集中此类泄漏现象,并建议进一步研究以更好理解测试集泄漏对大语言模型性能的影响。

Overall, our work demonstrates at an unprecedented scale that semi-parametric approaches can provide an orthogonal, more efficient approach than raw parameter scaling as we seek to build more powerful language models.

总体而言,我们的工作以前所未有的规模证明:在构建更强大语言模型的过程中,半参数化方法能提供与纯参数扩展正交且更高效的路径。

Acknowledgements

致谢

We would like to thank Nikolai Grigorev, Marc’aurelio Ranzato, Cyprien de Masson d’Autume, Po-Sen Huang, Johannes Welbl, Lisa Anne Hendricks, Ethan Perez, Jeff Stanway, Eric Noland, Gregory Wayne, John Jumper, Julian Schr it t wiese r, Lorrayne Bennett, Devang Agrawal, Dani Yogatama, Susannah Young, Nando de Freitas, Demis Hassabis, and Koray Ka vuk cuo g lu for their help, advice and reviews. Additionally, we would like to thank Zonglin Li, David Simcha, and the ScaNN developers for their help.

我们要感谢Nikolai Grigorev、Marc'aurelio Ranzato、Cyprien de Masson d'Autume、Po-Sen Huang、Johannes Welbl、Lisa Anne Hendricks、Ethan Perez、Jeff Stanway、Eric Noland、Gregory Wayne、John Jumper、Julian Schrittwieser、Lorrayne Bennett、Devang Agrawal、Dani Yogatama、Susannah Young、Nando de Freitas、Demis Hassabis和Koray Kavukcuoglu的帮助、建议和审阅。此外,我们还要感谢Zonglin Li、David Simcha以及ScaNN开发团队的协助。

Table 6 | Sample - Beavers are interesting animals. The Retro[Off] sample quickly diverges to other animals while the Retro[On] sample tends to stay focused on the beaver topic due to neighbour conditioning.

Prompt and sample of RETRo[OFF]Prompt and sample of RETRo [ON] colored by LCP with RET(Cu-1) LCP =0, 1, 2, 3,4,≥ 5[Nl,Fl] colored by LCP with Cu+1 LCP=0,1,2,3,4,>5[N2,F²] colored by LCP with Cu+1 LCP=0,1,2,3,4,>5
Beavers are interesting animals that live near rivers. They buildBeavers are interesting animals that live near rivers. They buildBeavers build their lodges in pon ds they have created in wooded areas Like many things in nature, there is a connection between creaturesi n the wild.Beaver ponds cause tree s to drown,but the dead trees attra ct the great blue heron,which often return year after year. Over time, a beaver pond can attract more than 50 nests in a colony, called a rooke ry.An example of this canbe found in the large pond off Bradford Roadnaw them into smaller sections and d rag them into the water.Engineers of the Pond Beavers are interesting animals because they change the hab itat in which they live. Beavers do this by blocking up streams to creat e ponds. Then they build their homes called lodges, in these ponds. Bea vers'bodies make them well-suited f or underwater building Special muscl es close off theirnoses,ears,and throats tokeep thewater out.Beave rs' broad tails act like rudders for
and and water. They have great camou flage to hide from predators. The G olden Retriever, or Golden, is a ver y big dog. They are very strong, and have a thick coat to help them live ind.Beavers use their strong teeth an dstrong jaws to cut down trees and branches to build their homes.They also use their teeth and jaws to che w up food. Beavers use their big,f at tails to swim. They useusy.Beavers swim easily in streams, picking up rocks and sticks to buil d their dams.They gnaw at trees wit h their big front teeth to cut them down. Then they use parts of the tre es to build their houses.Beavers ar e clever builders. They know exactly what they need tobuild theirbeave r dams. They use mud from the stream to make their dams stay together. T hey use their tails topat down the mud.Beavers put a snug room at thear-like tail, and two protruding tee th that are strong enough to gnaw do wn trees. The beaver uses trees, bra nches, and mud to build dams across rivers and streams.These dams creat e a deep pool of water in which the beaver builds its home.A beaver hom e is called a lodge. A baby beaver o r “kit” remains in the family lodge until the age of two.Beaver fur, kn own aspeltwasoncehighlypopular as a trim for hats and coats. How m ight the popularity of beaver fur co ntributed to thecolonization of New
1 that is very large. It can jump ve ry high. They can run very fast. A meerkat is a small animal that lives in Africa. They live in colonies of up to a hundred of them. They can c limb trees very easily.and cut trees. They use their flat t ails to steer when they are swimming In the winter,beavers have to bui ld a house called a beaver lodge. Th ey build a house that is very strong The walls are made of twigs. The r oof is madeusy.Beavers swim easily in streams, picking up rocks and sticks to buil d their dams. They gnaw at trees wit h their big front teeth to cut them down. Then they use parts of the tre es to build their houses.Beavers ar e clever builders.They know exactly what they need to build theirbeave r dams.Theyuse mud from the stream to make their dams stay together. T hey use their tails topat down the mud.Beavers put a snug room at theh to cut them down.Then they use pa rts of the trees to build their hous es.Beavers are clever builders.The y know exactly what they need tobui ld their beaver dams. They use mud f romthestream tomaketheir dams st ay together. They use their tails to pat down the mud.Beavers put a snu g room at the top of their dams for their babies. They store their food underwater.Beavers eat the bark fro m the trees that they cut down!1. W hat is the main idea of the first pa
and it can run very fast. Penguins are birds that live on Antarctica. T hey have a thick coat to keep them w arm. Rabbits are small animals that live in the ground. Theyhey also use their strong jaws to cu t trees.They bring them to their ho use.They also use their sharp teeth to chew up the tree parts. They use theirflat tails toswim to the top of their house. Then they use their

表 6 | 示例 - 海狸是很有趣的动物。Retro[Off]样本会快速偏离到其他动物话题,而Retro[On]样本由于邻近条件作用更倾向于保持海狸主题。

RETRo[OFF]的提示词与样本 按LCP着色的RETRo[ON]样本 (Cu-1) LCP=0,1,2,3,4,≥5 按LCP着色的[Nl,Fl] (Cu+1) LCP=0,1,2,3,4,>5 按LCP着色的[N2,F²] (Cu+1) LCP=0,1,2,3,4,>5
海狸是生活在河流附近的有趣动物。它们会建造 海狸是生活在河流附近的有趣动物。它们会建造 海狸在林区自建的水塘中筑巢。如同自然界许多事物,野生动物之间存在着奇妙关联。海狸塘会导致树木淹死,但枯木会吸引大蓝鹭年复一年地回归。一个海狸塘随时间推移可吸引超过50个巢穴形成群落,称为鹭巢区。布拉德福德路旁的大池塘就是典型范例 用牙齿将树木咬成小段拖入水中。池塘工程师海狸之所以有趣,是因为它们会改造栖息地。通过阻断溪流创造池塘,并在其中建造名为"巢屋"的住所。海狸的身体构造特别适合水下作业:特殊肌肉能封闭鼻耳喉部防水,宽阔的尾巴在水中起到舵的作用
和水。它们拥有出色的伪装能力来躲避捕食者。金毛寻回犬是体型巨大的犬种,力量强大,厚实的被毛帮助它们适应 海狸用强壮的牙齿和颚部砍伐树木枝条来筑巢,也用其咀嚼食物。宽大的尾巴帮助游泳,它们还会 海狸在溪流中灵活游动,收集石块树枝筑坝。用大门牙啃咬树木,再将木材用于建造房屋。它们是聪明的建筑师,清楚筑坝所需材料:用溪流淤泥加固水坝,用尾巴拍实泥浆。它们会在坝顶为幼崽建造舒适房间 类尾巴和两颗足以啃倒树木的突出门牙。海狸用树木、枝条和淤泥在河溪筑坝,形成深水区建造巢屋。幼崽"海狸宝宝"会在家族巢屋生活到两岁。曾作为帽檐大衣饰边而风靡的海狸皮,如何影响了新大陆的殖民进程?
体型非常庞大。能跳得很高,奔跑速度极快。狐獴是非洲小型群居动物,百只群居,擅长攀树 冬季海狸需要建造坚固的巢屋:枝条构成墙壁, 海狸在溪流中灵活游动,收集石块树枝筑坝。用大门牙啃咬树木,再将木材用于建造房屋。它们是聪明的建筑师,清楚筑坝所需材料:用溪流淤泥加固水坝,用尾巴拍实泥浆。它们会在坝顶为幼崽建造舒适房间 咬断树木后利用木材建房。海狸是精明的建筑师,深知筑坝要领:用溪流淤泥固坝,以尾拍实。它们在坝顶为幼崽布置温暖房间,水下储存食物,甚至食用自己砍伐的树皮!1.第一段的主旨是
奔跑速度很快。企鹅是南极鸟类,厚实的羽毛保暖。兔子是穴居的小型动物 也用强壮的颚部砍树运回巢穴。利齿咀嚼木料,扁尾助游,

Table 7 | Sample - Hamlet, Act 1, Scene 1. The Retro[Off] sample has correct syntax but is hallucinated, and ends with repetition of one character (FRANCISCO Approach me not). The Retro[On] sample is the correct continuation of the original text, and is robust to formatting differences between our prompt and the retrieved data.

Prompt and sample of RETRo[OFF]Prompt and sample of RETRo [ON] colored by LCP with RET(Cu-1) LCP=0,1,2,3,4,>5[Nl,Fl] colored by LCP with Cu+1 LCP =0, 1, 2, 3,4,> 5[N2, F²] colored by LCP with Cu+1 LCP =0, 1, 2, 3,4,≥ 5
ACT I SCENE I. Elsinore. A pla tform before the castle. FRANC ISCO at his post. Enter to him BERNARDOBERNARDOWho's there ? FRANCISCO Nay, answer me: st and, and unfold yourself. BERN ARDOACT I SCENE I. Elsinore. A pla tform before the castle. FRANC ISCO at his post. Enter to him BERNARDOBERNARDOWho's there ? FRANCISCO Nay, answer me: st and, and unfold yourself. BERN ARDOACT I SCENE I. Elsinore.A pla tform before the castle. (FRAN CISCO at his post.Enter to hi m BERNARDO)BERNARDO Who's the re?FRANCISCO Nay,answer me: stand, and unfold yourself. BE RNARDO Long live the king! FRA NCISCO Bernardo? BERNARDO He. FRANCISCO You come most carefu lly upon your hour. BERNARDO '>ACT I SCENE I. Elsin ore.A platform before thecas tle.FRANCISCO at his post. E ntertohimBERNARDOBERNARDO Who's there?FRANCISCO Nay,an swer me: stand, and unfold youl rself.BERNARDO Long live the king!FRANCISCO Bernardo?BERN ARDO He. FRANCISCO You come mo st carefully upon your hour. B ERNARDO *Tis now struck twelve
e lord here; I, Francisco, tha t am sick of grief. [ Aside. B ERNARDO The king ! FRANCISCO 1 am sick of that also. BERNARD O My lord ? FRANCISCO Do not a pproach me. BERNARDOBernardo?BERNARDO He.FRANCI SCO You come most carefully up on your hour. BERNARDO *Tis no w struck twelve; get thee to b ed, Francisco. FRANCISCO For t his relief much thanks: 'tis b itter cold, And I am sick at h eart. BBernardo?BERNARDO He.FRANCI SCO You come most carefully up on your hour.BERNARDO Tis no w struck twelve;get thee tob ed,Francisco.FRANCISCO For t his relief much thanks:'tis b itter cold, And I am sick at h eart.romeoACTI PROLOGUE Two households,bo th alike in dignity, In fair V erona,wherewelayourscene,live the king! FRANCISCO Bern ardo?BERNARDOHe.FRANCISCOY ou come most carefully upon yo ur hour. BERNARDO Tis now str uck twelve: get thee to bed, F rancisco.FRANCISCOForthis r elief much thanks: 'tis bitter cold, And I am sick at heart. BERNARDO Have you had quiet g uard? FRANCISCO Not a mouse st irring. BERNARDO Well, good ni ght. Ifyou do meet Horatio and Marcellus, The rivals2 of my
you. FRANCISCO Approach me not ,but speak. BERNARDO Your han d, your voice FRANCISCO I will not hear thee speak. BERNARDO Francisco, your hand, I entre at thee. FRANCISCO Approach me not. BERNARDO Francisco FRANCrd?FRANCISCO Not a mouse stir ring. BERNARDO Well, good nigh t. If you do meet Horatio and Marcellus, The rivals of my wa tch, bid them make haste. FRAN CISCO I think I hear them. Sta nd, ho! who is there? Entert a mouse stirring.BERNARDO W ell, good night. If you do mee t Horatio and Marcellus, The r ivals of my watch, bid them ma ke haste.FRANCISCO I think I hear them.Stand,ho!Who's th ere?(Enter HORATIO and MARCEL LUS)HORATIO Friends to this g round. MARCELLUS And liegemen to the Dane. FRANCISCO Give yo u good night. MARCELLUS O, far ewell, honest soldier: Who hatARDO Have you had quiet guard? FRANCISCO Not a mouse stirrin g.BERNARDO Well, good night. Ifyou do meet Horatio and Marc ellus,The rivals2 of my watch bid them make haste.FRANCIS CO I think I hear them.— Stand ho!who is there?ENTER HORA TIOANDMARCELLUS.HORATIOFri ends to this ground. MARCELLUS And liegemen to the Dane.3 FR ANCISCO Give you good night. M ARCELLUS O, farewell, honest s oldier: Who hath relieved you?
ISCO Approach me not. BERNARDO I have a letterFRANCISCO App roach me not. BERNARDO For the king. FRANCISCO Approach me n ot. BERNARDO There's no treasoHORATIO and MARCELLUS HORATIO Friends to this ground. MARCE LLUS And liegemen to the Dane. FRANCISCO Give you good night MARCELLUS O, farewell, hones

表 7 | 示例 - 《哈姆雷特》第一幕第一场。Retro[Off]样本语法正确但存在幻觉,并以重复单个字符结尾(FRANCISCO Approach me not)。Retro[On]样本是原文的正确延续,并且对我们提示词与检索数据之间的格式差异具有鲁棒性。

RETRo[OFF]的提示词与样本 通过LCP着色的RETRo[ON]样本 (RET(Cu-1) LCP=0,1,2,3,4,>5 通过LCP着色的[Nl,Fl]样本 (Cu+1 LCP=0,1,2,3,4,>5) 通过LCP着色的[N2,F²]样本 (Cu+1 LCP=0,1,2,3,4,≥5)
第一幕第一场 艾尔西诺 城堡前的平台 弗兰西斯科在站岗 波洛涅斯上 波洛涅斯:那边是谁?弗兰西斯科:不,你先回答我;站住,表明你的身份。波洛涅斯 第一幕第一场 艾尔西诺 城堡前的平台 弗兰西斯科在站岗 波洛涅斯上 波洛涅斯:那边是谁?弗兰西斯科:不,你先回答我;站住,表明你的身份。波洛涅斯 第一幕第一场 艾尔西诺 城堡前的平台 (弗兰西斯科在站岗 波洛涅斯上)波洛涅斯:那边是谁?弗兰西斯科:不,你先回答我;站住,表明你的身份。波洛涅斯 吾王万岁!弗兰西斯科:是波洛涅斯吗?波洛涅斯:正是。弗兰西斯科:您来得正是时候。波洛涅斯 ><文本>第一幕第一场 艾尔西诺 城堡前的平台 弗兰西斯科在站岗 波洛涅斯上 波洛涅斯:那边是谁?弗兰西斯科:不,你先回答我;站住,表明你的身份。波洛涅斯 吾王万岁!弗兰西斯科:是波洛涅斯吗?波洛涅斯:正是。弗兰西斯科:您来得正是时候。波洛涅斯 *现在已敲十二点
大人在这里;我,弗兰西斯科,正沉浸在悲伤中。[旁白] 波洛涅斯:国王!弗兰西斯科:我也为此悲伤。波洛涅斯:大人?弗兰西斯科:不要靠近我。波洛涅斯 波洛涅斯:是波洛涅斯吗?波洛涅斯:正是。弗兰西斯科:您来得正是时候。波洛涅斯 *现在已敲十二点;去睡吧,弗兰西斯科。弗兰西斯科:非常感谢这次换岗:天寒地冻,我心如刀绞。 波洛涅斯:是波洛涅斯吗?波洛涅斯:正是。弗兰西斯科:您来得正是时候。波洛涅斯 现在已敲十二点;去睡吧,弗兰西斯科。弗兰西斯科:非常感谢这次换岗:天寒地冻,我心如刀绞。<文档><文档编号>罗密欧<文本>第一幕 开场诗 两个家族,同样尊贵,在维罗纳城,我们故事开始的地方, 万岁!弗兰西斯科:是波洛涅斯吗?波洛涅斯:正是。弗兰西斯科:您来得正是时候。波洛涅斯 现在已敲十二点:去睡吧,弗兰西斯科。弗兰西斯科:非常感谢这次换岗:天寒地冻,我心如刀绞。波洛涅斯:守卫期间安静吗?弗兰西斯科:连老鼠都没动静。波洛涅斯:好吧,晚安。如果你遇到霍拉旭和马西勒斯,我值班的伙伴们,叫他们快点来。弗兰西斯科:我想我听到他们了。站住!那边是谁?上
你。弗兰西斯科:不要靠近我,说话就行。波洛涅斯:你的手,你的声音 弗兰西斯科:我不想听你说话。波洛涅斯:弗兰西斯科,你的手,我恳求你。弗兰西斯科:不要靠近我。波洛涅斯:弗兰西斯科 弗兰西 守卫?弗兰西斯科:连老鼠都没动静。波洛涅斯:好吧,晚安。如果你遇到霍拉旭和马西勒斯,我值班的伙伴们,叫他们快点来。弗兰西斯科:我想我听到他们了。站住!那边是谁?上 连老鼠都没动静。波洛涅斯:好吧,晚安。如果你遇到霍拉旭和马西勒斯,我值班的伙伴们,叫他们快点来。弗兰西斯科:我想我听到他们了。站住!那边是谁?(霍拉旭和马西勒斯上)霍拉旭:这片土地的朋友。马西勒斯:丹麦的臣民。弗兰西斯科:祝你们晚安。马西勒斯:啊,再见,诚实的士兵:谁接替 波洛涅斯:守卫期间安静吗?弗兰西斯科:连老鼠都没动静。波洛涅斯:好吧,晚安。如果你遇到霍拉旭和马西勒斯,我值班的伙伴们,叫他们快点来。弗兰西斯科:我想我听到他们了。- 站住!那边是谁?霍拉旭和马西勒斯上。霍拉旭:这片土地的朋友。马西勒斯:丹麦的臣民。弗兰西斯科:祝你们晚安。马西勒斯:啊,再见,诚实的士兵:谁接替了你?
斯科:不要靠近我。波洛涅斯:我有封信 弗兰西斯科:不要靠近我。波洛涅斯:给国王的。弗兰西斯科:不要靠近我。波洛涅斯:没有叛 霍拉旭和马西勒斯 霍拉旭:这片土地的朋友。马西勒斯:丹麦的臣民。弗兰西斯科:祝你们晚安 马西勒斯:啊,再见,诚

References

参考文献

J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Conference of the North American Chapter of the Association for Computational Linguistics, June 2019. URL https://a cl anthology.org/N19-1423.

J. Devlin, M.-W. Chang, K. Lee 和 K. Toutanova。BERT: 用于语言理解的深度双向Transformer预训练。载于北美计算语言学协会会议,2019年6月。URL https://acl anthology.org/N19-1423。

J. Kaplan, S. McCandlish, T. Henighan, T. B. Brown, B. Chess, R. Child, S. Gray, A. Radford, J. Wu, and D. Amodei. Scaling laws for neural language models. CoRR, 2020. URL https://arxiv. org/abs/2001.08361.

J. Kaplan、S. McCandlish、T. Henighan、T. B. Brown、B. Chess、R. Child、S. Gray、A. Radford、J. Wu 和 D. Amodei。神经语言模型的缩放定律。CoRR,2020。URL https://arxiv.org/abs/2001.08361

X. Wei and W. B. Croft. LDA-based document models for ad-hoc retrieval. In ACM SIGIR International Conference on Research and Development in Information Retrieval, 2006. URL http://portal. acm.org/citation.cfm?doid $=$ 1148170.1148204.

X. Wei 和 W. B. Croft. 基于LDA的文档模型在即席检索中的应用. 发表于: ACM SIGIR国际信息检索研究与发展会议, 2006. URL http://portal.acm.org/citation.cfm?doid $=$ 1148170.1148204.

A. Datasets

A. 数据集

We provide a full description of Massive Text and of our extract of recent Wikipedia articles.

我们全面介绍了Massive Text以及我们对近期维基百科文章的提取内容。

A.1. Full description of Massive Text

A.1. Massive Text 完整描述

The full break down of Massive Text by source and languages is given in Table 8. For a full description and analysis of Massive Text, see Rae et al. (2021).

表8详细列出了Massive Text按来源和语言的完整分类。有关Massive Text的完整描述和分析,请参阅Rae等人(2021)的研究。

SourceLanguageToken count (M)DocumentsSampling weight
WebEn483,002604,938,8160.314
Ru103,95493,004,8820.033
Es95,762126,893,2860.033
Zh95,152121,813,4510.033
Fr59,45076,612,2050.033
De57,54677,242,6400.033
Pt44,56162,524,3620.033
It35,25542,565,0930.033
Sw2,2461,971,2340.0044
Ur631455,4290.0011
BooksEn3,423,74020,472,6320.25
NewsEn236,918397,852,7130.1
WikipediaEn3,9776,267,2140.0285
De2,1553,307,8180.003
Fr1,783
Ru1,4112,310,0400.003
Es2,767,0390.003
It1,2702,885,0130.003
Zh1,0712,014,2910.003 0.003
Pt927 6141,654,7720.003
Ur611,423,335 344,811
Sw1558,0900.0001 0.0004
Github374,952142,881,8320.05
Total5,026,4631,792,260,9981
来源 语言 Token数量 (百万) 文档数 采样权重
Web En 483,002 604,938,816 0.314
Ru 103,954 93,004,882 0.033
Es 95,762 126,893,286 0.033
Zh 95,152 121,813,451 0.033
Fr 59,450 76,612,205 0.033
De 57,546 77,242,640 0.033
Pt 44,561 62,524,362 0.033
It 35,255 42,565,093 0.033
Sw 2,246 1,971,234 0.0044
Ur 631 455,429 0.0011
Books En 3,423,740 20,472,632 0.25
News En 236,918 397,852,713 0.1
Wikipedia En 3,977 6,267,214 0.0285
De 2,155 3,307,818 0.003
Fr 1,783 - -
Ru 1,411 2,310,040 0.003
Es - 2,767,039 0.003
It 1,270 2,885,013 0.003
Zh 1,071 2,014,291 0.003
Pt 927 1,654,772 0.003
Ur 61 344,811 -
Sw 15 58,090 0.0001
Github - 374,952 142,881,832 0.05
总计 - 5,026,463 1,792,260,998 1

Table 8 | Massive Text dataset. The final column indicates the sampling weight for each dataset during training. For the retrieval database, the entire dataset is used, with the exception of books for which we use a sub-sample of $4%$ .

表 8 | 大规模文本数据集。最后一列表示训练期间每个数据集的采样权重。对于检索数据库,除书籍使用 $4%$ 的子样本外,其余均使用完整数据集。

A.2. Wikipedia September 2021

A.2. 维基百科 2021年9月

We create an evaluation dataset consisting of 23 Wikipedia articles that were added or heavily edited in September 2021, after we collected our training dataset. In addition, we filter out articles that rely too heavily on templated content, using the method detailed in $\S2.6$ to identify articles with chunks that have a high overlap with their neighbours. Fig. 10 show that little overlap remains between our test dataset and the retrieved neighbours from the training dataset. The full list of included articles is given in Table 9.

我们创建了一个评估数据集,包含23篇2021年9月新增或大幅编辑的维基百科文章(这些内容是在训练数据集收集完成后产生的)。此外,我们采用$\S2.6$章节所述方法过滤了过度依赖模板化内容的文章,通过识别与相邻文本块高度重合的段落进行剔除。图10显示测试数据集与训练数据集中检索到的相邻内容之间几乎不存在重叠。完整文章列表参见表9。

Table 9 | Full set of articles included in our Wikipedia Sept. 2021 evaluation dataset.

表 9: 我们2021年9月维基百科评估数据集中包含的完整文章集。

Megan Rohrer EmmaRaducanuAakashavaani Junior Eurovision Song Contest 2021
Ambra Sabatini WhyDonatePavilionBukit Jalil Blake Desjarlais
The Juggernaut (company) Angela Diaz2021 All-Ireland Senior Football Championship Final Drift-barrier hypothesis
2020 Summer ParalympicsVenomics
2021AfghanprotestsGreat Circle (novel)
Rexh XhakliHurricane Ida
Julia Laskin2021 Montenegrin episcopal enthronement protests
CuijkAtWarWiththeSilverfish
GhoubetWindPowerStation

| Megan Rohrer EmmaRaducanu | Aakashavaani 2021年欧洲青少年歌唱大赛 |
| Ambra Sabatini WhyDonate | PavilionBukit Jalil Blake Desjarlais |
| The Juggernaut (公司) Angela Diaz | 2021年全爱尔兰高级足球锦标赛决赛 漂移屏障假说 |
| 2020年夏季残奥会 | 毒液组学 |
| 2021阿富汗抗议活动 | 大圆环(小说) |
| Rexh Xhakli | 飓风艾达 |
| Julia Laskin | 2021年黑山主教就职抗议事件 |
| Cuijk | 与衣鱼作战 |
| | |
| | |
| Ghoubet风力发电站 | |

We first parse articles using mw parser from hell 5. We then remove sections with the following titles: “references”, “external links”, “sources”, “further reading”, “see also”, “citations”, and “note”. In the remaining sections, we remove Wikilinks and remove the following templates: “reflist”, “notelist”, “notelist-ua”, “notelist-lr”, “notelist-ur”, and “notelist-lg”. We also exclude objects with the “ref” or “table” tag and clean the remaining text with the strip_code function. Finally, we concatenate the title and all the sections and use $\setminus\cap\setminus\cap$ to delimitate them.

我们首先使用mw parser from hell 5解析文章。然后移除包含以下标题的章节:"references"、"external links"、"sources"、"further reading"、"see also"、"citations"和"note"。在剩余章节中,我们移除维基链接并删除以下模板:"reflist"、"notelist"、"notelist-ua"、"notelist-lr"、"notelist-ur"和"notelist-lg"。同时排除带有"ref"或"table"标签的对象,并使用strip_code函数清理剩余文本。最后,我们将标题与所有章节拼接起来,并使用$\setminus\cap\setminus\cap$作为分隔符。

B. Details on the retrieval architecture

B. 检索架构细节

We give details on the Retro architecture, and on the fine-tuning procedure we use for Retrofitting existing language models.

我们详细介绍Retro架构,以及用于改造现有大语言模型的微调流程。

B.1. Retro architecture and implementation

B.1. Retro架构与实现

B.1.1. Feed-forward architecture

B.1.1. 前馈架构

As mentioned in the main text, the overall encoder-decoder architecture is fully feed-forward. We start with a sequence $X\in\mathbb{V}^{n}=(C_{u})_ {1\leqslant u\leqslant l},$ and its pre-computed neighbours $(\mathrm{RET}(C_{u}))_ {1\leqslant u\leqslant l}$ and returns logits in $\mathbb{R}^{n\times|\mathbb{V}|}$ . Along with Attn, Ffw, Cca and Ca operators introduced in the main text, we define the decoder embedding layer $\mathbf{EMB}:\mathbb{V}^{n}\rightarrow\mathbb{R}^{n\times d}$ , the Split operator that extracts chunked intermediary embeddings $\operatorname{SpLIT}(H)\triangleq(H_{u})_ {1\leqslant u\leqslant l}\in\mathbb{R}^{l\times m\times d}$ and the read-out layer Read : $\mathbb{R}^{n\times d}\rightarrow$ $\mathbb{R}^{n\times|\mathbb{V}|}$ . We then describe the forward pass in Algorithm 1. In addition to the usual Transformer ones, Retro architecture hyper parameters involves the layer indices $P_{\mathrm{enc}}$ and $P$ , at which the encoder and the decoder perform cross-attention.

如正文所述,整体编码器-解码器架构采用全前馈设计。我们从序列$X\in\mathbb{V}^{n}=(C_{u})_ {1\leqslant u\leqslant l}$及其预计算邻域$(\mathrm{RET}(C_{u}))_ {1\leqslant u\leqslant l}$出发,最终输出$\mathbb{R}^{n\times|\mathbb{V}|}$维度的逻辑值。除正文介绍的Attn、Ffw、Cca和Ca算子外,我们还定义了解码器嵌入层$\mathbf{EMB}:\mathbb{V}^{n}\rightarrow\mathbb{R}^{n\times d}$、提取分块中间嵌入的Split算子$\operatorname{SpLIT}(H)\triangleq(H_{u})_ {1\leqslant u\leqslant l}\in\mathbb{R}^{l\times m\times d}$,以及读出层Read: $\mathbb{R}^{n\times d}\rightarrow$$\mathbb{R}^{n\times|\mathbb{V}|}$。算法1详细描述了前向传播过程。除常规Transformer参数外,Retro架构的超参数还包括编码器与解码器执行交叉注意力的层索引$P_{\mathrm{enc}}$和$P$。

B.1.2. Relative positional encoding in the chunked cross-attention layer

B.1.2. 分块交叉注意力层中的相对位置编码

The Ca operator uses relative positional logits, that are computed from a specific relative distance separating data tokens from retrieval tokens. Indeed, we expect any retrieval neighbour $\operatorname{RET}(C_{u})^{j}$ and the chunk $C_{u}$ to be relatively well aligned, and assume that they start at the same position. Therefore, when computing $\mathbf{CA}(H_{u}^{+},E_{u})$ , we set the distance between the data token $i\in[1,l]$ of chunk $C_{u}^{+}$ and

Ca 操作符使用相对位置逻辑值,这些逻辑值是根据数据 token 与检索 token 之间的特定相对距离计算得出的。实际上,我们期望任何检索邻居 $\operatorname{RET}(C_{u})^{j}$ 和块 $C_{u}$ 能够相对较好地对齐,并假设它们从同一位置开始。因此,在计算 $\mathbf{CA}(H_{u}^{+},E_{u})$ 时,我们将块 $C_{u}^{+}$ 的数据 token $i\in[1,l]$ 与

the retrieval token $i^{\prime}\in[1,2l]$ of $\mathrm{RET}(C_{u})^{j}$ to be

检索Token $i^{\prime}\in[1,2l]$ 属于 $\mathrm{RET}(C_{u})^{j}$

$$
d(i,i^{\prime})\triangleq i-i^{\prime}+l-1.
$$

$$
d(i,i^{\prime})\triangleq i-i^{\prime}+l-1.
$$

When computing the encoder cross-attentions $\mathsf{C A}(\mathrm{RET}(C_{u})^{j},H_{u})$ , we set the distance between the retrieval token $i^{\prime}\in[1,2l]$ and the data token $i\in[1,l]$ to be

在计算编码器交叉注意力 $\mathsf{C A}(\mathrm{RET}(C_{u})^{j},H_{u})$ 时,我们将检索token $i^{\prime}\in[1,2l]$ 与数据token $i\in[1,l]$ 之间的距离设为

$$
d_{\mathrm{enc}}(i^{\prime},i)\triangleq i^{\prime}-i.
$$

$$
d_{\mathrm{enc}}(i^{\prime},i)\triangleq i^{\prime}-i.
$$

Positional logits are obtained as a linear transform of a cosine vector computed from $(d(i,i^{\prime}))_{i,i^{\prime}}$ , and are added to content logits, as in a regular self-attention block.

位置logits是通过对$(d(i,i^{\prime}))_{i,i^{\prime}}$计算的余弦向量进行线性变换得到的,并像常规自注意力块中那样添加到内容logits中。

B.1.3. Chunked cross-attention implementation

B.1.3. 分块交叉注意力实现

Our implementation of the Cca operator, shown in Listing 1, is based on a vectorized application of a cross-attention layer. For simplicity, we omit the multi-head attention logic and use the simplest Q,K,V attention. We omit relative positional logits computation, described above.

我们在代码清单1中实现的Cca算子基于交叉注意力层的向量化应用。为简化起见,省略了多头注意力逻辑,仅使用最基本的Q,K,V注意力机制,同时省略了前文所述的相对位置对数计算部分。

B.1.4. Optional sharing of embedding matrices

B.1.4. 嵌入矩阵的可选共享

We use disjoint embeddings for the encoder and decoder by default, which allows us to use a different dimensionality for the encoder (typically kept at $d_{\mathrm{ENC}}=896\rangle$ and for the decoder (that we scale up to $d=8192^{\cdot}$ ). It is possible to share the embeddings, with little difference in training, as we show in the ablation section.

默认情况下,我们为编码器和解码器使用分离的嵌入(embedding),这使得编码器可以采用不同的维度(通常保持为 $d_{\mathrm{ENC}}=896\rangle$),而解码器可扩展至 $d=8192^{\cdot}$)。如消融实验部分所示,共享嵌入也是可行的,对训练影响甚微。

B.2. Baseline to Retro model fine-tuning

B.2. 从基线模型到Retro模型的微调

As shown in Fig. 5, we found that we were able to take a pre-trained baseline transformer and add Retro through fine-tuning. In all cases, we froze all weights from pre-training and freshly initial is ed the retrieval encoder and cross-attention weights. In all cases, the cross-attention is added every third layer starting at layer six. The learning rate for the three smaller models was set to $2\times10^{-4}$ and half that for the larger model. We experimented with allowing the entire model to resume training during fine-tuning but consistently found that the best approach was to freeze the pre-trained model. This kept the retrieval-off performance frozen whereas when all weights were tuned the retrieval off performance would degrade.

如图 5 所示,我们发现可以通过微调在预训练的基线 Transformer 基础上加入 Retro。在所有实验中,我们冻结了预训练的全部权重,并重新初始化了检索编码器和交叉注意力权重。交叉注意力均从第六层开始每隔三层添加一次。三个较小模型的学习率设置为 $2\times10^{-4}$,较大模型的学习率减半。我们尝试在微调时允许整个模型继续训练,但始终发现最佳方案是冻结预训练模型。这种方法能保持关闭检索时的性能不变,而调整全部权重会导致关闭检索时的性能下降。

C. Training details and hyper parameters

C. 训练细节与超参数

We provide the hyper parameters used in the various experiments of $\S4$ .

我们提供了在$\S4$各项实验中使用的超参数。

C.1. Language model pre-training

C.1. 语言模型预训练

In Table 10, we show the hyper parameters of the different models we train. In all cases, we train for 419,430,400,000 training tokens. The three smaller models are trained with a batch size of 256 and the largest model is trained with a batch size of 1024. The minimum learning rate is set to 0.1 times the maximum learning rate, which is shown in Table 10. The learning rate is decayed using a cosine cycle length that matches the total number of training tokens. All models are trained using AdamW (Loshchilov and Hutter, 2019) with a weight decay parameter of 0.1. The learning rate linearly increases from $10^{-7}$ to the maximum learning rate over the first 750 steps of training. All models use ZeRO to shard the optimiser state (Raj bh and ari et al., 2020). Additional infrastructure details can be found in Rae et al. (2021).

在表10中,我们展示了所训练不同模型的超参数。所有模型均训练了419,430,400,000个训练token。三个较小模型采用256的批量大小进行训练,最大模型则使用1024的批量大小。最小学习率设置为最大学习率的0.1倍(具体数值见表10)。学习率衰减采用与训练token总数匹配的余弦周期。所有模型均使用AdamW (Loshchilov and Hutter, 2019) 进行训练,权重衰减参数为0.1。学习率在前750个训练步中从 $10^{-7}$ 线性增长至最大学习率。所有模型均采用ZeRO技术对优化器状态进行分片 (Raj bh and ari et al., 2020)。更多基础设施细节可参阅Rae等人 (2021) 的研究。

Listing 1 | Jax implementation of the chunked cross attention, simplified.

列表 1 | 简化的分块交叉注意力 Jax 实现

$\mathrm{n}=\mathrm{128}$ # Sequence length $\mathrm{m}= 16$ # Chunk length $\mathrm{\textit {r} }=32$ # Retrieval length $\mathrm{\textit {k} }=\mathrm{\textit {4} }$ # Number of neighbours $\mathrm{d}= 16$ # Embedding size $\underline{{\mathrm{1}}}=\mathrm{n}$ // m # Number of chunks

$\mathrm{n}=\mathrm{128}$ # 序列长度$\mathrm{m}= 16$ # 块长度
$\mathrm{\textit {r} }=32$ # 检索长度
$\mathrm{\textit {k} }=\mathrm{\textit {4} }$ # 邻居数量
$\mathrm{d}= 16$ # 嵌入大小
$\underline{{\mathrm{1}}}=\mathrm{n}$ // m # 块数量

# Parameters

参数

${{Q}}=$ jnp.zeros((d, d)) $\mathrm{K}=$ jnp.zeros((d, d)) $\begin{array}{r l}{\nabla}&{{}=}\end{array}$ jnp.zeros((d, d))

${{Q}}=$ jnp.zeros((d, d)) $\mathrm{K}=$ jnp.zeros((d, d)) $\begin{array}{r l}{\nabla}&{{}=}\end{array}$ jnp.zeros((d, d))

def relative positional encodings(attending length, attended length): # Classical relative positional encodings

def relative positional encodings(attending length, attended length): # 经典相对位置编码

def cross attention(chunk, neighbour):

def cross_attention(chunk, neighbour):

m, ${{d}}=$ chunk.shape r, ${{d}}=$ neighbour.shape queries $=$ chunk @ Q keys $=$ neighbour @ K logits $=$ queries @ keys.T values $=$ neighbour @ V return logits, values

m, ${{d}}=$ chunk.shape
r, ${{d}}=$ neighbour.shape
queries = chunk @ Q
keys = neighbour @ K
logits = queries @ keys.T
values = neighbour @ V
return logits, values

def multi neighbour cross attention(chunk, neighbours): m, ${{d}}=$ chunk.shape k, r, d $=$ neighbours.shape

def multi neighbour cross attention(chunk, neighbours): m, ${{d}}=$ chunk.shape k, r, d $=$ neighbours.shape

def multi chunk cross attention(observation, neighbours): attending chunks $=$ jnp.pad(observation[m-1:], ((0, m - 1), (0, 0)), mode $=^{\prime}$ constant’).reshape(l, m, d)chunk ed output $=$ jnp.vectorize(multi neighbour cross attention, signature $=^{\prime}$ (m,d),(k,r,d)->(m,d)’)( attending chunks, neighbours) assert chunk ed output.shape $==$ (l, m, d) output $=$ jnp.pad(chunk ed output.reshape(n, d), ((m - 1, 0), (0, 0)), mode $=^{\prime}$ constant’)[:n] return output

def multi_chunk_cross_attention(observation, neighbours):
attending_chunks = jnp.pad(observation[m-1:], ((0, m - 1), (0, 0)), mode='constant').reshape(l, m, d)
chunked_output = jnp.vectorize(multi_neighbour_cross_attention, signature='(m,d),(k,r,d)->(m,d)')(attending_chunks, neighbours)
assert chunked_output.shape == (l, m, d)
output = jnp.pad(chunked_output.reshape(n, d), ((m - 1, 0), (0, 0)), mode='constant')[:n]
return output

Table 10 | Retro model hyper parameters, along with the size of the decoder.

Baselinedmodeldffw#headsHead size#layersPPENCMax LR
247M8963584166412[6,9,12][1]2x10-4
564M153661441212812[6,9,12][1]2x10-4
1,574M204881921612824[9,12,...,24][1]2x10-4
7,505M4096163843212832[9,12,...,32][1]1x10-4

表 10 | Retro 模型超参数及解码器尺寸

Baseline dmodel dffw #heads Head size #layers P PENC Max LR
247M 896 3584 16 64 12 [6,9,12] [1] 2x10-4
564M 1536 6144 12 128 12 [6,9,12] [1] 2x10-4
1,574M 2048 8192 16 128 24 [9,12,...,24] [1] 2x10-4
7,505M 4096 16384 32 128 32 [9,12,...,32] [1] 1x10-4

Table 11 | Hyper parameters for the Wikitext 103 experiments presented in Table 4. We use the same learning rate schedule for the baseline and the Retro-fitting. For Retro-fitting, we reset the schedule i.e. the schedule starts from step 0, not from step 35,000.

ModelNumber of layers d dFFW Key size Value size
Training dataNumber of heads Dataset Wikitext103train Sequence length Batch size Tokenizer vocabulary size
Optimisationoptimiser Adam's β1 Adam's β2 Adam's ε Dropout rate
Schedule Evaluation Overlapping proportionLearning rate start Learning rate max 2.5e-4 Learning rate min Warmup steps Cosine cycle steps 100,000

表 11 | 表 4 中 Wikitext 103 实验的超参数。我们对基线模型和 Retro-fitting 使用相同的学习率调度策略。对于 Retro-fitting,我们重置了调度策略,即调度从第 0 步开始,而非第 35,000 步。

模型 层数 d dFFW 键大小 值大小
训练数据 头数 数据集 Wikitext103train 序列长度 批量大小
优化 优化器 Adam's β1 Adam's β2 Adam's ε 丢弃率
调度评估重叠比例 学习率起始值 学习率最大值 2.5e-4 学习率最小值 预热步数

C.2. Wikitext 103 comparison

C.2. Wikitext 103 对比

We provide more details on our Wikitext 103 results presented in $\S4.1$ and Table 4. We train a baseline transformer on the Wikitext 103 training set with the hyper parameters presented in Table 11. The learning rate ramps linearly from $1\times10^{-7}$ to $2.5\times10^{-4}$ in the first 4,000 steps, then decays to $2\times10^{-5}$ at 100,000 steps using a cosine schedule. The baseline checkpoint at step 35,000 has the lowest perplexity on Wikitext 103 valid, of 21.58, for overlapping proportion of $75%$ (sliding window evaluation that only uses probabilities for tokens that have at least $75%$ of the sequence length of context, when available). We use this checkpoint for all our baseline and 𝑘NN-LM numbers reported in Table 4, except that Table 4 reports for an overlapping proportion of $87.5~%$ , which slightly lowers the perplexity of our baseline to 21.53 on Wikitext 103 valid.

我们在$\S4.1$和表4中展示了Wikitext 103结果的更多细节。使用表11中的超参数在Wikitext 103训练集上训练了一个基线Transformer模型。学习率在前4,000步从$1\times10^{-7}$线性增长到$2.5\times10^{-4}$,随后通过余弦调度在100,000步时衰减至$2\times10^{-5}$。在步数35,000处的基线检查点在Wikitext 103验证集上达到了最低困惑度21.58(重叠比例为$75%$的滑动窗口评估,仅当可用时使用至少$75%$序列长度上下文的token概率)。除表4报告的重叠比例为$87.5%$(将基线困惑度略微降至21.53)外,我们均使用该检查点生成表4中所有基线和𝑘NN-LM的结果数据。

We also use the 35,000 step baseline checkpoint as initialization for a Retrofit, which otherwise uses the same optimiser and schedule hyper parameters but only trains the new retrieval weights, as explained in $\S4.2$ . Our best Retrofit checkpoint has a Wikitext 103 valid perplexity 18.46, when retrieving from Wikipedia. We use this Retro checkpoint in Table 4 for all other retrieval sets. The evaluation curves for our baseline and Retrofit is shown if Fig. 7 (left). In this particular case, because Wikitext 103 is quite small, training a Retro model from scratch led to weaker results than the baseline, at least when retrieving from Wikipedia, as we couldn’t find an effective way to mitigate the increased over-fitting due to the additional weights of Retro.

我们还使用35,000步的基线检查点作为Retrofit的初始化,其余部分采用相同的优化器和调度超参数,但仅训练新的检索权重,如$\S4.2$所述。当从维基百科检索时,我们最佳的Retrofit检查点在Wikitext 103验证集上达到了18.46的困惑度。在表4中,我们使用该Retro检查点评估所有其他检索集。基线模型与Retrofit的评估曲线如图7(左)所示。在这个特定案例中,由于Wikitext 103数据集规模较小,从头开始训练Retro模型得到的结果弱于基线(至少从维基百科检索时如此),因为我们未能找到有效方法来缓解Retro新增权重带来的过拟合加剧问题。

We also re-implement 𝑘NN-LM using the same tokenizer and dataset that we use for our baseline and Retrofitting experiments. $k\mathrm{NN-LM}$ has probabilities $p_{k\mathsf{N N-L M}}=\lambda p_{L M}+(1-\lambda)p_{k N N}$ with $p_{k N N}(n_{k})\propto\exp(-\alpha d_{k})$ . To tune $\lambda$ and $\alpha$ , we begin with $\alpha=0.0012$ , which corresponds to the inverse of the standard deviation of the norm of the embeddings that we use as keys and queries for $k\mathrm{NN-LM}$ . We find the best $\lambda=0.118$ . We then find the best $\alpha=0.00785$ for that value of $\lambda$ . Fig. 7 center and right respectively show the perplexity of $k\mathrm{NN-LM}$ as a function of $\lambda$ and $\alpha$ .

我们还使用与基准实验和Retrofitting实验相同的分词器和数据集重新实现了𝑘NN-LM。$k\mathrm{NN-LM}$的概率计算式为$p_{k\mathsf{N N-L M}}=\lambda p_{L M}+(1-\lambda)p_{k N N}$,其中$p_{k N N}(n_{k})\propto\exp(-\alpha d_{k})$。为调整$\lambda$和$\alpha$参数,我们初始设定$\alpha=0.0012$(该值对应嵌入向量范数标准差的倒数,这些嵌入向量用作$k\mathrm{NN-LM}$的键和查询)。最终确定最优参数组合为$\lambda=0.118$,并基于该$\lambda$值进一步优化得到$\alpha=0.00785$。图7中右两图分别展示了$k\mathrm{NN-LM}$在$\lambda$和$\alpha$参数变化时的困惑度表现。


Figure 7 | Wikitext 103 valid perplexities. Left: Baseline and Retrofit (initialized from baseline’s checkpoint at 35,000 steps) perplexities as a function of training steps. Center and right: 𝑘NN-LM perplexity as a function of $\lambda$ (for $\alpha=0.0012^{}$ ) and $\alpha$ (for $\lambda=0.12$ ) respectively.

图 7 | Wikitext 103验证集困惑度。左图:基线模型与Retrofit(从基线模型第35,000步检查点初始化)的困惑度随训练步数变化曲线。中图与右图:𝑘NN-LM困惑度分别随$\lambda$(固定$\alpha=0.0012$)和$\alpha$(固定$\lambda=0.12$)的变化曲线。

C.3. Retrofitting baseline models experiments

C.3. 基线模型改造实验

In Table 12, we give the hyper parameters used for Retrofitting the models on Massive Text.

表 12: 我们列出了在Massive Text上用于模型改造(Retrofitting)的超参数。

Table 12 | Hyper parameters for the Retrofitting experiments

ModelLayers with RETRo-block (P)Learning rateBatchsize
172MEvery 3rd from 62 × 10-4→2× 10-5256
425MEvery 3rd from 62×10-4→2×10-5256
1.5BEvery 3rd from 62×10-4→2×10-5256
7.5BEvery 3rd fi from 61 × 10-4 →1 ×10-5256

表 12 | 改造实验的超参数

模型 使用RETRo-block的层数(P) 学习率 批量大小
172M 从第6层开始每3层 2×10⁻⁴→2×10⁻⁵ 256
425M 从第6层开始每3层 2×10⁻⁴→2×10⁻⁵ 256
1.5B 从第6层开始每3层 2×10⁻⁴→2×10⁻⁵ 256
7.5B 从第6层开始每3层 1×10⁻⁴→1×10⁻⁵ 256

C.4. Question answering experiments

C.4. 问答实验

We fine-tune our 7.5B Retro model for 25,000 steps, using a batch size of 128, a learning rate cosine scheduled from $10^{-6}$ to $10^{-7}$ , with a linear ramp of 750 steps. We use dropout in the decoder only, as it performs better than using dropout in both the encoder and the decoder. Each neighbour is formatted as title: {title}, source: {source}. We use the top 20 neighbours from Dpr when training and evaluating.

我们对7.5B参数的Retro模型进行了25,000步微调,批处理大小为128,学习率采用余弦调度从$10^{-6}$降至$10^{-7}$,并设置750步线性预热。仅在解码器中使用dropout,因为其表现优于同时在编码器和解码器中采用dropout。每个邻居数据格式化为title: {title}, source: {source}。训练和评估时均使用DPR检索的前20个邻居。

Table 13 | Performance of Retro for different variants. Model performance on C4 evaluation set, measured in bytes-per-bits, for a 247M parameter model trained with a 157 billion token schedule.

Ablation groupAblationC4 eval bpb
ModelRETRO0.822
No query conditioning0.829
No CA positional encodings0.826
Shared embeddings0.823
6-layer encoder0.821
Retrieval valuesNeighbours N0.950
Continuations F0.895
No retrieval0.987
Training neighbours1 training neighbours0.858
4 training neighbours0.847
Cross attention positionCA top layer (1/12)0.827
CA mid layer (6/12)0.823
CA top layer (12/12)0.831
CA all layers0.860
CA every 3 from 10.823

表 13 | Retro不同变体的性能表现。在C4评估集上,以每比特字节数(bpb)衡量的247M参数模型性能,训练时使用了1570亿token的调度计划。

消融组 消融项 C4评估bpb
模型 RETRO 0.822
无查询条件 0.829
无CA位置编码 0.826
共享嵌入 0.823
6层编码器 0.821
检索值 邻居N 0.950
延续F 0.895
无检索 0.987
训练邻居 1个训练邻居 0.858
4个训练邻居 0.847
交叉注意力位置 CA顶层(1/12) 0.827
CA中间层(6/12) 0.823
CA顶层(12/12) 0.831
CA所有层 0.860
CA每3层从1开始 0.823

D. Model ablations

D. 模型消融实验

We validate important design choices by evaluating what happens when we do not include them. We use the 247M parameter model for all experiments and we train on a compressed 157 billion token schedule for all ablation experiments. We describe results relative to the default settings presented in the main text and recalled here. We report C4 evaluation loss at the end of the training process, and also compares how the evaluation loss decrease versus the training time, measured relatively to the baseline training time. Results are reported in Fig. 8 and Table 13.

我们通过评估不包含某些设计选择时的效果来验证其重要性。所有实验均使用247M参数模型,并在1570亿token的压缩训练计划下进行消融实验。结果描述均相对于正文中提及的默认设置(此处复述)。训练结束时报告C4评估损失值,并对比评估损失下降速度与训练时长的关系(以基线训练时长为参照)。结果见图8和表13。

Using relative encodings in cross-attention. Using relative encodings in cross-attention, as described in $\S\mathrm{B}.1.2$ , provides a pure improvement both in the number of steps to reac