From 83b91859ce6b3361db1d07e20afa016fbbf061be Mon Sep 17 00:00:00 2001 From: iomgaa Date: Fri, 20 Jun 2025 12:43:21 +0800 Subject: [PATCH] update --- model/model.py | 124 ++++++++++++----------------------- train_pretrain_accelerate.py | 4 +- 2 files changed, 44 insertions(+), 84 deletions(-) diff --git a/model/model.py b/model/model.py index c94fd2c..67699b5 100644 --- a/model/model.py +++ b/model/model.py @@ -2,7 +2,7 @@ import math import struct import inspect import time - +#子空间二维分解+梯度更新 from .LMConfig import LMConfig from typing import Any, Optional, Tuple, List, Union import numpy as np @@ -67,23 +67,21 @@ class KnowledgeDataset(nn.Module): ## 数据库参数 self.knowledge_num = params.knowledge_num self.knowledge_length = params.knowledge_length - self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True) - self.product_key_topk = min(16, self.knowledge_num) - # 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 - self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num)) - + # 修改键存储为二维分解空间,设置为可训练参数 + self.num_keys = int(math.sqrt(self.knowledge_num)) + # 确保keys是可训练参数 + self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True) + self.product_key_topk = min(16, self.num_keys) + # 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度 self.register_buffer('knowledge_dataset', - torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) - ) + torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)) # 计算step数目,用于动态调整权重 self.step_counter = 0 - self.freeze_embedding = False - - + # 移除批次计数器和更新频率相关代码 def intelligent_selection(self, query, all_scores, all_indices): """智能分层选择策略""" @@ -106,7 +104,8 @@ class KnowledgeDataset(nn.Module): candidate_tokens = self.knowledge_dataset[unique_indices] flat_tokens = candidate_tokens.view(-1) flat_embeddings = self.tok_embeddings(flat_tokens) - #获取flat_tokens对应的index + + # 获取flat_tokens对应的index(保留这些变量以便其他地方使用) pre_update_indices = unique_indices.view(-1) pre_update_embeddings = flat_embeddings.view( len(unique_indices), self.knowledge_length, -1 @@ -158,85 +157,46 @@ class KnowledgeDataset(nn.Module): all_best_tokens = torch.stack(batch_best_tokens, dim=0) all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) - # 获取 - - # 使用重新计算的embeddings更新self.keys - if self.is_train: - self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) - - # 更新被修改过的key - with torch.no_grad(): - self.has_update_keys[pre_update_indices] = 1 - return all_best_tokens, all_best_tokens_embeddings - def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings): - if self.freeze_embedding: - return - # 使用pre_update_embeddings更新self.keys - with torch.no_grad(): - pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512] - pre_update_embeddings = self.to_queries(pre_update_embeddings) - self.keys[pre_update_indices] = pre_update_embeddings + - def search_index(self,x): + def search_index(self, x): batch_size, seq_len, dim = x.shape - # collapse sequence dimension by averaging + # 1. 序列维度平均 x_flat = x.mean(dim=1) # [batch_size, dim] - queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] - # queries = queries.reshape(batch_size, 2, self.key_dim) - # queries = queries.permute(1, 0, 2) + # 2. 生成查询向量并重塑为两个子查询 + queries = self.to_queries(x_flat) # [batch_size, knowledge_dim] + queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim] + # 调整维度顺序,使子空间维度位于首位 + queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim] - # 2. 计算queries与keys的相似度 - sim = torch.einsum('b d, k d -> b k', queries, self.keys) + # 3. 计算每个子空间的相似度 + sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - # 3. 在两个子空间分别做top-k - scores_and_indices = sim.topk(self.product_key_topk, dim=-1) - scores, indices = scores_and_indices[0], scores_and_indices[1] + # 4. 在两个子空间分别做top-k + scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] + scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] + indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - # 5. 应用智能分层选择策略 + # 5. 组合两个子空间的结果 + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk] + all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk] + + # 6. 将结果重塑为二维 + all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk] + all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk] + + # 7. 选择最终的top-k结果 + scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1) + indices = torch.gather(all_indices, 1, indices_of_indices) + + # 8. 应用智能分层选择策略 best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) - # 6. 更新1%的keys - if self.is_train: - # 获取未更新过的keys的索引 - not_updated_indices = torch.where(self.has_update_keys == 0)[0] - - # 如果有未更新的keys,随机选择num_update_keys个进行更新 - if len(not_updated_indices) > 0: - num_update_keys = int(self.knowledge_num * 0.01) - perm = torch.randperm(len(not_updated_indices))[:num_update_keys] - perm_num = perm.shape[0] - pre_update_indices = not_updated_indices[perm] - pre_update_tokens = self.knowledge_dataset[pre_update_indices] - pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) - pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1) - self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) - # 更新被修改过的key - with torch.no_grad(): - self.has_update_keys[pre_update_indices] = 1 - else: - print("all keys are updated") - # 重置所有keys的更新状态 - self.has_update_keys.zero_() - # 重新获取所有可更新的索引 - not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device) - num_update_keys = int(self.knowledge_num * 0.01) - perm = torch.randperm(len(not_updated_indices))[:num_update_keys] - pre_update_indices = not_updated_indices[perm] - pre_update_tokens = self.knowledge_dataset[pre_update_indices] - pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) - pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1) - self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) - # 更新被修改过的key - with torch.no_grad(): - self.has_update_keys[pre_update_indices] = 1 - - - return best_tokens, best_tokens_embeddings class CrossAttention(nn.Module): @@ -522,10 +482,9 @@ class MiniMindLM(PreTrainedModel): start_pos = args.get('start_pos', 0) if self.freeze_embedding and step == 0: self.tok_embeddings.weight.requires_grad = False - # 同时冻结KnowledgeDataset的嵌入更新 - self.knowledge_dataset.freeze_embedding = True + # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制 + # self.knowledge_dataset.freeze_embedding = True print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) - print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] for l, layer in enumerate(self.layers): @@ -600,4 +559,5 @@ class MiniMindLM(PreTrainedModel): input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: - break \ No newline at end of file + break + diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index e00e6a4..aae2e81 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -461,14 +461,14 @@ def main(): parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") - parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目") + parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径") parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") args = parser.parse_args() - + ######################################################### # 初始化accelerator和deepspeed #########################################################