This commit is contained in:
iomgaa 2025-06-20 12:43:21 +08:00
parent 770c34f0e3
commit 83b91859ce
2 changed files with 44 additions and 84 deletions

View File

@ -2,7 +2,7 @@ import math
import struct
import inspect
import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
@ -67,23 +67,21 @@ class KnowledgeDataset(nn.Module):
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
@ -106,7 +104,8 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
@ -158,85 +157,46 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.key_dim)
# queries = queries.permute(1, 0, 2)
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 2. 计算queries与keys的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores, indices = scores_and_indices[0], scores_and_indices[1]
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 应用智能分层选择策略
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
@ -522,10 +482,9 @@ class MiniMindLM(PreTrainedModel):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
@ -600,4 +559,5 @@ class MiniMindLM(PreTrainedModel):
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
break

View File

@ -461,14 +461,14 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################