diff --git a/model/model.py b/model/model.py index 2873a77..1a94bd8 100644 --- a/model/model.py +++ b/model/model.py @@ -551,7 +551,6 @@ class MiniMindLM(PreTrainedModel): self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) - self.OUT = CausalLMOutputWithPast() self.params = params def forward(self, @@ -608,11 +607,21 @@ class MiniMindLM(PreTrainedModel): slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) - self.OUT.__setitem__('last_hidden_state', h) - self.OUT.__setitem__('logits', logits) - self.OUT.__setitem__('aux_loss', aux_loss) - self.OUT.__setitem__('past_key_values', past_kvs) - return self.OUT + # 进一步简化,只保留必要的参数 + output = CausalLMOutputWithPast( + logits=logits, + past_key_values=past_kvs, + ) + + # 尝试添加其他属性(如果支持的话) + try: + output.hidden_states = h + + output.aux_loss = aux_loss + except: + pass + + return output @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, diff --git a/train_pretrain.py b/train_pretrain.py index a3be32b..5cc556b 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -40,6 +40,7 @@ def get_lr(current_step, total_steps, lr): def train_epoch(epoch, wandb): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() + moe_path = '_moe' if lm_config.use_moe else '' for step, (X, Y, loss_mask) in enumerate(train_loader): try: # 将数据加载到设备上 @@ -59,7 +60,20 @@ def train_epoch(epoch, wandb): Y.view(-1) ).view(Y.size()) loss = (loss * loss_mask).sum() / loss_mask.sum() - loss += res.aux_loss + # 添加辅助损失,如果存在的话 + try: + if hasattr(model, 'module'): + # DDP情况 + aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers + if hasattr(l.feed_forward, 'aux_loss')) + else: + # 非DDP情况 + aux_loss = sum(l.feed_forward.aux_loss for l in model.layers + if hasattr(l.feed_forward, 'aux_loss')) + loss += aux_loss + except Exception as e: + Logger(f"Warning: Could not add auxiliary loss: {e}") + # 如果出错,不添加辅助损失 loss = loss / args.accumulation_steps # Print data types for debugging @@ -106,7 +120,7 @@ def train_epoch(epoch, wandb): # 保存模型 if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' + # moe_path = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): @@ -124,7 +138,7 @@ def train_epoch(epoch, wandb): os.remove(save_path) if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, save_path)