从Mistral 7B到MoE模型Mixtral 8x7B的全面解析:从原理分析到代码解读

0 / 882

前言

本文先全面介绍 Mistral 7B,特别是 Mixtral 8x7B

毕竟 OpenAI 团队一直对 GPT-4 的参数量和训练细节守口如瓶。早些时候,有人爆料 GPT-4 是采用了由 8 个专家模型组成的集成系统。后来又有传闻称,ChatGPT 也只是百亿参数级的模型(大概在 200 亿左右)

传闻无从证明,但 Mixtral 8x7B 可能提供了一种「非常接近 GPT-4」的开源选项,特此,本文全面解析下:从原理解析到代码解读(在此文之前,尚没有资料扒得像本文这样如此之细)

第一部分 23 年 5 月 Mistral AI 发布的 Mistral 7B

1.1 Mistral 7B:通过分组查询注意力 + 滑动窗口注意力超越 13B 模型

23 年 5 月,DeepMind 和 Meta 的三位前员工在巴黎共同创立了 Mistral AI(其 CEO Arthur Mensch 此前在 DeepMind 巴黎工作,CTO Timothée Lacroix 和首席科学家 Guillaume Lample 则在 Meta 共同参与过 LLaMA 一代的研发,很像当年 OpenAI 的部分员工出走成立 Anthropic 啊)

23 年 10 月,他们发布了第一个基座大模型,即 Mistral 7B

Mistral 7B 对应的论文为《Mistral 7B》称( 另,这是其 GitHub 地址) ,以下是「模型参数图」

  1. Mistral 7B 在所有评估基准中均胜过了目前最好的 13B 参数模型(Llama 2,对标的第二代),并在推理、数学和代码生成方面超越了 Llama 34B(对,这里其对标 Llama 第一代的 34B)
    Mistral 7B outperforms the previous best 13B model (Llama 2, [Llama 2 : Open foundation and fine-tuned chat models ]) across all testedbenchmarks, and surpasses the best 34B model (LLaMa 34B, [Llama : Open and efficient foundation language models ]) in mathematics and codegeneration.

  2. 该模型采用了分组查询注意力(GQA),GQA 显著加快了推理速度,还减少了解码期间的内存需求,允许更高的批处理大小,从而提高吞吐量
    GQA significantly accelerates the inference speed, and also reduces the memory requirement during decoding, allowing for higher batch sizes hence higher throughput
    所以你看上面的「模型参数图」,维度(dim):4096,总计 32 个头(n_heads),每个头的维度(head_dim):128,这一眼可以看出来,而 n_kv_heads 是啥呢?
    咋一看好像不太好理解 是不?其实,正是因为 Mistral 用了 GQA,n_heads 指的是 Q 的头数,n_kv_heads 指的是 K、V 的头数

    不过要注意的是,与上图中间所示部分不太一样的地方在于:
    \rightarrow 上图中间所示部分中,Q 的头数是 K V 头数的 2 倍
    \rightarrow 但在 Mistral 的 GQA 中,Q 的头数是 K V 头数的 4 倍

    关于 GQA 的更多介绍,请参见《一文通透各种注意力:从多头注意力 MHA 到分组查询注意力 GQA、多查询注意力 MQA

  3. 同时结合滑动窗口注意力(sliding window attention,简称 SWA)以有效处理任意长度的序列
    SWA is designed to handle longer sequences more effectively at a reduced computational cost

    包括你再看上上张图所示的「模型参数图」,可知 context_len 8192 是说它训练的时候,传进来的数据最大只能到 8192 个 tokens,也就是训练时的上下文长度上限,
    windows_size 4096 是 sliding windows attention 的滑窗大小,1 次 attention 计算的上下文范围只 4096 个 tokens

    言外之意是,每个 token 只最多计算 4096 的范围
    第 5000 个 token 只计算[905: 5000]这个范围的 attention
    第 5001 个 token 只计算[906: 5001]这个范围的 attention
    以此类推..

此外,作者提供了一个针对遵循指令进行了微调的模型,名为 Mistral 7B-Instruct,它在人工和自动化基准测试中均超过了 LLaMA2 13B-chat 模型

