手写MOE

MOE(Mixture of Experts)也就是混合专家系统,已经在LLM(Large Language Model)的结构中成为标配了。最近看到一篇手写MOE教程,所学下来,受益颇多。
MOE概述
MOE的核心思想就是通过多个神经网络构成多个“领域”的专家,对输入进行一个判断,让最“擅长”这个任务的专家处理这个任务。从而实现更高效、准确的处理任务。
因此,MOE模型的必要组成部分包含:
-
专家(Experts):模型中的每个专家都是一个独立的神经网络,专门处理输入数据的特定子集或特定任务。通常一个Experts是一个 前馈神经网络(FeadFoward Network,FFN)
-
门控网络(Gate Network):门控网络的作用是决定每个输入应该由哪个专家或哪些专家来处理。通常它根据输入样本的特征计算出每个专家的权重或重要性,然后根据这些权重将输入样本分配给相应的专家。门控网络通常是一个简单的神经网络,其输出经过softmax激活函数处理,以确保所有专家的权重之和为1。
而现在 LLM + MOE 如此火热的原因主要在于:
-
提高模型性能:通过将多个专家的预测结果进行整合,MoE模型可以在不同的数据子集或任务方面发挥每个专家的优势,从而提高整体模型的性能。在不同的任务中,激活擅长的专家,可以让模型可以更准确地对不同领域的任务进行处理。
-
减少计算消耗:与传统的密集模型相比,MoE模型在处理每个输入样本时,只有相关的专家会被激活,而不是整个模型的所有参数都被使用。这意味着MoE模型可以在保持较高性能的同时,显著减少计算资源的消耗,特别是在LLM中,这种优势更为明显。
-
增强模型的可扩展性:MoE模型的架构可以很容易地扩展到更多的专家和更大的模型规模。通过增加专家的数量,模型可以覆盖更广泛的数据特征和任务类型,从而在不增加计算复杂度的情况下,提升模型的表达能力和泛化能力。这种可扩展性为处理大规模、复杂的数据集提供了有效的解决方案,例如在处理多模态数据时,MoE模型可以通过设置不同的专家来专门处理不同模态的数据,实现更高效的多模态融合。
Expert
因为 MOE 网络对应着多个专家,而通常来说,专家的结构是一个前馈神经网络。
因此我们首先需要实现一个普通前馈神经网络结构的专家。
import torch
import torch.nn as nn
class MOE_basic(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc = nn.Linear(in_features=input_size, out_features=hidden_size)
self.fc2 = nn.Linear(in_features=hidden_size, out_features=output_size)
self.active = nn.ReLU()
def forward(self, x):
x = self.fc2(self.active(self.fc(x)))
return x
这个 FFN 专家先把输入向量从 input_size
维线性投影到更高的 hidden_size
维,经 ReLU激活函数, 进行非线性变换,再线性缩回 input_size
维,输出与输入同形状。
通过这种“先扩后缩”的方式能够在更高维空间中学习复杂的特征交互。
基础版本MOE
基础版本的MOE思想非常的简单,即首先将输入通过一个Gate NetWork,获得每一个专家的权重(weight),然后将输入依次通过每一个专家,最后将权重乘上专家的输出作为专家的”分析结果“,然后将每个专家的结果相加即可。
class simple_MOE(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_experts):
super().__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList(
[
MOE_basic(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
for _ in range(num_experts)
]
)
self.gate = nn.Linear(in_features=input_size, out_features=num_experts)
# x [batch, input_size]
def forward(self, x):
# shape: [batch, num_experts]
experts_weight = self.gate(x)
# shape: [[batch, 1, output_size] * num_experts]
experts_out_list = [expert(x).unsqueeze(1) for expert in self.experts]
# shape: [batch, num_experts, output_size]
experts_out_list = torch.cat(experts_out_list, dim=1)
# shape: [batch, 1, num_experts]
experts_weight = experts_weight.unsqueeze(1)
# shape: [batch, 1, output_size]
out = experts_weight @ experts_out_list
# shape: [batch, output_size]
return out.squeeze(1)
Sparse MoE
Sparse MoE 开始就是大模型训练时所采用的结构了。回看基础版本的MOE,让输入进过了每一个专家的计算,然后采取了不同权重,表示对不同专家的侧重程度。但是在大模型的训练架构中,由于模型的参数量过大,这种基础版本的MOE并没有减轻训练的负担。
因此,Sparse Moe相较于基础版本的MOE的区别是,通过Router选择 topK 个专家,然后对这 topK 个专家的输出进行加权求和。
如果 topK = 专家数量,那么其实也退化成了基础版本的MOE
具体结构如下图所示:
以 switch transformers 模型的 MOE 架构图作为演示
首先是 Router,对每个 token 通过 top-k 选择专家
class MOE_router(nn.Module):
def __init__(self, hidden_size, num_experts, top_k):
super().__init__()
self.num_experts = num_experts
self.gate = nn.Linear(in_features=hidden_size, out_features=num_experts)
self.top_k = top_k
# hidden_states :[batch* seq, hidden_size]
def forward(self, hidden_states):
routers_logits = self.gate(hidden_states) # shape: [batch* seq, num_experts]
routers_probs = F.softmax(
routers_logits, dim=-1, dtype=torch.float
) # shape:[batch* seq, num_experts]
router_weights, selected_experts = torch.topk(
routers_probs, self.top_k, dim=1
) # shape:[batch* seq,top_k], topk 返回选择的top-k的值和对应的索引
router_weights = router_weights / router_weights.sum(
dim=-1, keepdim=True
) # 归一化
router_weights = router_weights.to(hidden_states.dtype) # 转换精度
# 生成专家掩码,提升计算效率
experts_mask = F.one_hot(
selected_experts, num_classes=self.num_experts
) # 选的是1,没选的是0 shape:[batch* seq, top_k, num_experts]
experts_mask = experts_mask.permute(2, 1, 0) # (num_experts, top_k, batch* seq)
return routers_logits, router_weights, selected_experts, experts_mask
class spare_MOE(nn.Module):
def __init__(self, input_size, output_size, hidden_size, top_k, num_experts):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.num_experts = num_experts
self.top_k = top_k
self.router = MOE_router(self.hidden_size, num_experts, top_k)
self.MOE = nn.ModuleList([MOE_basic(input_size, output_size) for _ in range(num_experts)])
def forward(self, x):
# x shape [batch, seq,hidden_size]
print(x.shape)
batch_size, seq, hidden_size = x.size()
# 合并强两个维度,将所有token视为独立的输入,这样每个专家可以独立处理每个token,而不需要考虑它们属于哪个批次或序列
# [batch*seq, hidden_size]
hidden_state = x.view(-1, hidden_size)
routers_logits, router_weights, selected_experts, experts_mask = self.router(hidden_state)
print(routers_logits.shape, router_weights.shape, selected_experts.shape, experts_mask.shape)
final_hidden_states = torch.zeros_like(hidden_state)
for experts_idx in range(self.num_experts):
experts_layer = self.MOE[experts_idx] # 当前专家网络
# idx [top_k]
# top_X [batch* seq]
idx, top_x = torch.where(experts_mask[experts_idx] == 1)
# 选出哪些token需要经过专家
current_state = hidden_state.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_size)
# 专家处理*权重
current_hidden_state = experts_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)
final_hidden_states.index_add(0, top_x, current_hidden_state.to(hidden_state.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, seq, hidden_size)
return final_hidden_states, routers_logits
然后是整个 Sparse MoE 的代码
class spare_MOE(nn.Module):
def __init__(self, input_size, output_size, hidden_size, top_k, num_experts):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.num_experts = num_experts
self.top_k = top_k
self.router = MOE_router(self.hidden_size, num_experts, top_k)
self.MOE = nn.ModuleList([MOE_basic(input_size, hidden_size, output_size) for _ in range(num_experts)])
def forward(self, x):
# x shape [batch, seq,hidden_size]
print(x.shape)
batch_size, seq, hidden_size = x.size()
# 合并强两个维度,将所有token视为独立的输入,这样每个专家可以独立处理每个token,而不需要考虑它们属于哪个批次或序列
# [batch*seq, hidden_size]
hidden_state = x.view(-1, hidden_size)
routers_logits, router_weights, selected_experts, experts_mask = self.router(hidden_state)
print(routers_logits.shape, router_weights.shape, selected_experts.shape, experts_mask.shape)
final_hidden_states = torch.zeros_like(hidden_state)
for experts_idx in range(self.num_experts):
experts_layer = self.MOE[experts_idx] # 当前专家网络
# idx [top_k]
# top_X [batch* seq]
idx, top_x = torch.where(experts_mask[experts_idx] == 1)
# 选出哪些token需要经过专家
current_state = hidden_state.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_size)
# 专家处理*权重
current_hidden_state = experts_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)
final_hidden_states.index_add(0, top_x, current_hidden_state.to(hidden_state.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, seq, hidden_size)
return final_hidden_states, routers_logits
ShareExpert Sparse MoE
ShareExpert SparseMoE 则是 Deepseek 所采用的一种 MOE架构。
和上面的 Sparse MoE 的区别是,ShareExpert Sparse MoE 增加了一组共享的专家模型。这些模型是所有输入的 token 共享的,也就是说,所有 token 都过这些共享专家模型,然后每个 token 再经过Sparse Moe,选出 top-k 个专家,然后这 top-K 个专家和共享的专家的输出一起再加权求和。
具体结构如下图所示:

下面的手写代码参考了 deepseek MoE 的思想,有一定的简化,但是可以方便理解训练过程。
class shareExpert_MOE(nn.Module):
def __init__(self, input_size, output_size, hidden_size, top_k, num_experts):
super().__init__()
self.spare_MOE = spare_MOE(input_size, output_size, hidden_size, top_k, num_experts)
self.share_MOE = nn.ModuleList(
[MOE_basic(input_size=input_size, hidden_size=hidden_size, output_size=output_size) for _ in range(num_experts)]
)
def forward(self, x):
# spare_moe_out [batch, seq, hidden], routers_logits [batch* seq, num_experts]
spare_moe_out, routers_logits = self.spare_MOE(x)
share_moe_out = [
expert(x) for expert in self.share_MOE
] # [batch, seq, hidden]
share_moe_out = torch.stack(share_moe_out, dim=0).sum(dim=0)
return share_moe_out + spare_moe_out, routers_logits
参考文章
LLM MOE的进化之路,从普通简化 MOE,到 sparse moe,再到 deepseek 使用的 share_expert sparse moe | chaofa用代码打点酱油