diff --git a/model/LMConfig.py b/model/LMConfig.py index 8ce52fc..1312e97 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -37,8 +37,8 @@ class LMConfig(PretrainedConfig): seq_aux: bool = True, norm_topk_prob: bool = True, #################################################### - knowlwdge_num: int = 64*64, - knowlwdge_length: int = 8, + knowledge_num: int = 64*64, + knowledge_length: int = 8, **kwargs, ): self.dim = dim @@ -70,6 +70,6 @@ class LMConfig(PretrainedConfig): self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失 self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率 #################################################### - self.knowlwdge_num = knowlwdge_num - self.knowlwdge_length = knowlwdge_length + self.knowledge_num = knowledge_num + self.knowledge_length = knowledge_length super().__init__(**kwargs) diff --git a/model/model.py b/model/model.py index 9cb6a49..010c8e3 100644 --- a/model/model.py +++ b/model/model.py @@ -515,7 +515,7 @@ class MiniMindBlock(nn.Module): # 前馈神经网络 out = h + self.feed_forward(self.ffn_norm(h)) - return out + return out class ExtractDB(nn.Module): def __init__(self,params): @@ -524,22 +524,26 @@ class ExtractDB(nn.Module): self.batch_size = None self.dim = params.dim self.dim_key = self.dim // 2 - self.knowlwdge_num = params.knowlwdge_num # 100专家,确保是完全平方数 + self.knowledge_num = params.knowledge_num # 100专家,确保是完全平方数 # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 self.head_dim = params.dim // params.n_heads - self.knowledge_length = params.knowlwdge_length*params.dim + self.knowledge_length = params.knowledge_length # 使用register_buffer代替nn.Parameter,避免梯度问题 - self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02) + # self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02) + self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long)) + + - self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0 + + self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0 self.product_key_topk = min(16, self.num_keys) self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) self.num_experts_per_head_topk = 1 self.to_queries = nn.Sequential( nn.Linear(params.dim, self.dim_key * 2, bias=False), ) - + def q_to_k(self,x): # 1. 生成queries self.batch_size, seq_len, dim = x.shape @@ -574,12 +578,12 @@ class ExtractDB(nn.Module): def get_data(self, index): # 直接从GPU获取embedding - db_values = self.weight_down_embed[index] - db_value = db_values.view(self.batch_size, -1, self.dim) - return db_value + db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb + # db_value = db_values.view(self.batch_size,-1) + return db_values @torch.no_grad() - def updata_value(self, k, v): + def updata_value(self, k, v):#要加一个从向量返回index的过程 # 直接更新buffer上的值 (不需要梯度) v_reshaped = v.view(v.size(0), -1) # 确保数据类型匹配 @@ -654,9 +658,13 @@ class MiniMindLM(PreTrainedModel): dtype=h.dtype, device=h.device) else: # 正常模式,使用数据库查询 + # import pdb;pdb.set_trace() index = self.extract_db.q_to_k(h) - db_value = self.extract_db.get_data(index) + + token_idx = self.extract_db.get_data(index) #这里是index + db_value =self.tok_embeddings(token_idx) + h = layer( h, db_value, pos_cis_real ) @@ -673,12 +681,28 @@ class MiniMindLM(PreTrainedModel): # 数据库更新逻辑与主计算图分离 with torch.no_grad(): + # Compute shared downsampling layer once shared_features = self.shared_downsample(h_tensor_detached) - z_v = self.downsample_v_specific(shared_features) + z_v = self.downsample_v_specific(shared_features)#这里需要从emb返回为token + batch_z, seq_len, dim_z = z_v.shape + embedding_weights = self.tok_embeddings.weight.detach()#[vocab_size, dim] + z_v_flat = z_v.reshape(-1, z_v.shape[-1]) # [batch_size_z * knowledge_len, dim] + + # 余弦相似度版本 + # z_v_flat_norm = F.normalize(z_v_flat, p=2, dim=1) + # embedding_weights_norm = F.normalize(embedding_weights, p=2, dim=1) + # similarity_scores = torch.matmul(z_v_flat_norm, embedding_weights_norm.transpose(0, 1)) + # token_indices_flat = torch.argmax(similarity_scores, dim=1) # [batch_size_z * seq_len] + # token_indices = token_indices_flat.reshape(batch_size, -1) + + distances = torch.cdist(z_v_flat, embedding_weights, p=2) + token_indices_flat = torch.argmin(distances, dim=1) + token_indices = token_indices_flat.reshape(batch_z,-1) z_q = self.downsample_q_specific(shared_features) + # import pdb;pdb.set_trace() z_k = self.extract_db.q_to_k(z_q) - self.extract_db.updata_value(z_k, z_v) + self.extract_db.updata_value(z_k, token_indices) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) diff --git a/run_file/DynamicKV-LLM_Mini_Minimind.sh b/run_file/DynamicKV-LLM_Mini_Minimind.sh index 0c92075..9495752 100644 --- a/run_file/DynamicKV-LLM_Mini_Minimind.sh +++ b/run_file/DynamicKV-LLM_Mini_Minimind.sh @@ -45,5 +45,5 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \ --use_flash_attn \ --profile \ --profile_interval 10\ - --knowlwdge_num 4096 \ - --knowlwdge_length 8 + --knowledge_num 4096 \ + --knowledge_length 8 diff --git a/train_pretrain.py b/train_pretrain.py index 7397984..9776e39 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -291,7 +291,7 @@ def train_epoch(epoch, wandb): def init_model(lm_config, pretrained_embedding_path: Optional[str] = None): # 加载tokenizer - tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer') # 加载模型 model = MiniMindLM(lm_config).to(args.device) @@ -349,7 +349,7 @@ if __name__ == "__main__": parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。 parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式 - parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 + parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") # 性能分析相关参数 parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") @@ -406,7 +406,6 @@ if __name__ == "__main__": wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None - model, tokenizer = init_model(lm_config, args.pretrained_embedding_path) train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 5faf219..6fa19b9 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -289,8 +289,8 @@ def main(): parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") - parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目") - parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度") + parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目") + parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度") args = parser.parse_args() ######################################################### @@ -327,8 +327,8 @@ def main(): use_moe=args.use_moe, disable_db=args.disable_db, flash_attn=args.use_flash_attn, - knowlwdge_num=args.knowlwdge_num, - knowlwdge_length=args.knowlwdge_length + knowledge_num=args.knowledge_num, + knowledge_length=args.knowledge_length ) #########################################################