1.2 三个显著特点:滑动窗口注意力、滚动缓冲区缓存、预填充与分块

1.2.1 滑动窗口注意力:扩展上下文长度

vanilla attention 的操作次数在序列长度上是二次型的,记忆量随着 token 数量线性增加。在推理时,由于缓存可用性的降低,这导致了更高的延迟和更小的吞吐量(The number of operations in vanilla attention is quadratic in the sequence length, and the memory increases linearly with the number of tokens. At inference time, this incurs higherlatency and smaller throughput due to reduced cache availability )

为了缓解这个问题,Mistral 7B 使用滑动窗口注意力(sliding window attention)

  1. 每个 token 最多可以关注来自上一层的 W 个 token(上图中,W = 3)。请注意,滑动窗口之外的 token 仍然影响下一个单词预测
    each token can attend to at most W tokens from the previous layer (here, W = 3). Note that tokensoutside the sliding window still influence next word prediction.

    举个例子,在面对这个序列时:The cat sat on the
    如果是标准注意力,在计算最后一个 token “the”时,得计算 the 本身所对应的 query 与整个上文每个 token 对应的 key 的内积,当序列长度一长时,该计算量还是比较大的
    但如果是滑动窗口注意力,则在计算最后一个 token “the”时,只需计算 the 本身所对应的 query 与上文中 3 个 token 对应的 key 的内积(这里说的上文中的 3 个 token 包括 the 自己在内 )

  2. 在每个注意力层,信息可以向前移动 W 个 token。因此,在 k 层注意力之后,信息最多可以向前移动 k 个 ×W 个 token
    At each attention layer, information can moveforward by W tokens. Hence, after k attention layers, information can move forward by up to k ×W tokens.

1.2.2 滚动缓冲区缓存(Rolling Buffer Cache)

固定的注意力长度意味着可以使用滚动缓存来限制的缓存大小(A fixed attention span means that we can limit our cache size using a rollingbuffer cache)

  1. 缓存的大小是固定的 W,时间步长 i 的键和值存储在缓存的位置 i mod W 中。因此,当位置 i 大于 W 时,缓存中过去的值就会被覆盖,缓存的大小就会停止增加
    The cache has a fixed size of W, and the keys and values for the timestep i are storedin position i mod W of the cache. As a result, when the position i is larger than W, past valuesin the cache are overwritten, and the size of the cache stops increasing

    以“The cat sat on the mat”为例..
    当 i = 0 时,指 The,0 mod  3=0
    当 i = 1 时,指 cat,1 mod  3=1
    当 i = 2 时,指 sat,2 mod  3=2
    当 i = 3 时,指 on,3 mod  3=0
    当 i = 4 时,指 the,4 mod  3=1
    当 i = 5 时,指 mat,5 mod 3 = 2

  2. 在 32k token 的序列长度上,这减少了 8 倍的缓存内存使用,而不影响模型质量
    On a sequence length of 32k tokens, this reduces the cache memory usageby 8x, without impacting the model quality.

如果把缓冲区比作一座仓库,每存进一个新东西,都会占据相应的位置,而仓库的总容量是固定的,当仓库被装满时,就会把最早放入的东西移除,让新的物品继续进仓,相当于入仓时间更接近当前时间的物品则会留在仓库中,如此,即能在节约资源的同时保留一定长度的序列

1.2.3 预填充与分块:减少重复运算

在生成序列时,需要一个一个地预测 token,因为每个 token 都以前面的 token 为条件。然而,prompt 是提前知道的,可以用 prompt 预填充(k, v)缓存,即

  1. 如果 prompt 非常大,可以把它分成更小的块,用每个块预填充缓存。为此,可以选择窗口大小作为分块大小。因此,对于每个块,需要计算缓存和块上的注意力

  2. 下图展示了注意力掩码在缓存和分块上的工作原理

    在预填充缓存时,长序列被分块,以限制内存使用
    我们把一个序列分成三个块来处理,“The cat sat on”,“the mat and saw”,“the dog go to”。上图中显示了第三块(“the dog go to”)发生的情况:它使用因果掩码(最右块)来关注自己,使用滑动窗口(中心块)来关注缓存,并且不关注过去的 token,因为它们在滑动窗口之外(左块)

