update
This commit is contained in:
parent
5841f8b4e5
commit
d7fe504e1e
@ -37,8 +37,8 @@ class LMConfig(PretrainedConfig):
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
####################################################
|
||||
knowlwdge_num: int = 64*64,
|
||||
knowlwdge_length: int = 8,
|
||||
knowledge_num: int = 64*64,
|
||||
knowledge_length: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
self.dim = dim
|
||||
@ -70,6 +70,6 @@ class LMConfig(PretrainedConfig):
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
####################################################
|
||||
self.knowlwdge_num = knowlwdge_num
|
||||
self.knowlwdge_length = knowlwdge_length
|
||||
self.knowledge_num = knowledge_num
|
||||
self.knowledge_length = knowledge_length
|
||||
super().__init__(**kwargs)
|
||||
|
@ -524,15 +524,19 @@ class ExtractDB(nn.Module):
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.knowlwdge_num = params.knowlwdge_num # 100专家,确保是完全平方数
|
||||
self.knowledge_num = params.knowledge_num # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_length = params.knowlwdge_length*params.dim
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
|
||||
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
|
||||
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
|
||||
|
||||
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
@ -574,12 +578,12 @@ class ExtractDB(nn.Module):
|
||||
|
||||
def get_data(self, index):
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
|
||||
# db_value = db_values.view(self.batch_size,-1)
|
||||
return db_values
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):
|
||||
def updata_value(self, k, v):#要加一个从向量返回index的过程
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
@ -654,8 +658,12 @@ class MiniMindLM(PreTrainedModel):
|
||||
dtype=h.dtype, device=h.device)
|
||||
else:
|
||||
# 正常模式,使用数据库查询
|
||||
# import pdb;pdb.set_trace()
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
token_idx = self.extract_db.get_data(index) #这里是index
|
||||
|
||||
db_value =self.tok_embeddings(token_idx)
|
||||
|
||||
h = layer(
|
||||
h, db_value, pos_cis_real
|
||||
@ -673,12 +681,28 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
z_v = self.downsample_v_specific(shared_features)
|
||||
z_v = self.downsample_v_specific(shared_features)#这里需要从emb返回为token
|
||||
batch_z, seq_len, dim_z = z_v.shape
|
||||
embedding_weights = self.tok_embeddings.weight.detach()#[vocab_size, dim]
|
||||
z_v_flat = z_v.reshape(-1, z_v.shape[-1]) # [batch_size_z * knowledge_len, dim]
|
||||
|
||||
# 余弦相似度版本
|
||||
# z_v_flat_norm = F.normalize(z_v_flat, p=2, dim=1)
|
||||
# embedding_weights_norm = F.normalize(embedding_weights, p=2, dim=1)
|
||||
# similarity_scores = torch.matmul(z_v_flat_norm, embedding_weights_norm.transpose(0, 1))
|
||||
# token_indices_flat = torch.argmax(similarity_scores, dim=1) # [batch_size_z * seq_len]
|
||||
# token_indices = token_indices_flat.reshape(batch_size, -1)
|
||||
|
||||
distances = torch.cdist(z_v_flat, embedding_weights, p=2)
|
||||
token_indices_flat = torch.argmin(distances, dim=1)
|
||||
token_indices = token_indices_flat.reshape(batch_z,-1)
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
# import pdb;pdb.set_trace()
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, z_v)
|
||||
self.extract_db.updata_value(z_k, token_indices)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
|
@ -45,5 +45,5 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10\
|
||||
--knowlwdge_num 4096 \
|
||||
--knowlwdge_length 8
|
||||
--knowledge_num 4096 \
|
||||
--knowledge_length 8
|
||||
|
@ -291,7 +291,7 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
|
||||
# 加载模型
|
||||
model = MiniMindLM(lm_config).to(args.device)
|
||||
|
||||
@ -349,7 +349,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
# 性能分析相关参数
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
@ -406,7 +406,6 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
|
@ -289,8 +289,8 @@ def main():
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
|
||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
@ -327,8 +327,8 @@ def main():
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.use_flash_attn,
|
||||
knowlwdge_num=args.knowlwdge_num,
|
||||
knowlwdge_length=args.knowlwdge_length
|
||||
knowledge_num=args.knowledge_num,
|
||||
knowledge_length=args.knowledge_length
|
||||
)
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user