update
This commit is contained in:
parent
770c34f0e3
commit
83b91859ce
120
model/model.py
120
model/model.py
@ -2,7 +2,7 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
@ -67,23 +67,21 @@ class KnowledgeDataset(nn.Module):
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.knowledge_num)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||
)
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
@ -106,7 +104,8 @@ class KnowledgeDataset(nn.Module):
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
#获取flat_tokens对应的index
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
@ -158,84 +157,45 @@ class KnowledgeDataset(nn.Module):
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 获取
|
||||
|
||||
# 使用重新计算的embeddings更新self.keys
|
||||
if self.is_train:
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||
if self.freeze_embedding:
|
||||
return
|
||||
# 使用pre_update_embeddings更新self.keys
|
||||
with torch.no_grad():
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||||
self.keys[pre_update_indices] = pre_update_embeddings
|
||||
|
||||
def search_index(self,x):
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||||
# queries = queries.permute(1, 0, 2)
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||||
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 应用智能分层选择策略
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 6. 更新1%的keys
|
||||
if self.is_train:
|
||||
# 获取未更新过的keys的索引
|
||||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||
|
||||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||
if len(not_updated_indices) > 0:
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
perm_num = perm.shape[0]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
else:
|
||||
print("all keys are updated")
|
||||
# 重置所有keys的更新状态
|
||||
self.has_update_keys.zero_()
|
||||
# 重新获取所有可更新的索引
|
||||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
@ -522,10 +482,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 同时冻结KnowledgeDataset的嵌入更新
|
||||
self.knowledge_dataset.freeze_embedding = True
|
||||
# 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
@ -601,3 +560,4 @@ class MiniMindLM(PreTrainedModel):
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
||||
|
@ -461,7 +461,7 @@ def main():
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user