1.3 Mistral 7B – Instruct

与 Mistral 7B 同期发布的 Mistral 7B – Instruct(We also provide a model fine-tuned to follow instructions,Mistral 7B –Instruct),在 MT-Bench 的表现可以略微超过 13B –Chat 模型

// 待更

第二部分 首个开源 MoE 大模型 Mixtral 8x7B

2.1 Mixtral 8x7B 的整体架构与模型细节

23 年 12 月 8 日,Mistral AI 在 X 平台甩出一条磁力链接(当然,后来很多人打开一看,发现是接近 87 GB 的种子)

看上去,Mixtral 8x7B 的架构此前传闻的 GPT-4 架构非常相似(很像传闻中 GPT-4 的同款方案),但是「缩小版」:

  • 8 个专家总数,而不是 16 名(减少一半)
  • 每个专家为 7B 参数,而不是 166B(减少 24 倍)
  • 47B 总参数(估计)而不是 1.8T(减少 42 倍)
  • 与原始 GPT-4 相同的 32K 上下文

在发布后 24 小时内,已经有开发者做出了在线体验网站:https://replicate.com/nateraw/mixtral-8x7b-32kseqlen

两天后的 23 年 12.11 日,Mistral AI 团队对外正式发布 Mixtral 8x7B,其在大多数基准测试中都优于 Llama 2 70B,推理速度提高了 6 倍,且它在大多数标准基准测试中匹配或优于 GPT3.5

为免歧义,补充说明下,Mistral AI 团队目前总共发布了两个模型

  • 今年 10 月发布的 Mistral 7B
  • 今年 12 月则发布的混合专家模型,称之为 Mixtral 8x7B

特意注意,一个 mis 一个 mix,本质不同

而 Mixtral 8x7B 是一个纯解码器模型,下图是 Mixtral 的核心参数(可以把它和 Mistral 的核心参数做个对比 )

  1. 其中前馈块从一组 8 个不同的参数组中进行选择(It is a decoder-only model where the feedforward block picks from a set of 8 distinct groups of parameters)

  2. 在每一层,对于每个 token,路由器网络选择其中的两个组(“专家”)来处理 token 并通过组合相加得到它们的输出(At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively)

    这点可能很多朋友不会特别在意,但你仔细品味下,你会发现大有天地,即:每个 token 都由某两个专家负责完成,最后整个序列 则是由一系列「不同的两两专家」组合完成,下文还会详述该点

  3. 上下文长度达到 32K
    Mixtral is pretrained with multilingual data using a context size of 32k tokens

2.1.1 Mixtral 8x7B 是一个稀疏的专家混合网络

如下图所示,传入模型的各个 token 在经过 Attention 层及残差连接后,进一步将由路由(Gating/Router)导向 2 个 expert(FFN)中,之后对 expert 的输出进行加权聚合,再经过残差连接得到当前层的输出

即对于给定的输入x,MoE 模块的输出由“专家网络输出的加权和”决定,其中权重由“门控网络的输出”确定(The output of the MoE module for a given input x is determined by the weighted sum of the outputs of the expert networks , where the weights are given by the gating network’s output .)

当给定n个专家网络\left{E_{0}, E_{i}, \ldots, E_{n-1}\right},则专家层(expert layer)的输出为:

\sum_{i=0}^{n-1} G(x){i} \cdot E{i}(x)

其中

  1. G(x)_{i}表示第i 个专家的门控网络的 n 维输出(denotes the n-dimensional output of the gating network for the i-th expert)
  2. E_{i}(x) 是第i个专家网络的输出(the output of the i-th expert network)

如果门控向量稀疏,我们可以避免计算门为零的专家输出(If the gating vector is sparse, we can avoid computing the outputs of experts whose gates are zero)。有多种实现 G(x)的可选方法,但一种简单且高性能的方法是通过对线性层的 Top-K logits 进行 softmax(but a simple and performant one is implemented by taking the softmax over the Top-K logits of a linear layer [28])

