This commit is contained in:
iomgaa 2025-06-20 12:43:21 +08:00
parent 770c34f0e3
commit 83b91859ce
2 changed files with 44 additions and 84 deletions

View File

@ -2,7 +2,7 @@ import math
import struct import struct
import inspect import inspect
import time import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union from typing import Any, Optional, Tuple, List, Union
import numpy as np import numpy as np
@ -67,23 +67,21 @@ class KnowledgeDataset(nn.Module):
## 数据库参数 ## 数据库参数
self.knowledge_num = params.knowledge_num self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 # 修改键存储为二维分解空间,设置为可训练参数
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num)) self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度 # 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset', self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
)
# 计算step数目用于动态调整权重 # 计算step数目用于动态调整权重
self.step_counter = 0 self.step_counter = 0
self.freeze_embedding = False # 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices): def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略""" """智能分层选择策略"""
@ -106,7 +104,8 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices] candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1) flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens) flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1) pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view( pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1 len(unique_indices), self.knowledge_length, -1
@ -158,84 +157,45 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0) all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
def search_index(self, x):
batch_size, seq_len, dim = x.shape batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging # 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim] x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] # 2. 生成查询向量并重塑为两个子查询
# queries = queries.reshape(batch_size, 2, self.key_dim) queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
# queries = queries.permute(1, 0, 2) queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 2. 计算queries与keys的相似度 # 3. 计算每个子空间的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys) sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k # 4. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1) scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores, indices = scores_and_indices[0], scores_and_indices[1] scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 应用智能分层选择策略 # 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings return best_tokens, best_tokens_embeddings
@ -522,10 +482,9 @@ class MiniMindLM(PreTrainedModel):
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0: if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新 # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
self.knowledge_dataset.freeze_embedding = True # self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
@ -601,3 +560,4 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:] yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

View File

@ -461,7 +461,7 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")