update moe note
This commit is contained in:
parent
d503093ec4
commit
d9453ed9a3
@ -221,7 +221,6 @@ class MOEFeedForward(nn.Module):
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
# 训练模式下,重复输入数据
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
@ -229,7 +228,6 @@ class MOEFeedForward(nn.Module):
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
# 推理模式下,只选择最优专家
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
@ -242,9 +240,10 @@ class MOEFeedForward(nn.Module):
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
|
||||
# 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...]
|
||||
# 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
@ -254,7 +253,6 @@ class MOEFeedForward(nn.Module):
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
# 使用 scatter_add_ 进行 sum 操作
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
Loading…
x
Reference in New Issue
Block a user