G(x):=\operatorname{Softmax}\left(\operatorname{TopK}\left(x \cdot W_{g}\right)\right)

其中

  1. 如果\ell_{i}在 logits 的 top-K 坐标\ell \in \mathbb{R}^{n}中,则(\operatorname{TopK}(\ell)){i}:=\ell{i},否则(\operatorname{TopK}(\ell))_{i}:=-\infty
    where(\operatorname{TopK}(\ell)){i}:=\ell{i} if \ell_{i} is among the top-K coordinates of logits \ell \in \mathbb{R}^{n}and (\operatorname{TopK}(\ell))_{i}:=-\infty otherwise.

  2. 每个 token 所使用的专家数量K是可调的参数
    当保持K不变但增加n时,可以增加模型的总参数数量,同时保持计算成本有效不变
    The value of K – the number of experts used per token – is a hyper-parameter that modulates the amount of compute used to process each token. If one increases n while keeping K fixed, one can increase the model’s parameter count while keeping its computational cost effectively constant.

    这引出了「总参数数量(通常称为稀疏参数数量)」与用于「处理单个 token 的活动参数数量」之间的区别
    对总参数数量而言,随着n的增加而增加;而对于活动参数数量而言,K直到n逐渐增加
    This motivates a distinction between the model’s total parameter count (commonly referenced as the sparse parameter count), which grows with n, and the number of parameters used for processing an individual token (called the active parameter count), which grows with K up to n.

MoE 层能够在具备高性能专用内核的单个 GPU 上高效运行

  1. 例如,Megablocks 将 MoE 层的前馈网络(FFN)操作转换为大型稀疏矩阵乘法(Megablocks [13] casts the feed-forward network (FFN) operations of the MoE layer as large sparse matrix multiplications),从而显著提升了执行速度
    并且可以自动处理不同专家被分配可变数量 token 的情况(naturally handling cases where different experts get a variable number of tokens assigned to them.)

  2. 此外,通过标准模型并行技术和一种名为专家并行(EP)的特殊分区策略,MoE 层可以在多个 GPU 上进行分布
    Moreover, the MoE layer can be distributed to multiple GPUs through standard Model Parallelism techniques , and through a particular kind of partitioning strategy called Expert Parallelism (EP) [28].
    在 MoE 层执行过程中,旨在由特定专家处理的 token 会被路由到相应的 GPU 进行处理,并将专家输出返回到原始 token 位置 During the MoE layer’s execution, tokens meant to be processed by a specific expert are routed to the corresponding GPU for processing , and the expert’s output is returned to the original token location.

    需要注意的是,在负载平衡方面,EP 带来了挑战,因为均匀地分配工作负载至关重要以避免单个 GPU 过载或遇到计算瓶颈
    Note that EP introduces challenges in load balancing, as it is essential to distribute the workload evenly across the GPUs to prevent overloading individual GPUs or hitting computational bottlenecks.

在 Transformer 模型中,MoE 层独立应用于每个 token,并替换了 Transformer 块的前馈(FFN)子块(In a Transformer model, the MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block)

对于 Mixtral

  1. 采用与专家函数E_{i}(x)相同的 SwiGLU 架构,并设置 K = 2
  2. 这意味着每个 token 被路由到两个具有不同权重集的 SwiGLU 子块
    For Mixtral we use the same SwiGLU architecture as the expert function Ei(x) and set K = 2

综上,输入 token x经过处理后得到输出y(This means each token is routed to two SwiGLU sub-blocks with different sets of weights)

y=\sum_{i=0}^{n-1} \operatorname{Softmax}\left(\operatorname{Top} 2\left(x \cdot W_{g}\right)\right){i} \cdot \operatorname{SwiGLU}{i}(x)

这个公式类似于 GShard 架构,不同之处是 mixtral 用 MoE 层替换所有 FFN 子块,而 GShard 替换所有其他块,并且 GShard 对分配给每个 token 的第二个专家使用更详细的门策略

2.1.2 Mixtral 的参数总量为何是 46.7B 而非 56B

