update
This commit is contained in:
parent
8dd7cfaf72
commit
6eaca41018
@ -551,7 +551,6 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.params = params
|
||||
|
||||
def forward(self,
|
||||
@ -608,11 +607,21 @@ class MiniMindLM(PreTrainedModel):
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
self.OUT.__setitem__('last_hidden_state', h)
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
self.OUT.__setitem__('past_key_values', past_kvs)
|
||||
return self.OUT
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
past_key_values=past_kvs,
|
||||
)
|
||||
|
||||
# 尝试添加其他属性(如果支持的话)
|
||||
try:
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
except:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
|
@ -40,6 +40,7 @@ def get_lr(current_step, total_steps, lr):
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
try:
|
||||
# 将数据加载到设备上
|
||||
@ -59,7 +60,20 @@ def train_epoch(epoch, wandb):
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
if hasattr(model, 'module'):
|
||||
# DDP情况
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
else:
|
||||
# 非DDP情况
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
# 如果出错,不添加辅助损失
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# Print data types for debugging
|
||||
@ -106,7 +120,7 @@ def train_epoch(epoch, wandb):
|
||||
# 保存模型
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
# moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
@ -124,7 +138,7 @@ def train_epoch(epoch, wandb):
|
||||
os.remove(save_path)
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user