diff --git a/model/model.py b/model/model.py index d53b57a..848cf8a 100644 --- a/model/model.py +++ b/model/model.py @@ -221,7 +221,6 @@ class MOEFeedForward(nn.Module): x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: - # 训练模式下,重复输入数据 x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=torch.float16) for i, expert in enumerate(self.experts): @@ -229,7 +228,6 @@ class MOEFeedForward(nn.Module): y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) else: - # 推理模式下,只选择最优专家 y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) @@ -242,9 +240,10 @@ class MOEFeedForward(nn.Module): idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.config.num_experts_per_tok - # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52] - # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] - # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理...... + # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4) + # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时 + # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok) + # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推 for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: @@ -254,7 +253,6 @@ class MOEFeedForward(nn.Module): expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - # 使用 scatter_add_ 进行 sum 操作 expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) return expert_cache