Mixtral 共有 46.7B 个参数,但每个 token 仅使用 12.9B 个参数。因此,它以与 12.9B 模型相同的速度和相同的成本处理输入并生成输出( Mixtral has 46.7B total parameters but only uses 12.9B parameters per token. It, therefore, processes input and generates output at the same speed and for the same cost as a 12.9B model )

  1. 即,虽然 Mixtral 模型的完整名称为“Mixtral-8x7B-v0.1”,看似有“8x7B=56B”的参数量,但实际的参数量应当是约 47B 而非 56B,因为在各个层中仅有 experts 部分(FFN)是独立存在的,其余的部分(Attention 等)则是各个 expert 均有共享的
  2. 可以想象成一个“纺锤状”的样式,数据由共享模块传输至 expert 模块对应于纺锤中部发散的部分,对 expert 的输出进行加权聚合则对应纺锤末端收束的部分

2.1.3 Mixtral 中所采取的 GQA 机制

Mixtral 沿用了 Mistral 7B 中所采取的 GQA 机制,与传统的 MHA(Multi-Head Attention)相比,主要是对 Attention 机制中的 K、V 表征维度进行控制,从而降低 K、V 对应的参数量,除 GQA 外相应地还有 MQA(Multi-Query Attention),MQA 可以认为是 GQA 的特例。相关维度如下表所示:

Q K V
MHA hidden_dim hidden_dim hidden_dim
GQA hidden_dim hidden_dim/n hidden_dim/n
MQA hidden_dim 1 1

其中 n 为 K 和 V 相对 MHA 参数量降低的比例,具体地,在 Mixtral 中 n 为 4

关于 GQA 的更多细节详见此文《一文通透各种注意力:从多头注意力 MHA 到分组查询注意力 GQA、多查询注意力 MQA

2.1.4 Mixtral 中的路由(Gating/Router)

路由(Gating/Router)本质是一个线性层,输入维度为隐层维度 hidden_dim、输出维度为 expert 数 num_experts。正向传播过程中将被用作预测给定 token 对应输入各个 expert 的分值

self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

至于路由处理的对象可以是 Sentence-Level、Token-Level 或者 Task-Level

  • Sentence-Level 是对各个样本分别进行路由
  • Token-Level 是对样本中的各个 token 分别进行路由
  • Task-Level 要求不同的 expert 明确负责不同任务

因此同样也是对各个样本分别进行路由,但其所路由的目标 expert 是有明确导向的,例如某样本的数据还提供有“所属任务”信息,通过该信息可明确将该样本导向某个专职负责对应任务的 expert 中

Mixtral 采取了 Token-Level 的处理单位

  1. 至于首次在 NLP 任务中使用 Token-Level 的 MOE 可以追溯至 2017 年的《Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
  2. 该论文展示了 Token-Level 的一些有趣现象,通过观察各个 expert 所负责 token 的统计特征,不同的 expert 确实掌握了一些语法层面理解, 当需要不定冠词“a”在重要的动词短语中引入直接宾语时,则会有专门的 752 号 expert 来负责输出这个“a”

2.2 模型表现:匹配或超越Llama 2 70B 以及 GPT3.5

我们将 Mixtral 与 Llama 2 系列和 GPT3.5 基础模型进行比较。Mixtral 在大多数基准测试中均匹配或优于 Llama 2 70B 以及 GPT3.5

性能概览

在下图中的测试,衡量了质量与推理预算的权衡。与 Llama 2 相比,Mistral 7B 和 Mixtral 8x7B 更高效

性能规模

下表给出了上图的详细结果

详细的基准测试

为了识别可能的缺陷,通过微调/偏好建模来纠正,测量了其在 BBQ/BOLD 上的性能

BBQ BOLD 基准

与 Llama 2 相比,Mixtral 对 BBQ 基准的偏差较小。总体而言,Mixtral 在 BOLD 上比 Llama 2 显示出更积极的情绪

2.3 指令遵循模型Mixtral 8x7B Instruct

与 Mixtral 8x7B 一起发布还有 Mixtral 8x7B Instruct,其在 Mixtral 8x7B 的基础上通过监督微调和直接偏好优化(DPO)进行优化,以让之严格的遵循指令

