修复了loss为nan的错误
This commit is contained in:
parent
0c8c6e5d1a
commit
eba4311ac5
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,4 +2,4 @@
|
||||
/dataset
|
||||
/out
|
||||
wandb/
|
||||
**/*.log
|
||||
**/*.lognohup.out
|
||||
|
@ -2,6 +2,8 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import sys
|
||||
sys.path.append('/mnt/lzn/Minimind')
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
@ -94,22 +96,25 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False,
|
||||
use_cache=True,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
|
||||
|
||||
# 应用旋转位置编码
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
|
||||
# print(xk, xv)
|
||||
# 重复键值对
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
@ -373,7 +378,7 @@ class MiniMindBlock(nn.Module):
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False):
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
@ -425,7 +430,7 @@ class MiniMindBlock(nn.Module):
|
||||
use_cache=use_cache,
|
||||
db_value=db_value
|
||||
)
|
||||
|
||||
# print(past_kv)
|
||||
h_attn = self.cross_att(h_attn, db_value)
|
||||
|
||||
# 残差连接
|
||||
@ -485,7 +490,7 @@ class ExtractDB(nn.Module):
|
||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# 5. 最终top-k选择
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)# no gradient
|
||||
indices = all_indices.gather(-1, pk_indices)
|
||||
flat_indices = indices.view(-1)
|
||||
return flat_indices
|
||||
@ -557,7 +562,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
use_cache: bool = True,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
@ -578,16 +583,17 @@ class MiniMindLM(PreTrainedModel):
|
||||
# 正常模式,使用数据库查询
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
|
||||
|
||||
|
||||
h, past_kv = layer(
|
||||
h, db_value, pos_cis,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
past_kvs.append(past_kv)
|
||||
h_list.append(h.unsqueeze(0))
|
||||
|
||||
# print(past_kvs)
|
||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
||||
|
||||
# 只在非禁用数据库模式下执行数据库更新逻辑
|
||||
@ -595,7 +601,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
# 使用detach()分离计算图,避免多次反向传播
|
||||
h_tensor_detached = h_tensor.detach()
|
||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
|
||||
|
||||
|
||||
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
# Compute shared downsampling layer once
|
||||
@ -604,7 +612,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, z_v)
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
@ -1,6 +1,9 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
os.environ["WANDB_MODE"] = "offline"
|
||||
import sys
|
||||
sys.path.append('/mnt/lzn/Minimind')
|
||||
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
@ -119,6 +122,7 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {str(e)}")
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user