This commit is contained in:
Gary 2025-05-16 08:38:59 +00:00
parent 5841f8b4e5
commit d7fe504e1e
5 changed files with 49 additions and 26 deletions

View File

@ -37,8 +37,8 @@ class LMConfig(PretrainedConfig):
seq_aux: bool = True, seq_aux: bool = True,
norm_topk_prob: bool = True, norm_topk_prob: bool = True,
#################################################### ####################################################
knowlwdge_num: int = 64*64, knowledge_num: int = 64*64,
knowlwdge_length: int = 8, knowledge_length: int = 8,
**kwargs, **kwargs,
): ):
self.dim = dim self.dim = dim
@ -70,6 +70,6 @@ class LMConfig(PretrainedConfig):
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失 self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率 self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
#################################################### ####################################################
self.knowlwdge_num = knowlwdge_num self.knowledge_num = knowledge_num
self.knowlwdge_length = knowlwdge_length self.knowledge_length = knowledge_length
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -524,15 +524,19 @@ class ExtractDB(nn.Module):
self.batch_size = None self.batch_size = None
self.dim = params.dim self.dim = params.dim
self.dim_key = self.dim // 2 self.dim_key = self.dim // 2
self.knowlwdge_num = params.knowlwdge_num # 100专家确保是完全平方数 self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用 # 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowlwdge_length*params.dim self.knowledge_length = params.knowledge_length
# 使用register_buffer代替nn.Parameter避免梯度问题 # 使用register_buffer代替nn.Parameter避免梯度问题
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02) # self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys) self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1 self.num_experts_per_head_topk = 1
@ -574,12 +578,12 @@ class ExtractDB(nn.Module):
def get_data(self, index): def get_data(self, index):
# 直接从GPU获取embedding # 直接从GPU获取embedding
db_values = self.weight_down_embed[index] db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
db_value = db_values.view(self.batch_size, -1, self.dim) # db_value = db_values.view(self.batch_size,-1)
return db_value return db_values
@torch.no_grad() @torch.no_grad()
def updata_value(self, k, v): def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度) # 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1) v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配 # 确保数据类型匹配
@ -654,8 +658,12 @@ class MiniMindLM(PreTrainedModel):
dtype=h.dtype, device=h.device) dtype=h.dtype, device=h.device)
else: else:
# 正常模式,使用数据库查询 # 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h) index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer( h = layer(
h, db_value, pos_cis_real h, db_value, pos_cis_real
@ -673,12 +681,28 @@ class MiniMindLM(PreTrainedModel):
# 数据库更新逻辑与主计算图分离 # 数据库更新逻辑与主计算图分离
with torch.no_grad(): with torch.no_grad():
# Compute shared downsampling layer once # Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached) shared_features = self.shared_downsample(h_tensor_detached)
z_v = self.downsample_v_specific(shared_features) z_v = self.downsample_v_specific(shared_features)#这里需要从emb返回为token
batch_z, seq_len, dim_z = z_v.shape
embedding_weights = self.tok_embeddings.weight.detach()#[vocab_size, dim]
z_v_flat = z_v.reshape(-1, z_v.shape[-1]) # [batch_size_z * knowledge_len, dim]
# 余弦相似度版本
# z_v_flat_norm = F.normalize(z_v_flat, p=2, dim=1)
# embedding_weights_norm = F.normalize(embedding_weights, p=2, dim=1)
# similarity_scores = torch.matmul(z_v_flat_norm, embedding_weights_norm.transpose(0, 1))
# token_indices_flat = torch.argmax(similarity_scores, dim=1) # [batch_size_z * seq_len]
# token_indices = token_indices_flat.reshape(batch_size, -1)
distances = torch.cdist(z_v_flat, embedding_weights, p=2)
token_indices_flat = torch.argmin(distances, dim=1)
token_indices = token_indices_flat.reshape(batch_z,-1)
z_q = self.downsample_q_specific(shared_features) z_q = self.downsample_q_specific(shared_features)
# import pdb;pdb.set_trace()
z_k = self.extract_db.q_to_k(z_q) z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, z_v) self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :]) logits = self.output(self.norm(h)[:, slice_indices, :])

View File

@ -45,5 +45,5 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \
--use_flash_attn \ --use_flash_attn \
--profile \ --profile \
--profile_interval 10\ --profile_interval 10\
--knowlwdge_num 4096 \ --knowledge_num 4096 \
--knowlwdge_length 8 --knowledge_length 8

View File

@ -291,7 +291,7 @@ def train_epoch(epoch, wandb):
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None): def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
# 加载tokenizer # 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
# 加载模型 # 加载模型
model = MiniMindLM(lm_config).to(args.device) model = MiniMindLM(lm_config).to(args.device)
@ -349,7 +349,7 @@ if __name__ == "__main__":
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式 parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
# 性能分析相关参数 # 性能分析相关参数
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
@ -406,7 +406,6 @@ if __name__ == "__main__":
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
else: else:
wandb = None wandb = None
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path) model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None train_sampler = DistributedSampler(train_ds) if ddp else None

View File

@ -289,8 +289,8 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args() args = parser.parse_args()
######################################################### #########################################################
@ -327,8 +327,8 @@ def main():
use_moe=args.use_moe, use_moe=args.use_moe,
disable_db=args.disable_db, disable_db=args.disable_db,
flash_attn=args.use_flash_attn, flash_attn=args.use_flash_attn,
knowlwdge_num=args.knowlwdge_num, knowledge_num=args.knowledge_num,
knowlwdge_length=args.knowlwdge_length knowledge_length=args.knowledge_length
) )
######################################################### #########################################################