关于什么是 DPO 及其原理细节,请参见此文《RLHF 的替代之 DPO 原理解析:从 RLHF、Claude 的 RAILF 到 DPO、Zephyr

在 MT-Bench 上,它达到了 8.30 的分数,使其成为最好的开源模型,性能可与 GPT3.5 相媲美

第三部分 Mixtral(MOE 架构)的实现细节:代码解读

如阿荀所说(*本部分的 base 版本由我司大模型项目团队第二项目组的阿荀提供,我在其基础上陆陆续续做了大量的补充、说明 *),上文中关于 mixtral 一个比较反直觉的点是:

  • 对于每个 token,路由器网络选择其中的两个组(“专家”)来处理 token 并通过组合相加得到它们的输出「At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively
  • 啥意思,就是如果不仔细了解的话,很容易误以为是“输入的一整个序列”分给 TOP 2 专家,结果事实是每个 token 都各自分配 TOP 2 专家,而且当你仔细抠完 mixtral 的代码之后,你会发现还真是如此..

3.1 MOE 模块的前向传播:整体流程

单个 Mixtral 层可以大体划分为 Attention 模块和 MOE 模块,以下重点关注 MOE 模块的前向传播过程

3.1.1 获取各 token 对应的 top2 expert 及其权重

为确保大家可以以最快的速度理解各行代码的含义,我在阿荀分析的基础上拆成了以下六个步骤,且对每个步骤都加了额外的解释说明

  1. 由于 hidden_states 的维度,通常包括批大小(batch_size)、序列长度(sequence_length)和隐藏层维度(hidden_dim),故有
    # 由Attention模块输出的hidden_states作为本部分的输入
    batch_size, sequence_length, hidden_dim = hidden_states.shape
    
  2. 将 hidden_states 的形状重构为一个二维张量,用于将其处理为每个 token 的表示
    # 转换成(bs*seq_len, hidden_dim),即token-level
    hidden_states = hidden_states.view(-1, hidden_dim)
    
  3. 通过一个门控(gate)机制来生成路由逻辑(router_logits),用于后续决定每个 token 应由哪些专家(experts)处理
    # router_logits: (batch * sequence_length, n_experts)
    # (bs * seq_len, n_experts)
    router_logits = self.gate(hidden_states)
    
  4. 对每个 token 的路由逻辑应用 softmax 函数,计算每个专家对每个 token 的处理权重
    # 在token-level(dim=1)进行softmax,即每个token都各自进行n_experts分类的输出
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    
  5. 选取每个 token 的前 top_k 个最重要的专家及其权重
    # routing_weights: (bs * seq_len, topk),是选取的experts对应的原始权重
    # selected_experts: (bs * seq_len, topk),是选取的experts的编号/索引号
    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
    
  6. 对选出的每个 token 的专家权重进行归一化处理,确保每个 token 的专家权重之和为 1
    # 对原始权重重新归一化,使得所取出的experts权重加和等于1
    # routing_weights的具体样例见下文的【代码块A】
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
    

