This commit is contained in:
iomgaa 2025-05-12 12:11:29 +08:00
parent 8dd7cfaf72
commit 6eaca41018
2 changed files with 32 additions and 9 deletions

View File

@ -551,7 +551,6 @@ class MiniMindLM(PreTrainedModel):
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.params = params
def forward(self,
@ -608,11 +607,21 @@ class MiniMindLM(PreTrainedModel):
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
past_key_values=past_kvs,
)
# 尝试添加其他属性(如果支持的话)
try:
output.hidden_states = h
output.aux_loss = aux_loss
except:
pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,

View File

@ -40,6 +40,7 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
moe_path = '_moe' if lm_config.use_moe else ''
for step, (X, Y, loss_mask) in enumerate(train_loader):
try:
# 将数据加载到设备上
@ -59,7 +60,20 @@ def train_epoch(epoch, wandb):
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
# 添加辅助损失,如果存在的话
try:
if hasattr(model, 'module'):
# DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
else:
# 非DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# Print data types for debugging
@ -106,7 +120,7 @@ def train_epoch(epoch, wandb):
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
# moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
@ -124,7 +138,7 @@ def train_epoch(epoch, wandb):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, save_path)