update moe note

This commit is contained in:
jingyaogong 2025-04-09 17:38:31 +08:00
parent d503093ec4
commit d9453ed9a3

View File

@ -221,7 +221,6 @@ class MOEFeedForward(nn.Module):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1) flat_topk_idx = topk_idx.view(-1)
if self.training: if self.training:
# 训练模式下,重复输入数据
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16) y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts): for i, expert in enumerate(self.experts):
@ -229,7 +228,6 @@ class MOEFeedForward(nn.Module):
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape) y = y.view(*orig_shape)
else: else:
# 推理模式下,只选择最优专家
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity) y = y + self.shared_experts(identity)
@ -242,9 +240,10 @@ class MOEFeedForward(nn.Module):
idxs = flat_expert_indices.argsort() idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok token_idxs = idxs // self.config.num_experts_per_tok
# 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52] # 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理token_idxs[6:15]位置的token都由专家1处理...... # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert): for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1] start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx: if start_idx == end_idx:
@ -254,7 +253,6 @@ class MOEFeedForward(nn.Module):
expert_tokens = x[exp_token_idx] expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# 使用 scatter_add_ 进行 sum 操作
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache return expert_cache