3.1.2 将各 token 传入对应的 expert 模型中进行前向传播得到输出

  1. 首先

    # final_hidden_states: (bs * seq_len, hidden_dim)
    # 由全0张量初始化
    # final_hidden_states将用于存储各token对应expert的聚合结果
    final_hidden_states = torch.zeros(
        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
    )
    
  2. 根据给定的 selected_experts 作为元素 1 所在位置的索引,构建向量长度为 num_experts 的 one-hot 编码
    好比 24 个 token,需要由 8 个 expert 两两组合处理,那我针对每一个 token 都构建长度为 8 的 0 1 编码,这个编码分别代表 8 个 expert
    故,每个 token 选择了哪两个 expert,则对应的编码位上变为 1,否则为 0

    比如 July 这个 token 选择 3 7 两个 expert,则 July 对应的 0 1 编码位:0 0 1 0 0 0 1 0
    再比如 Edu 这个 token 如果选择了 2 4 两个 expert,则其 01 编码为:0 1 0 1 0 0 0 0
    依此类推..

    # selected_experts.shape: (bs*seq_len, topk)
    # torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape: (bs*seq_len, topk, num_experts)
    
  3. 使用相对取巧方法来进行前向传播

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
    

    具体而言,下面这个张量
    torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0).shape: (num_experts, topk, bsseq_len)*
    的物理含义是由“每个 token 分别选取了哪 topk 个 expert”变成了“每个 expert 分别作为各个排位存在的时候,对应需要处理哪些 token”
    这样做的好处在于:后续循环的时候只需要进行 num_experts 次前向传播就能得到结果,而无需进行 bs*seq_len 次前向传播

    为方便大家更好的理解上面那行代码的含义,我特地画了个示意图以加快理解
    \rightarrow A B C D E F G H I J K L M N O P Q R S T U V W X Y Z,是需要处理的 token
    \rightarrow 1 2 3 4 5 6 7 8,代表 8 个 expert
    (*如阿荀所说,如此,便把关注视角从“各个 token”变成了“各个专家”,当然,大部分情况下 token 数远远不止下图这 5 个,而是比专家数多很多。总之,这么一转换,最终可以省掉很多循环 *)

  4. 所以接下来只需要进行 num_experts 次循环

    # 根据次序逐个取出expert模型
    for expert_idx in range(self.num_experts):
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])
    

    上面这几行代码得好好解释下
    由于 expert_mask 记录有各个 expert 分别作为各个排位存在的时候,对应需要处理哪些 token,故 expert_mask[expert_idx].shape: (topk, bsseq_len)*,便是从 expert_mask 中取出其对应的,详见下文的【代码块 B】
    故上面三行的最后一行中等式中的右边项:torch.where(expert_mask[expert_idx]),则是辨析出 *expert_mask[expert_idx]*值为 1 的位置索引,详见下文的【代码块 C】

    至于:idx.shape: (bs * seq_len, ),则代表 expert_mask[expert_idx]中(每列)元素值为 1 的索引位置
    以及:top_x.shape: (bs * seq_len, ),则代表 expert_mask[expert_idx]中(每行)元素值为 1 的索引位置

    继续分析该 for 循环之后的代码,如下

    # 如果exert_mask[expert_idx]不存在元素为1的值则跳过
        if top_x.shape[0] == 0:
            continue
    
        # 全部token的隐向量hidden_states中取出当前expert对应token的隐向量
        # current_state.shape: (top_x_length, hidden_dim)
        current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
    
        # 将取出的token隐向量传入expert模型进行前向传播得到返回
        # current_hidden_states.shape: (top_x_length, hidden_dim)
        # expert_layer的正向过程详见下文的【代码块D】
        current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
    
        # 将当前expert的输出以加和的形式写入预先定义好的final_hidden_states张量中
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
    
  5. for 循环结束后,相当于所有 expert 均处理完毕后,将维护好的final_hidden_states由(bs * seq_len, hidden_dim)转为(bs, seq_len, hidden_dim),并将作为本批次运行的返回
    更多详见下文的【代码块 E】

    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
    

3.2 MOE 前向传播中五个代码块的细致分析:鞭辟入里

3.2.1 代码块 A:routing_weights 的具体样例

# 【代码块Arouting_weights
# 每行对应1token,第0列为其对应排位第1expert、第1列为其对应排位第2expert,元素值为相应权重
[[0.5310, 0.4690],
 [0.5087, 0.4913],
 [0.5775, 0.4225],
 [0.5014, 0.4986],
 [0.5030, 0.4970],
 [0.5479, 0.4521],
 [0.5794, 0.4206],
 [0.5545, 0.4455],
 [0.5310, 0.4690],
 [0.5294, 0.4706],
 [0.5375, 0.4625],
 [0.5417, 0.4583],
 [0.5014, 0.4986],
 [0.5239, 0.4761],
 [0.5817, 0.4183],
 [0.5126, 0.4874]]

3.2.2 代码块 B:expert_mask[expert_idx]

因为有:expert_mask 记录有各个 expert 分别作为各个排位存在的时候,对应需要处理哪些 token
故而有:expert_mask[expert_idx]从 expert_mask 中取出第 expert_idx 个 expert 将处理哪些 token
\rightarrow 第 0 行为该 expert 作为排位第 1 存在的时候处理的 token
\rightarrow 第 1 行为该 expert 作为排位第 2 存在的时候处理的 token

