精读笔记:RETRO — Improving Language Models by Retrieving from Trillions of Tokens
论文信息
- 标题:Improving Language Models by Retrieving from Trillions of Tokens
- 作者:Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George van den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, Diego de Las Casas, Aurelia Guy, Jacob Menick, Roman Ring, Tom Hennigan, Saffron Huang, Loren Ante, Sam Schrittwieser, Oriol Vinyals, Simon Osindero, Karen Simonyan, Jack Rae, Erich Elsen, Laurent Sifre(均来自 DeepMind)
- 机构:DeepMind(Google DeepMind 前身)
- 发表时间:2021 年 12 月(arXiv 2112.04426)
- 期刊/会议:ICML 2022
阅读地图
本文提出了一种叫做 RETRO(Retrieval-Enhanced Transformer,检索增强的 Transformer)的语言模型。读完这篇精读笔记,你将理解:
- 为什么要做检索增强? 模型越来越大但代价高昂,有没有更省参数的方法?
- RETRO 怎么把输入文本切成 chunk? 为什么要按块检索而不是按词检索?
- 检索到的内容怎么融入生成? "分块交叉注意力(chunked cross-attention)"是什么?
- 2 万亿 token 的检索库是什么概念? 和 RAG/REALM 相比大了多少?
- 一个 75 亿参数的模型,怎么媲美 1750 亿参数的 GPT-3?
核心思想一句话: 把知识外包给数据库,模型只负责推理——小模型 + 超大检索库 ≈ 超大模型。
一、摘要(Abstract)
原文(关键句)
"We enhance auto-regressive language models by conditioning on document chunks retrieved from a large corpus, based on local similarity with preceding tokens. With a 2 trillion token database, our Retrieval-Enhanced Transformer (Retro) obtains comparable performance to GPT-3 and Jurassic-1 on the Pile, despite using 25× fewer parameters. After fine-tuning, Retro performance translates to downstream knowledge-intensive tasks such as question answering. Retro combines a frozen Bert retriever, a differentiable encoder and a chunked cross-attention mechanism to predict tokens based on an order of magnitude more data than what is typically consumed during training. We typically train Retro from scratch, yet can also rapidly Retrofit pre-trained transformers with retrieval and still achieve good performance."
翻译
我们通过在生成时融合从大型语料库中检索到的文档片段(基于与前文 token 的局部相似度),来增强自回归语言模型。借助一个包含 2 万亿 token 的数据库,我们的检索增强 Transformer(RETRO)在 Pile 数据集上取得了与 GPT-3 和 Jurassic-1 相当的性能,但参数量仅为它们的 1/25。经过微调后,RETRO 的性能可以迁移到问答等知识密集型下游任务。RETRO 结合了一个冻结的 BERT 检索器、一个可微分的编码器和一个分块交叉注意力机制,能够利用比训练时通常消耗的数据多一个数量级的信息来预测 token。我们通常从头训练 RETRO,但也可以快速地将现有预训练 Transformer 改装(retrofit) 为支持检索的版本,同样能取得良好性能。
新手讲解
这段摘要传达了三个震撼性的信息:
信息 1:2 万亿 token 的数据库
想象你要回答一道历史题。你有两个选择:
- A:把整个图书馆的知识背进脑子里(=训练一个超大模型)
- B:考试时带着一张图书馆索引卡,随时查阅(=RETRO)
RETRO 选择了 B。它的"索引卡"是一个包含 2 万亿个 token 的数据库——相当于大约 400 万本书的文字量。
信息 2:参数量 1/25 却表现相当
GPT-3 有 1750 亿参数,RETRO 最大版本只有约 75 亿参数(约 1/25),但在同一个测试集(The Pile)上性能相近。这说明把知识"外包"出去的策略是可行的。
信息 3:三个核心技术组件(先留个印象,后面会详细讲)
- 冻结的 BERT 检索器(frozen BERT retriever):用 BERT 模型将文本变成向量,用来找相似内容,但训练时这个 BERT 不更新参数("冻结")
- 可微分编码器(differentiable encoder):把检索到的文档片段编码成向量,这个编码器会随训练更新
- 分块交叉注意力(chunked cross-attention):把检索结果"注入"生成过程的核心机制
二、引言(Introduction)
2.1 语言模型的规模化困境
原文
"Large performance improvements have come from increasing the amount of data, training compute, or model parameters. Transformers have been scaled from 100 million parameter models in seminal work to over hundred billion parameters in the last two years... Increasing model size predictably improves performance on a wide range of downstream tasks."
"The benefits of increasing the number of parameters come from two factors: additional computations at training and inference time, and increased memorization of the training data."
翻译
性能的大幅提升主要来自三方面:增加数据量、增加训练算力、增加模型参数。Transformer 从最初的 1 亿参数模型,在过去两年里扩展到了超过 1000 亿参数……扩大模型规模可以预期地提升一系列下游任务的性能。
增加参数量带来的收益来自两个因素:一是训练和推理时的额外计算能力,二是对训练数据的更多记忆。
新手讲解
这里作者在做一道"解构题":大模型为什么更强?他们发现答案是两件事:
1. 更多计算:更大的网络能做更复杂的推理
2. 更多记忆:更多参数能记住更多训练数据里的知识
关键洞察:这两件事是捆绑在一起的——你想要更多记忆,就必须买更多参数,也必须付出更多算力。RETRO 的动机就是:能不能把这两件事拆开? 让模型负责推理,让外部数据库负责记忆。
2.2 RETRO 的核心主张
原文
"In this work, we endeavor to decouple these, by exploring efficient means of augmenting language models with a massive-scale memory without significantly increasing computations. Specifically, we suggest retrieval from a large text database as a complementary path to scaling language models. Instead of increasing the size of the model and training on more data, we equip models with the ability to directly access a large database to perform predictions—a semi-parametric approach."
翻译
在本文中,我们试图将这两者解耦,探索在不显著增加计算量的情况下,用大规模记忆来增强语言模型的有效方法。具体而言,我们建议将从大型文本数据库中检索作为扩展语言模型的补充路径。我们不是增加模型规模和训练数据,而是赋予模型直接访问大型数据库以进行预测的能力——这是一种半参数化(semi-parametric)的方法。
新手讲解
"半参数化"是什么意思?
- 纯参数化(parametric):所有知识都存在模型参数里,比如 GPT-3 的 1750 亿个权重数字里。
- 纯非参数化(non-parametric):完全靠查数据库,比如最简单的搜索引擎。
- 半参数化(semi-parametric):两者结合。RETRO 有一个参数模型负责理解和推理,同时有一个外部数据库负责存储知识。
类比:这就像一个开卷考试和闭卷考试的区别。GPT-3 是闭卷考试选手,把所有知识背在脑子里。RETRO 是开卷考试选手,随时可以翻书,但也需要足够的理解能力才能用好书上的内容。
2.3 与前人工作的区别
原文
"Existing retrieval for language modelling work only considers small transformers (100 millions parameters) and databases of limited size (up to billions of tokens). To our knowledge, our work is the first to show the benefits of scaling the retrieval database to trillions of tokens for large parametric language models."
翻译
现有的检索增强语言建模工作只考虑了小型 Transformer(1 亿参数量级)和规模有限的数据库(最多数十亿 token)。据我们所知,本文是第一个展示将检索数据库扩展到万亿 token 量级对大型参数语言模型带来收益的工作。
新手讲解:RETRO 与前人工作的关键差异
| 系统 | 检索粒度 | 检索库规模 | 是否在预训练阶段用检索 |
|---|---|---|---|
| kNN-LM(Khandelwal et al. 2019) | 每个 token 都检索 | ~数十亿 token(英语维基百科等) | 否(只在推理时用) |
| REALM(Guu et al. 2020) | 按整个 prompt 检索 | ~数十亿 token | 是,但规模小 |
| RAG(Lewis et al. 2020) | 按整个 prompt 检索 | ~数十亿 token | 否(在下游任务微调时用) |
| FiD(Izacard & Grave 2021) | 按整个 prompt 检索 | ~数十亿 token | 否(问答微调) |
| RETRO(本文) | 按 chunk(64 token 块)检索 | 1.75 万亿 token | 是,贯穿整个预训练 |
三大核心差异:
1. 规模跨越:从"十亿"跳到"万亿",差了大约 1000 倍
2. 检索时机:RETRO 在预训练阶段就用检索,而不是只在下游任务微调时才用
3. 检索粒度:RETRO 按"chunk(块)"检索,而不是按整个句子——这是 RETRO 的技术核心,下面重点讲
三、方法(Method)——核心技术章节
术语预习(首次出现,必须理解)
在读方法之前,先把几个核心术语搞清楚:
Chunk(块/片段):RETRO 把每篇文章切成若干个固定长度的小块,每块 64 个 token(大约 40-50 个英语单词)。这是 RETRO 检索的基本单位。为什么是块而不是整篇文章?因为文章太长,一篇文章里不同段落的主题可能不同,按块检索更精确。
Nearest Neighbor(最近邻):给定一个 chunk,在数据库里找和它语义最相似的 chunk。"最近"是用向量距离衡量的——把 chunk 变成向量,距离最近的就是语义最相似的。
Cross-Attention(交叉注意力):Transformer 里的一种注意力机制。普通的"自注意力"是序列内部的 token 互相看;"交叉注意力"是让序列 A 的 token 去看序列 B 的内容。RETRO 用交叉注意力让"正在生成的文本"去看"检索到的文档"。
Frozen Retriever(冻结的检索器):用 BERT 把文本变成向量来做检索。"冻结"意味着这个 BERT 的参数在训练过程中不更新——它的权重是固定的。好处是不需要每次更新后重新给整个数据库建索引(这个操作代价极高)。
3.1 训练数据与检索库(Section 2.1)
原文
"We use a multi-lingual version of MassiveText containing over 5 trillion tokens from multiple sources... During training, we retrieve from 600B tokens from the training data. For evaluation, the retrieval database consists in the full union of these datasets... totaling 1.75 trillion tokens."
"We use n=2048 and m=64" (序列长度和 chunk 大小)
翻译
我们使用一个包含超过 5 万亿 token 的多语言 MassiveText 数据集,来源涵盖网络内容、书籍、新闻、维基百科和 GitHub 代码。训练时,我们从其中 6000 亿 token 的子集中检索。评估时,检索数据库由所有数据集的并集组成,共计约 1.75 万亿 token。
我们使用序列长度 n=2048,chunk 大小 m=64。
新手讲解
数字直觉:
- 序列长度 n=2048:每次模型处理 2048 个 token(大约 1500 个英语单词,约 2-3 页文章)
- chunk 大小 m=64:每个块 64 个 token(大约 40-50 个词,一两句话到一小段)
- 一个序列会被切成 2048÷64 = 32 个 chunk
检索库的规模感:
- 英语维基百科:约 40 亿 token(4B)
- Common Crawl(整个互联网快照之一):约几万亿 token
- RETRO 的 1.75 万亿 token 检索库 ≈ 大约 438 倍的英语维基百科
训练 vs 评估的区别:训练时只从 6000 亿 token 的子集检索(节省成本),评估时用完整的 1.75 万亿 token 数据库(发挥最大性能)。
3.2 检索机制的形式化定义(Section 2.2)
原文
"We split sequences into a sequence of chunks of size m. Each chunk Cu is augmented with a set Ret(Cu) of k neighbours from the database. Ret(C1)=∅, namely the likelihood of tokens from the first chunk does not depend on any retrieval data."
翻译
我们将序列拆分为大小为 m 的 chunk 序列。每个 chunk Cu 被扩充了从数据库中检索到的 k 个近邻集合 Ret(Cu)。第一个 chunk Ret(C1)=∅,即第一个 chunk 的 token 概率不依赖任何检索数据。
新手讲解
RETRO 的核心生成逻辑图解:
原始文本(2048 token)
↓ 切分
[C1: 第1~64 token] → 检索 → 找到 k 个近邻 → 编码 → 融入C2的生成
[C2: 第65~128 token] → 检索 → 找到 k 个近邻 → 编码 → 融入C3的生成
[C3: 第129~192 token] → 检索 → ...
...以此类推...
关键规则:用当前 chunk Cu 的内容去检索,检索结果用于帮助下一个 chunk Cu+1 的预测。这样做保持了因果性(causal):生成 token 时只能看到已经出现过的内容,不能"偷看"未来。
为什么第一个 chunk 不检索? 因为在第一个 chunk 之前没有上下文,我们不知道该检索什么主题。从第二个 chunk 开始,用第一个 chunk 的内容去检索就有意义了。
k 是什么? 每个 chunk 检索 k 个最相似的数据库片段。训练时 k=2,评估时可以用更大的 k(比如 k=10 甚至 k=40)来提升性能。
3.3 检索数据库的结构:键值存储(Section 2.3)
原文
"Each value consists of two contiguous chunks of tokens [N,F], where N is the neighbour chunk and F is its continuation. The corresponding key is the BERT embedding of N, averaged over time. For each chunk C, the system retrieves its approximate k-nearest neighbours using L2 distance on BERT embeddings. To prevent causality violations, we filter out neighbours originating from the same document as the training sequence X."
翻译
数据库中的每个条目是一个键值对:值由两个相邻的 chunk 组成,即 [N, F],其中 N 是邻居块,F 是其后续块(continuation);键是对 N 进行 BERT 编码后的时间维度平均向量。对于每个 chunk C,系统用 L2 距离在 BERT 向量空间中检索近似的 k 个最近邻。为了防止违反因果性,我们过滤掉来自同一篇文档的近邻。
新手讲解
数据库条目的结构([N, F] 键值对):
数据库里存的一个条目:
┌─────────────────────────────────────────────────────────┐
│ 键(Key):N 的 BERT 向量(768维浮点数向量) │
│ 值(Value):[N(64 token)| F(64 token)]共128 token │
└─────────────────────────────────────────────────────────┘
N:被检索到的那段文本("邻居块")
F:N 之后紧跟的那段文本("后续块",F for Future/Following)
为什么要存 F? 只存 N 的话,模型只能看到"和我要生成的内容相似的那段话",但看不到"那段话之后接着说了什么"。存 F 让模型能获得更多上下文,比如在 N 描述了某人的背景后,F 可能包含了具体的成就或事件,这些都有助于生成更准确的下文。
BERT 向量的作用:BERT 是一个预训练的语言模型,能把一段文字转换成一个高维向量(数字列表)。语义相似的文字会产生相近的向量。RETRO 用这个向量来衡量"相似度"。
L2 距离:两个向量之间的欧氏距离。向量越近,说明语义越相似。这就是"最近邻(nearest neighbor)"中"近"的含义。
过滤同文档近邻:如果我正在学习文章 A,检索时把文章 A 自己的其他段落检索出来,那就相当于"作弊"——模型只是在复读同一篇文章。RETRO 明确过滤掉这种情况。
近似检索的实现:用 SCaNN 库进行近似最近邻搜索,复杂度 O(log T)(T 是数据库大小),在 2 万亿 token 的数据库上每次检索约 10 毫秒。
3.4 RETRO 模型架构(Section 2.4)——重中之重
原文
"The architecture uses an encoder-decoder transformer architecture, integrating the retrieved data through a cross-attention mechanism. The main innovation is the Retro-block, which interleaves standard transformer layers with specialized retrieval blocks. The Retro-block: Retro(H,E) ≜ Ffw(Cca(Attn(H),E)), and the standard layer: Lm(H) ≜ Ffw(Attn(H))"
翻译
RETRO 采用编码器-解码器 Transformer 架构,通过交叉注意力机制将检索到的数据融入。核心创新是 Retro-block(RETRO 块),它将标准 Transformer 层与专门的检索块交错排列。
- RETRO 块:Retro(H,E) = Ffw(Cca(Attn(H),E))
- 标准层:Lm(H) = Ffw(Attn(H))
(其中 Ffw=前馈层,Attn=自注意力层,Cca=分块交叉注意力层)
新手讲解
标准 Transformer 层 vs RETRO 块,直观对比:
标准 Transformer 层(GPT 用的那种):
输入 H
↓ 自注意力(Attn):序列内部 token 互相看
↓ 前馈网络(Ffw):逐 token 的变换
输出 H'
RETRO 块(RETRO 用的):
输入 H 和 检索编码 E
↓ 自注意力(Attn):序列内部 token 互相看
↓ 分块交叉注意力(Cca):生成的文本 → 看检索到的内容 E
↓ 前馈网络(Ffw)
输出 H'
不是每一层都是 RETRO 块! RETRO 的深层 Transformer 中,大部分层是普通的标准层,只有少数层是 RETRO 块(每隔几层插入一个)。这样既能融入检索信息,又不会让计算量爆炸。
3.5 检索内容的编码器(Section 2.4 续)
原文
"The retrieval encoder is a non-causal transformer... conditioned on Hu through cross-attention layers. The encoder processes k retrieval neighbours Ret(Cu) conditioned on the activations of chunk Cu through cross-attention. All neighbours are encoded in parallel, yielding encoded set E."
翻译
检索内容的编码器是一个非因果 Transformer(即双向 Transformer,可以同时看前后文),它通过交叉注意力层以 Hu(当前 chunk 的激活)为条件进行编码。编码器对 chunk Cu 检索到的 k 个近邻 Ret(Cu) 进行处理,以 Cu 的激活值为条件,通过交叉注意力编码。所有近邻并行编码,生成编码集合 E。
新手讲解
编码器的两个关键特性:
-
双向(non-causal):普通 GPT 的 Transformer 只能看"过去"(从左到右),因为它在生成时不知道"未来"。但检索到的文档是静态的,既然有完整的文档,为何不双向理解它?RETRO 的编码器用双向 Transformer 来"读懂"检索到的文档,效果更好。
-
以当前 chunk 为条件(conditioned on Cu):编码器不是独立理解检索文档,而是"结合我现在在说什么"来理解检索内容。这让编码器能突出检索文档里与当前主题最相关的部分。
类比:想象你在写一篇关于"爱因斯坦和相对论"的文章,你去查百科全书。编码器不是把整个爱因斯坦词条都平等对待,而是因为你当前在写"相对论",所以会特别关注词条里关于相对论的部分。
3.6 分块交叉注意力(Chunked Cross-Attention, CCA)——最核心机制
原文
"Cca(H,E)um+i−1 ≜ Ca(hum+i−1, Eu)"
"The Cca operation splits intermediate activations into l−1 attending chunks. Cross-attention is computed between Hu+ and Eu — the encoded retrieval set obtained from chunk Cu. The outputs from per-chunk cross-attentions are concatenated across time. The first m−1 tokens cannot attend to any neighbour and Cca becomes the identity."
翻译
分块交叉注意力(CCA)操作将中间激活分为 l-1 个"注意力块"。交叉注意力计算在 Hu+(当前 chunk 扩展的激活)和 Eu(从 chunk Cu 检索到的编码近邻)之间进行。每个块的交叉注意力输出沿时间维度拼接。第一个 chunk 的前 m-1 个 token 无法注意到任何近邻,此时 Cca 退化为恒等映射。
新手讲解
这是 RETRO 最难理解也最核心的部分,我们用一个完整的例子来讲清楚。
假设你在生成一段 192 token 的文章(3 个 chunk,每个 chunk 64 token):
序列:
C1: token 1~64 ("The history of artificial intelligence...")
C2: token 65~128 ("Alan Turing proposed the Turing Test...")
C3: token 129~192 ("Modern deep learning began with...")
分块交叉注意力的工作流程:
步骤1:生成 C1(token 1~64)
→ 无检索,正常生成
→ 同时,用 C1 的内容去数据库里检索 k=2 个最近邻
→ 检索结果:[N1,F1], [N2,F2](关于AI历史的类似文段)
→ 用编码器编码这些近邻 → 得到 E1
步骤2:生成 C2(token 65~128)
→ 从 token 65 开始,模型可以通过分块交叉注意力"读取" E1
→ 也就是说:C2 的每个 token 在生成时都能参考 E1(关于AI历史的检索内容)
→ 同时,用 C2 的内容去检索 → 得到 E2
步骤3:生成 C3(token 129~192)
→ 参考 E2(关于图灵测试的检索内容)
→ 以此类推
"分块"的精妙之处:
- 不同的 chunk 检索不同的近邻
- 生成 C2 时,只用 C1 检索到的近邻(而不是 C2 检索到的——那还没算出来!这保持了因果性)
- 这使得随着文章主题的变化,检索内容也能动态更新,而不是始终使用同一个检索结果
与"普通 RAG"的对比:
普通 RAG(如 RAG 论文 Lewis et al. 2020):
整篇文章 → 检索一次 → 把检索结果拼在开头 → 生成
RETRO:
文章第1块 → 检索 → 辅助第2块生成
文章第2块 → 检索 → 辅助第3块生成
文章第3块 → 检索 → 辅助第4块生成
...(动态、逐块更新的检索)
类比:普通 RAG 像是"考试前查一次百科全书,然后闭卷答题";RETRO 像是"每写一段就查一次相关资料,边查边写"。
3.7 采样(推理)时的工作流程
原文
"When sampling, at the end of a chunk Cu, the system uses SCaNN to retrieve neighbours based on Bert(Cu). The encoded neighbours then condition the generation of the next chunk Cu+1, computed incrementally."
翻译
在采样(推理)时,在每个 chunk Cu 的末尾,系统使用 SCaNN 根据 Bert(Cu) 检索近邻。编码后的近邻随后作为条件,影响下一个 chunk Cu+1 的生成,并进行增量计算。
新手讲解
推理时的完整流程:
1. 输入:用户 prompt(前几个 token)
2. 生成第一个 chunk(64 token):不用检索,正常生成
3. 到第一个 chunk 末尾:
a. 用 BERT 将当前 chunk 变成向量
b. 用 SCaNN 在 1.75 万亿 token 的数据库里找 k 个最近邻(~10 毫秒)
c. 用编码器处理这些近邻
4. 生成第二个 chunk:每个 token 的生成都能"看到"第一步检索到的内容
5. 到第二个 chunk 末尾:再次检索,更新检索内容
6. 重复直到生成完毕
计算开销:每 64 个 token 检索一次,每次约 10 毫秒。对于 2048 token 的序列,需要检索 32 次,总检索时间约 320 毫秒。这对在线推理有一定延迟,但对大多数应用是可接受的。
3.8 冻结的 BERT 检索器的意义
原文
"Retro combines a frozen Bert retriever, a differentiable encoder and a chunked cross-attention mechanism"
翻译(前文已引用,此处补充讲解)
为什么检索器要"冻结"(frozen)?
这是 RETRO 系统设计中一个非常重要的工程决策:
如果检索器(BERT)的参数也随训练更新,会怎样?
- 每次 BERT 参数变化,所有数据库条目的向量(键)都需要重新计算
- 1.75 万亿 token ÷ 64 = 约 273 亿个数据库条目,每个都要重新用新 BERT 算一遍向量
- 这个代价是灾难性的,根本无法在每次训练步骤后都做
"冻结"的代价与收益:
- 代价:检索质量不会随训练提升,用的始终是预训练 BERT 的语义理解
- 收益:数据库索引建一次就够用,大幅降低了系统复杂度
- 实验表明:即使是冻结的 BERT,检索质量也已足够好,能带来显著的性能提升
类比:就像图书馆的索引卡(图书馆的目录系统)一旦建好就不需要每天重建,即使偶尔有新书入库,旧的索引依然有效。
3.9 数据集泄露量化方法(Section 2.6)
原文
"For each evaluation chunk C, we retrieve the 10 closest neighbours (of length up to 128) in the training data. We then compute the longest token substring common to both the evaluation chunk and its neighbours. This gives a number s∈[0,m]. The value r(C)=s/m, ranging from 0 (chunk never seen) to 1 (chunk entirely seen), gives a reliable indication of how much overlap there is between the evaluation chunk and the training data."
"∀α∈[0,1], Cα ≜ {C∈C, r(C)≤α}" — chunks with overlap ratio less than α
翻译
对于每个评估 chunk C,我们在训练数据中检索 10 个最近邻(长度最多 128 个 token)。然后计算评估 chunk 和其近邻之间最长公共 token 子串的长度 s∈[0,m]。r(C)=s/m,范围从 0(从未见过的 chunk)到 1(完全相同的 chunk),可以可靠地反映评估 chunk 与训练数据之间的重叠程度。
对于所有 α∈[0,1],定义 Cα 为重叠率不超过 α 的 chunk 子集。
新手讲解
数据泄露(data contamination)是什么?
测试集里的内容如果在训练集里出现过,模型可能只是"背了答案"而不是真正"学会了"。对于 RETRO 这类检索增强模型,这个问题更严重:模型可能通过检索直接把答案从训练数据"抄"过来。
RETRO 的检测方法:
- 对测试集的每个 chunk,找训练集里最相似的 10 个片段
- 计算最长公共子串长度 s,除以 chunk 长度 m,得到重叠率 r
- 例如:如果一个 64 token 的测试 chunk 和训练集里某个片段有 16 个连续 token 完全相同,则 r=16/64=25%
关键发现:即使只看重叠率接近于 0 的 chunk(几乎没有泄露的 chunk),RETRO 依然显著优于不使用检索的基线模型。这说明 RETRO 的性能提升不仅仅来自"抄答案",更来自真正的知识融合和推理能力。
四、实验结果
4.1 RETRO 的模型规模
根据论文 Table 2,RETRO 训练了四个不同规模的版本(参数量不含词嵌入):
| 基线模型参数量 | RETRO 版本参数量 | 参数增量 |
|---|---|---|
| 1.32 亿(132M) | 1.72 亿(172M) | +30% |
| 3.68 亿(368M) | 4.25 亿(425M) | +15% |
| 13 亿(1.3B) | 14.5 亿(1.45B) | +11% |
| 70 亿(7.0B) | 75 亿(7.53B) | +8% |
关键观察:加入检索组件后,参数量只增加了 8%~30%,但性能提升远不止 8%~30%。随着模型越大,检索组件的参数占比越小(因为检索组件规模相对固定)。
4.2 与大模型的对比(Section 4.1)
原文(图4标题)
"The Pile: Comparison of our 7B baseline against Jurassic-1, Gopher, and Retro. We observe that the retrieval model outperforms the baseline on all test sets and outperforms Jurassic-1 on a majority of them, despite being over an order of magnitude smaller."
翻译
The Pile 基准测试:我们的 7B 基线模型与 Jurassic-1、Gopher 以及 RETRO 的对比。我们观察到,检索模型在所有测试集上均优于基线,并在大多数测试集上优于 Jurassic-1,尽管参数量比 Jurassic-1 小一个数量级以上。
模型参数规模对比:
| 模型 | 参数量 |
|---|---|
| RETRO 7.5B(本文) | 75 亿 |
| GPT-3 | 1750 亿 |
| Jurassic-1 | 1780 亿 |
| Gopher | 2800 亿 |
新手讲解
这个结果的震撼性在哪里?
- Jurassic-1(178B)的参数量是 RETRO 7.5B 的约 24 倍
- Gopher(280B)的参数量是 RETRO 7.5B 的约 37 倍
- RETRO 用 1/24 的参数,在大多数测试集上超过了 Jurassic-1
这直接印证了摘要中的核心主张:把知识外包给数据库,小模型可以媲美甚至超过大模型。
需要注意的是,这个比较有其背景条件:Jurassic-1 和 Gopher 是在自己的数据上训练的,而 RETRO 在测试时可以"查阅"2 万亿 token 的数据库。这类似于"开卷考试 vs 闭卷考试"的对比,需要理性看待。
4.3 检索库规模的影响(Section 4.1,数据缩放实验)
原文
"We observe dramatic gains as the retrieval data is increased from Wikipedia (4 billion tokens) to all of MassiveText (1.7T tokens)."
翻译
我们观察到,随着检索数据从维基百科(40 亿 token)增加到完整的 MassiveText(1.7 万亿 token),性能出现了显著提升。
新手讲解
这个实验回答了一个关键问题:检索库必须那么大吗? 答案是:越大越好,而且提升是非线性的。
实验设置:固定同一个 RETRO 模型,只改变检索库的大小,观察语言建模困惑度(perplexity,越低越好)的变化:
- 检索库 = 英语维基百科(4B token):perplexity 约 18.97
- 检索库 = MassiveText 10%(约 175B token):perplexity 约 14.95
- 检索库 = 完整 MassiveText(1.75T token):perplexity 进一步下降
这说明检索库规模本身就是一个重要的"超参数",而且 RETRO 能充分利用更大的数据库——不只是多存了更多数据,而是在推理时真正有效地利用了它们。
此外,实验还表明,检索更多的近邻(增大 k)对更大的模型帮助更显著,可以用到 k=40 个近邻而依然有性能提升。
4.4 问答任务表现(Section 4.3)
原文
"Our method is competitive with previous approaches such as Realm, RAG and DPR, but underperforms the more recent FiD."
翻译
我们的方法与 REALM、RAG 和 DPR 等先前方法相当,但弱于更新的 FiD 方法。
新手讲解
RETRO 在 NaturalQuestions 问答任务上取得了约 45.5% 的 exact match(精确匹配率),优于 RAG 和 REALM,但略逊于 FiD(Fusion-in-Decoder)。
为什么在问答任务上不是最强?
RETRO 的设计目标是通用语言建模(预测下一个 token),而不是专门为问答任务优化。FiD 是专门针对问答任务设计的系统,在这个领域更专精。这类似于"全能选手 vs 专项选手"的差距,并不影响 RETRO 在通用语言建模上的价值。
4.5 数据泄露分析结论
原文
"Retro outperforms baseline models at all leakage levels, down to α=12.5%." (α 代表重叠率阈值)
翻译
RETRO 在所有泄露水平下都优于基线模型,包括低至 α=12.5% 的情况(即评估 chunk 与训练数据的重叠不超过 8 个 token 的子集)。
新手讲解
这是一个重要的"反驳实验"。有人可能质疑:RETRO 的性能提升是不是只因为测试时检索到了训练集里的内容,本质上是"作弊"?
这个实验的结论是:即使排除掉那些与训练集有任何显著重叠的测试样本,RETRO 依然比基线好。换句话说,即使是测试集里那些"从未出现过"的内容,RETRO 也能更好地生成,这证明了检索带来的是真正的知识融合能力,而不仅仅是记忆复现。
五、改装现有模型(Retrofitting)
原文
"Fine-tune by randomly initializing and training only the chunked cross-attention and retrieval encoder weights. This approach requires only 3% of pre-training tokens to achieve near-scratch performance levels."
翻译
通过随机初始化并只训练分块交叉注意力和检索编码器的权重来微调现有模型。这种方法只需要预训练 token 数量的约 3% 就能达到接近从头训练的性能水平。
新手讲解
这是 RETRO 非常实用的特性:"改装(retrofit)"。如果你已经有一个训练好的标准语言模型(比如 GPT 系列),不需要从头重新训练就能给它加上检索能力:
- 在模型中插入分块交叉注意力层和检索编码器(随机初始化)
- 冻结原来模型的所有参数
- 只用少量数据(约原训练量的 3%)训练新加的组件
- 完成!现在这个模型有了检索能力
这大大降低了应用 RETRO 技术的门槛,不必花费巨大算力从头训练。
六、总结与意义
论文的核心贡献
- 规模突破:将检索库从"十亿"扩展到"万亿",证明规模对检索增强模型同样重要
- 架构创新:分块交叉注意力(CCA)机制,允许在生成过程中动态、逐块地更新检索内容
- 系统设计:冻结 BERT 检索器 + 可学习编码器,在效率与性能之间取得平衡
- 实验验证:7.5B 参数的 RETRO 媲美 175B~280B 参数的模型,验证了"知识外包"策略
对 AI 发展的意义
RETRO 的理念开创了一条与"暴力扩大模型"平行的路径:
- 传统路径:模型 → 更多参数 → 更多计算 → 更好性能
- RETRO 路径:合理规模的模型 + 超大检索库 → 接近甚至超越更大的模型
这一思路深刻影响了后来的 RAG 系统设计,也启发了很多"参数高效"的 AI 研究方向。
RETRO 的局限
- 推理延迟:每 64 个 token 需要一次数据库检索,增加了推理时间
- 基础设施需求:需要维护一个万亿 token 量级的检索数据库,存储和索引成本高
- 检索质量上限:冻结的 BERT 检索器限制了检索质量的天花板,无法随任务微调
- 非端到端优化:检索器和生成器不是联合优化的,可能存在不一致
七、术语速查表
| 术语 | 英文 | 解释 |
|---|---|---|
| 分块 | Chunk | 将序列切分为固定大小(64 token)的片段,是检索的基本单位 |
| 最近邻 | Nearest Neighbor | 在向量空间中距离最近(语义最相似)的数据库条目 |
| 交叉注意力 | Cross-Attention | 让序列 A 的 token 能"看到"序列 B 内容的注意力机制 |
| 分块交叉注意力 | Chunked Cross-Attention (CCA) | RETRO 的核心,让每个 chunk 通过交叉注意力融合其检索到的近邻 |
| 冻结的检索器 | Frozen Retriever | 参数不更新的 BERT 模型,用于将文本映射为检索向量 |
| 半参数化 | Semi-parametric | 知识一部分存在模型参数里,一部分存在外部数据库里 |
| 改装 | Retrofitting | 将已有的预训练模型快速改造为具有检索能力的 RETRO 模型 |
| 困惑度 | Perplexity | 语言模型的评估指标,越低说明模型预测越准确 |
| 近似最近邻 | Approximate Nearest Neighbor (ANN) | 用 SCaNN 等工具快速(近似)地找到向量空间中的最近邻 |
| 键值存储 | Key-Value Store | 数据库的存储结构:键=BERT向量,值=[N,F]两个相邻chunk |
| 后续块 | Continuation (F) | 检索到的 chunk N 之后紧跟的那个 chunk,提供更多上下文 |
| Pile | The Pile | EleutherAI 发布的大规模文本基准数据集,包含多种来源 |
| 精确匹配 | Exact Match | 问答任务的评估指标,模型回答与标准答案完全一致的比例 |
精读笔记基于论文原文(arXiv 2112.04426)整理,所有数字和引文均来自原始论文。