# 【代码块Bexpert_mask[expert_idx]
# 下述两行例子的物理含义为:
# 第一行是“该expert作为排位1exert存在时,需要处理第9token
# 第二行是“该expert作为排位2expert存在时,需要处理第1011token
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]]

3.2.3 代码块 C:idx, top_x = torch.where(expert_mask[expert_idx])

# 【代码块C】idx, top_x = torch.where(expert_mask[expert_idx])
# 以上述expert_mask[expert_idx]样例为例,对应的torch.where(expert_mask[expert_idx])结果如下
idx: [0, 1, 1]
top_x: [9, 10, 11]

idx 对应行索引,top_x 对应列索引,例如张量 expert_mask[expert_idx]中,出现元素 1 的索引为(0, 9)、(1, 10)、(1, 11)
从物理含义来理解,top_x 实际上就对应着“关乎当前 expert 的 token 索引”,第 9、第 10、第 11 个 token 被“路由”导向了当前所关注的 expert,通过 top_x 可以取到“需要传入该 expert 的输入”,也即第 9、第 10、第 11 个 token 对应的隐向量

  • 因此 top_x 将作为索引用于从全部 token 的隐向量 hidden_states 中取出对应 token 的隐向量
  • 而 idx 和 top_x 也会组合起来被用于从 expert 权重张量 routing_weights 中取出对应的权重

并且通过行索引、列索引的组合 routing_weights

3.2.4 代码块 D:expert 内部的前向传播

# 【代码块D】expert内部的前向传播
def forward(self, hidden_states, routing_weights):
    current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
    current_hidden_states = self.w2(current_hidden_states)
    return routing_weights * current_hidden_states

其入参不仅有 expert 相应 token 的隐向量,还有对应 expert 的权重,整体是一个基于 swiGLU 激活的 FFN

最后对 FFN 的输出进行加权得到该 expert 的实际输出,因此加权处理是在 expert 的内部就已经进行了

3.2.5 代码块 E:final_hidden_states

  1. 最初 final_hidden_states 是全 0 张量
    # 查看与当前expert有关的final_hidden_states部分,即final_hidden_states[top_x]
    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]]
    
  2. 使用.index_add_函数后在指定位置(top_x)加上了指定值(current_hidden_states)
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
    
  3. 再次查看与当前 expert 有关的 final_hidden_states 部分,即
    [[ 0.0938,  0.0509, -0.0689,  ..., -0.0182, -0.0246,  0.0468],
     [ 0.1246,  0.0642,  0.0015,  ...,  0.0100, -0.0110,  0.0219],
     [ 0.0478, -0.0192,  0.0139,  ..., -0.0039, -0.0197,  0.0475]]
    

第四部分 混合专家模型 MOE 的发展史与更多实践细节

// 待更

第五部分 MoE-Mamba 模型:将 Mamba 和混合专家层组合起来

// 待更

参考文献与推荐阅读

  1. 一条磁力链接席卷 AI 圈,87GB 种子直接开源 8x7B MoE 模型
  2. Mistral AI 对 Mixtral of experts 的介绍:Mixtral of experts | Mistral AI | Open source models
  3. 开源大模型超越 GPT-3.5!爆火 MoE 实测结果出炉
  4. https://github.com/nateraw/replicate-examples/tree/main/mixtral
  5. 预训练大模型:百度 UFO(Unified Feature Optimization)
  6. 集 4 学员且友人 wstart 推荐的三篇论文
    LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment
    MegaBlocks: Efficient Sparse Training with Mixture-of-Experts
    Weak-to-Strong Generalization: Eliciting Strong Capabilities With Weak Supervision
  7. Mixtral 8x7B 论文终于来了:架构细节、参数量首次曝光
    一条磁力链爆全网,Mixtral 8x7B 论文来了!碾压 Llama 2 70B,每 token 仅需激活 13B 参数
  8. Mixtral of Experts 论文,是本文中此节“1.1.1 Mixtral 8x7B 是一个稀疏的专家混合网络”的核心参考