modified by wcy
This commit is contained in:
parent
770c34f0e3
commit
4b9c5e29ae
102
.vscode/launch.json
vendored
102
.vscode/launch.json
vendored
@ -1,102 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "512",
|
||||
"--n_layers", "8",
|
||||
"--batch_size", "8",
|
||||
"--epochs", "1"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0,1"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate (Small Test)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "512",
|
||||
"--n_layers", "8",
|
||||
"--batch_size", "2",
|
||||
"--epochs", "1",
|
||||
"--log_interval", "10",
|
||||
"--save_interval", "100",
|
||||
"--accumulation_steps", "4"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"WANDB_MODE": "offline"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug ExtractDB Comparison",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "256",
|
||||
"--n_layers", "4",
|
||||
"--batch_size", "2",
|
||||
"--epochs", "1",
|
||||
"--log_interval", "10",
|
||||
"--save_interval", "200",
|
||||
"--accumulation_steps", "2",
|
||||
"--comparison_mode",
|
||||
"--knowledge_num", "256",
|
||||
"--knowledge_length", "64",
|
||||
"--comparison_mode"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"WANDB_MODE": "offline"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
}
|
||||
]
|
||||
}
|
18
.vscode/settings.json
vendored
18
.vscode/settings.json
vendored
@ -1,18 +0,0 @@
|
||||
{
|
||||
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.formatting.provider": "black",
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/*.pyc": true,
|
||||
"**/.git": false,
|
||||
"**/wandb": false
|
||||
}
|
||||
}
|
97
analyze_database.py
Normal file
97
analyze_database.py
Normal file
@ -0,0 +1,97 @@
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
|
||||
"""分析database_init.json文件中的数据条目数量和质量"""
|
||||
|
||||
print(f"开始分析数据库文件: {json_path}")
|
||||
|
||||
# 1. 加载tokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
print(f"成功加载tokenizer: {tokenizer_path}")
|
||||
except Exception as e:
|
||||
print(f"加载tokenizer失败: {e}")
|
||||
return
|
||||
|
||||
# 2. 加载JSON文件
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
# 提取sentences列表
|
||||
sentences_data = database_data.get('sentences', [])
|
||||
print(f"加载了 {len(sentences_data)} 条sentences数据")
|
||||
except Exception as e:
|
||||
print(f"加载JSON文件失败: {e}")
|
||||
return
|
||||
|
||||
# 3. 分析句子长度分布
|
||||
if len(sentences_data) == 0:
|
||||
print("没有找到有效的句子数据")
|
||||
return
|
||||
|
||||
# 按照importance_score排序
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
print(f"按importance_score排序完成,最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
|
||||
|
||||
# 统计句子长度分布
|
||||
token_lengths = []
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
|
||||
# 4. 分析token长度分布
|
||||
for i, sentence_data in enumerate(sorted_sentences):
|
||||
sentence = sentence_data.get('corrected_sentence', '')
|
||||
if not sentence:
|
||||
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
|
||||
continue
|
||||
|
||||
# 将句子转换为tokens
|
||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
token_lengths.append(len(sentence_tokens))
|
||||
|
||||
if i < 5: # 显示前5条数据样例
|
||||
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
|
||||
|
||||
# 5. 统计分析结果
|
||||
token_lengths = torch.tensor(token_lengths)
|
||||
stats = {
|
||||
"总条目数": len(sorted_sentences),
|
||||
"有效条目数": len(token_lengths),
|
||||
"token长度-平均值": token_lengths.float().mean().item(),
|
||||
"token长度-最小值": token_lengths.min().item(),
|
||||
"token长度-最大值": token_lengths.max().item(),
|
||||
"token长度-中位数": token_lengths.median().item(),
|
||||
"token长度-标准差": token_lengths.float().std().item(),
|
||||
}
|
||||
|
||||
# 统计长度分布
|
||||
length_bins = {
|
||||
"小于16": (token_lengths < 16).sum().item(),
|
||||
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
|
||||
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
|
||||
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
|
||||
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
|
||||
"256及以上": (token_lengths >= 256).sum().item(),
|
||||
}
|
||||
|
||||
# 打印统计信息
|
||||
print("\n===== 数据库分析结果 =====")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n===== Token长度分布 =====")
|
||||
for bin_name, count in length_bins.items():
|
||||
percentage = (count / len(token_lengths)) * 100
|
||||
print(f"{bin_name}: {count} ({percentage:.1f}%)")
|
||||
|
||||
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
|
||||
|
||||
return stats, length_bins
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 指定数据库文件路径
|
||||
database_path = "./dataset/database_init.json"
|
||||
analyze_database(database_path)
|
@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
|
||||
help="Path to the model checkpoint")
|
||||
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||
help="Path to save the decoded text file")
|
||||
|
112
loss.py
Normal file
112
loss.py
Normal file
@ -0,0 +1,112 @@
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def parse_log_file(file_path):
|
||||
"""
|
||||
Parse the training log file to extract epoch, step, and loss information.
|
||||
"""
|
||||
# Regular expression to match log entries with loss information
|
||||
pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)'
|
||||
|
||||
epochs = []
|
||||
steps = []
|
||||
losses = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Find all matches
|
||||
matches = re.findall(pattern, log_content)
|
||||
|
||||
for match in matches:
|
||||
epoch = int(match[0])
|
||||
step = int(match[1])
|
||||
loss = float(match[2])
|
||||
|
||||
epochs.append(epoch)
|
||||
steps.append(step)
|
||||
losses.append(loss)
|
||||
|
||||
return epochs, steps, losses
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing log file: {e}")
|
||||
return [], [], []
|
||||
|
||||
def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'):
|
||||
"""
|
||||
Plot the loss curve and save it to a file.
|
||||
"""
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
# Create continuous steps for better visualization
|
||||
continuous_steps = []
|
||||
current_max_step = 0
|
||||
prev_epoch = 0
|
||||
|
||||
for i, (e, s) in enumerate(zip(epochs, steps)):
|
||||
if e > prev_epoch:
|
||||
# New epoch starts
|
||||
if i > 0:
|
||||
current_max_step = continuous_steps[-1]
|
||||
prev_epoch = e
|
||||
continuous_steps.append(s + current_max_step)
|
||||
|
||||
# 修改:减小线条宽度和点的大小
|
||||
plt.plot(continuous_steps, losses, marker='.', linestyle='-',
|
||||
color='#1f77b4', markersize=3, linewidth=0.8)
|
||||
|
||||
plt.title('Training Loss Over Steps', fontsize=16)
|
||||
plt.xlabel('Steps (Continuous)', fontsize=14)
|
||||
plt.ylabel('Loss', fontsize=14)
|
||||
plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5)
|
||||
|
||||
# 修改:减小红线宽度
|
||||
for i in range(1, len(epochs)):
|
||||
if epochs[i] > epochs[i-1]:
|
||||
plt.axvline(x=continuous_steps[i], color='r',
|
||||
linestyle='--', alpha=0.5, linewidth=0.8)
|
||||
|
||||
unique_epochs = sorted(set(epochs))
|
||||
# Add epoch labels
|
||||
for e in unique_epochs:
|
||||
indices = [i for i, epoch in enumerate(epochs) if epoch == e]
|
||||
if indices:
|
||||
mid_idx = indices[len(indices) // 2]
|
||||
plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}',
|
||||
horizontalalignment='center', verticalalignment='center',
|
||||
fontsize=10,
|
||||
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})
|
||||
|
||||
# 移除悬停注释,简化图表
|
||||
# for i, (e, s, l) in enumerate(zip(epochs, steps, losses)):
|
||||
# plt.annotate(...)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_file, dpi=300)
|
||||
print(f"Loss curve saved as {output_file}")
|
||||
|
||||
# Also display the data in a table format
|
||||
print("\nExtracted training data:")
|
||||
print("-" * 50)
|
||||
print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}")
|
||||
print("-" * 50)
|
||||
for e, s, l in zip(epochs, steps, losses):
|
||||
print(f"{e:<10}{s:<10}{l:<15.6f}")
|
||||
|
||||
def main():
|
||||
# Specify the path to your log file
|
||||
log_file_path = 'out/train.log'
|
||||
|
||||
# Parse the log file
|
||||
epochs, steps, losses = parse_log_file(log_file_path)
|
||||
|
||||
if epochs and steps and losses:
|
||||
plot_loss_curve(epochs, steps, losses)
|
||||
else:
|
||||
print("No data extracted from log file. Please check if the file format is correct.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
163
model/model.py
163
model/model.py
@ -2,7 +2,7 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
@ -67,23 +67,21 @@ class KnowledgeDataset(nn.Module):
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.knowledge_num)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||
)
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
@ -106,7 +104,8 @@ class KnowledgeDataset(nn.Module):
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
#获取flat_tokens对应的index
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
@ -158,85 +157,87 @@ class KnowledgeDataset(nn.Module):
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 获取
|
||||
|
||||
# 使用重新计算的embeddings更新self.keys
|
||||
if self.is_train:
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||
if self.freeze_embedding:
|
||||
return
|
||||
# 使用pre_update_embeddings更新self.keys
|
||||
|
||||
with torch.no_grad():
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||||
self.keys[pre_update_indices] = pre_update_embeddings
|
||||
# 1. 计算token序列的平均嵌入
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
|
||||
# 2. 转换维度
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
|
||||
|
||||
# 3. 将one-hot索引转换为子空间索引
|
||||
indices_x = pre_update_indices // self.num_keys
|
||||
indices_y = pre_update_indices % self.num_keys
|
||||
|
||||
# 4. 收集需要更新的唯一子键
|
||||
unique_x = torch.unique(indices_x)
|
||||
unique_y = torch.unique(indices_y)
|
||||
|
||||
# 5. 更新第一个子空间的键
|
||||
for k1 in unique_x:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k1 = (indices_x == k1)
|
||||
if mask_k1.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k1_embeddings = pre_update_embeddings[mask_k1]
|
||||
k1_avg_embedding = k1_embeddings.mean(dim=0)
|
||||
|
||||
# 拆分为两个子空间并更新第一个子空间
|
||||
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
|
||||
|
||||
# 6. 更新第二个子空间的键
|
||||
for k2 in unique_y:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k2 = (indices_y == k2)
|
||||
if mask_k2.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k2_embeddings = pre_update_embeddings[mask_k2]
|
||||
k2_avg_embedding = k2_embeddings.mean(dim=0)
|
||||
|
||||
# 更新第二个子空间
|
||||
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
|
||||
|
||||
def search_index(self,x):
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||||
# queries = queries.permute(1, 0, 2)
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||||
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 应用智能分层选择策略
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 6. 更新1%的keys
|
||||
if self.is_train:
|
||||
# 获取未更新过的keys的索引
|
||||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||
|
||||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||
if len(not_updated_indices) > 0:
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
perm_num = perm.shape[0]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
else:
|
||||
print("all keys are updated")
|
||||
# 重置所有keys的更新状态
|
||||
self.has_update_keys.zero_()
|
||||
# 重新获取所有可更新的索引
|
||||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@ -522,10 +523,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 同时冻结KnowledgeDataset的嵌入更新
|
||||
self.knowledge_dataset.freeze_embedding = True
|
||||
# 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
@ -600,4 +600,5 @@ class MiniMindLM(PreTrainedModel):
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
break
|
||||
|
||||
|
603
model/model0.py
Normal file
603
model/model0.py
Normal file
@ -0,0 +1,603 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
#子空间不分解+嵌入更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.knowledge_num)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||
)
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
#获取flat_tokens对应的index
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 获取
|
||||
|
||||
# 使用重新计算的embeddings更新self.keys
|
||||
if self.is_train:
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||
if self.freeze_embedding:
|
||||
return
|
||||
# 使用pre_update_embeddings更新self.keys
|
||||
with torch.no_grad():
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||||
self.keys[pre_update_indices] = pre_update_embeddings
|
||||
|
||||
def search_index(self,x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||||
# queries = queries.permute(1, 0, 2)
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||||
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||||
|
||||
# 5. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 6. 更新1%的keys
|
||||
if self.is_train:
|
||||
# 获取未更新过的keys的索引
|
||||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||
|
||||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||
if len(not_updated_indices) > 0:
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
perm_num = perm.shape[0]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
else:
|
||||
print("all keys are updated")
|
||||
# 重置所有keys的更新状态
|
||||
self.has_update_keys.zero_()
|
||||
# 重新获取所有可更新的索引
|
||||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 同时冻结KnowledgeDataset的嵌入更新
|
||||
self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
675
model/model1.py
Normal file
675
model/model1.py
Normal file
@ -0,0 +1,675 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
#子空间二维分解+全局嵌入更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
# 添加批次计数器和更新频率
|
||||
self.batch_counter = 0
|
||||
self.update_frequency = 100 # 每100个批次更新一次
|
||||
|
||||
def _global_keys_update(self):
|
||||
"""全局更新所有子键"""
|
||||
# 移除对self.freeze_embedding的检查,确保在调用时总是执行更新
|
||||
with torch.no_grad():
|
||||
# 创建用于存储每个子键的嵌入和计数的张量
|
||||
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
|
||||
# 分批处理所有知识条目,避免内存溢出
|
||||
batch_size = 1000 # 可根据可用内存调整
|
||||
for i in range(0, self.knowledge_num, batch_size):
|
||||
end_idx = min(i + batch_size, self.knowledge_num)
|
||||
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
|
||||
|
||||
# 获取批次的嵌入
|
||||
batch_tokens = self.knowledge_dataset[batch_indices]
|
||||
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
|
||||
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
|
||||
batch_embeddings = self.to_queries(batch_embeddings)
|
||||
|
||||
# 计算批次中每个条目对应的子键索引
|
||||
indices_x = batch_indices // self.num_keys
|
||||
indices_y = batch_indices % self.num_keys
|
||||
|
||||
# 累加每个子键对应的嵌入
|
||||
for j in range(len(batch_indices)):
|
||||
k1, k2 = indices_x[j].item(), indices_y[j].item()
|
||||
embedding = batch_embeddings[j]
|
||||
|
||||
# 更新第一个子空间累加值
|
||||
k1_embeddings_sum[k1] += embedding[:self.key_dim]
|
||||
k1_counts[k1] += 1
|
||||
|
||||
# 更新第二个子空间累加值
|
||||
k2_embeddings_sum[k2] += embedding[self.key_dim:]
|
||||
k2_counts[k2] += 1
|
||||
|
||||
# 计算平均值并更新键
|
||||
# 避免除零错误
|
||||
k1_counts = torch.clamp(k1_counts, min=1)
|
||||
k2_counts = torch.clamp(k2_counts, min=1)
|
||||
|
||||
# 计算每个子键的平均嵌入
|
||||
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
|
||||
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
|
||||
|
||||
print(f"执行了全局键更新,批次: {self.batch_counter}")
|
||||
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# 1. 计算token序列的平均嵌入
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
|
||||
# 2. 转换维度
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
|
||||
|
||||
# 3. 将one-hot索引转换为子空间索引
|
||||
indices_x = pre_update_indices // self.num_keys
|
||||
indices_y = pre_update_indices % self.num_keys
|
||||
|
||||
# 4. 收集需要更新的唯一子键
|
||||
unique_x = torch.unique(indices_x)
|
||||
unique_y = torch.unique(indices_y)
|
||||
|
||||
# 5. 更新第一个子空间的键
|
||||
for k1 in unique_x:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k1 = (indices_x == k1)
|
||||
if mask_k1.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k1_embeddings = pre_update_embeddings[mask_k1]
|
||||
k1_avg_embedding = k1_embeddings.mean(dim=0)
|
||||
|
||||
# 拆分为两个子空间并更新第一个子空间
|
||||
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
|
||||
|
||||
# 6. 更新第二个子空间的键
|
||||
for k2 in unique_y:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k2 = (indices_y == k2)
|
||||
if mask_k2.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k2_embeddings = pre_update_embeddings[mask_k2]
|
||||
k2_avg_embedding = k2_embeddings.mean(dim=0)
|
||||
|
||||
# 更新第二个子空间
|
||||
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 9. 更新批次计数并在特定批次执行全局更新
|
||||
if self.is_train:
|
||||
self.batch_counter += 1
|
||||
|
||||
# 每update_frequency个批次执行一次全局更新,其余时间保持冻结
|
||||
if self.batch_counter % self.update_frequency == 0:
|
||||
# 只在特定批次更新键,无论freeze_embedding状态如何
|
||||
self._global_keys_update()
|
||||
# 标记所有键为已更新状态
|
||||
with torch.no_grad():
|
||||
self.has_update_keys.fill_(1)
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
679
model/model2.py
Normal file
679
model/model2.py
Normal file
@ -0,0 +1,679 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
#子空间四维分解+全局嵌入更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
# 修改:子空间维度从原来的一半变为四分之一
|
||||
self.key_dim = self.knowledge_dim // 4
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改:将键存储从二维分解空间改为四维分解空间
|
||||
# 计算每个子空间的键数量(使用四次根号N)
|
||||
self.num_keys = int(self.knowledge_num ** 0.25)
|
||||
# 修改子空间个数从2变为4
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 4, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
# 添加批次计数器和更新频率
|
||||
self.batch_counter = 0
|
||||
self.update_frequency = 100 # 每100个批次更新一次
|
||||
|
||||
def _global_keys_update(self):
|
||||
"""全局更新所有子键"""
|
||||
# 移除对self.freeze_embedding的检查,确保在调用时总是执行更新
|
||||
with torch.no_grad():
|
||||
# 创建用于存储每个子键的嵌入和计数的张量(修改为4个子空间)
|
||||
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
k3_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
k4_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
|
||||
|
||||
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
k3_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
k4_counts = torch.zeros(self.num_keys, device=self.keys.device)
|
||||
|
||||
# 分批处理所有知识条目,避免内存溢出
|
||||
batch_size = 1000 # 可根据可用内存调整
|
||||
for i in range(0, self.knowledge_num, batch_size):
|
||||
end_idx = min(i + batch_size, self.knowledge_num)
|
||||
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
|
||||
|
||||
# 获取批次的嵌入
|
||||
batch_tokens = self.knowledge_dataset[batch_indices]
|
||||
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
|
||||
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
|
||||
batch_embeddings = self.to_queries(batch_embeddings)
|
||||
|
||||
# 计算批次中每个条目对应的子键索引(修改为4个子空间的索引计算)
|
||||
# 使用整数除法和取模运算来提取四维索引
|
||||
temp = batch_indices
|
||||
indices_4 = temp % self.num_keys
|
||||
temp = temp // self.num_keys
|
||||
indices_3 = temp % self.num_keys
|
||||
temp = temp // self.num_keys
|
||||
indices_2 = temp % self.num_keys
|
||||
indices_1 = temp // self.num_keys
|
||||
|
||||
# 累加每个子键对应的嵌入
|
||||
for j in range(len(batch_indices)):
|
||||
k1, k2, k3, k4 = indices_1[j].item(), indices_2[j].item(), indices_3[j].item(), indices_4[j].item()
|
||||
embedding = batch_embeddings[j]
|
||||
|
||||
# 将嵌入分为四份并分别累加到对应的子空间
|
||||
quarter = self.key_dim
|
||||
k1_embeddings_sum[k1] += embedding[:quarter]
|
||||
k1_counts[k1] += 1
|
||||
|
||||
k2_embeddings_sum[k2] += embedding[quarter:2*quarter]
|
||||
k2_counts[k2] += 1
|
||||
|
||||
k3_embeddings_sum[k3] += embedding[2*quarter:3*quarter]
|
||||
k3_counts[k3] += 1
|
||||
|
||||
k4_embeddings_sum[k4] += embedding[3*quarter:]
|
||||
k4_counts[k4] += 1
|
||||
|
||||
# 计算平均值并更新键
|
||||
# 避免除零错误
|
||||
k1_counts = torch.clamp(k1_counts, min=1)
|
||||
k2_counts = torch.clamp(k2_counts, min=1)
|
||||
k3_counts = torch.clamp(k3_counts, min=1)
|
||||
k4_counts = torch.clamp(k4_counts, min=1)
|
||||
|
||||
# 计算每个子键的平均嵌入
|
||||
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
|
||||
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
|
||||
self.keys[:, 2] = k3_embeddings_sum / k3_counts.unsqueeze(1)
|
||||
self.keys[:, 3] = k4_embeddings_sum / k4_counts.unsqueeze(1)
|
||||
|
||||
print(f"执行了全局键更新,批次: {self.batch_counter}")
|
||||
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为四个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
# 修改:重塑为四个子查询而非两个
|
||||
queries = queries.reshape(batch_size, 4, self.key_dim) # [batch_size, 4, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [4, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在四个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(4)]
|
||||
scores_1, scores_2, scores_3, scores_4 = [scores_and_indices[p][0] for p in range(4)]
|
||||
indices_1, indices_2, indices_3, indices_4 = [scores_and_indices[p][1] for p in range(4)]
|
||||
|
||||
# 5. 组合四个子空间的结果
|
||||
# 首先组合第一、第二子空间
|
||||
scores_12 = scores_1.unsqueeze(-1) + scores_2.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
indices_12_base = (indices_1.unsqueeze(-1) * self.num_keys) + indices_2.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 然后组合第三、第四子空间
|
||||
scores_34 = scores_3.unsqueeze(-1) + scores_4.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
indices_34_base = (indices_3.unsqueeze(-1) * self.num_keys) + indices_4.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 最后组合所有子空间
|
||||
scores_flat_12 = scores_12.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
indices_flat_12 = indices_12_base.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
scores_flat_34 = scores_34.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
indices_flat_34 = indices_34_base.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 对12和34组合的结果进行top-k选择
|
||||
topk_scores_12, topk_indices_12 = scores_flat_12.topk(min(self.product_key_topk, scores_flat_12.size(1)), dim=-1)
|
||||
topk_indices_12 = torch.gather(indices_flat_12, 1, topk_indices_12)
|
||||
|
||||
topk_scores_34, topk_indices_34 = scores_flat_34.topk(min(self.product_key_topk, scores_flat_34.size(1)), dim=-1)
|
||||
topk_indices_34 = torch.gather(indices_flat_34, 1, topk_indices_34)
|
||||
|
||||
# 将12和34的结果组合
|
||||
all_scores = topk_scores_12.unsqueeze(-1) + topk_scores_34.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (topk_indices_12.unsqueeze(-1) * (self.num_keys**2)) + topk_indices_34.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 9. 更新批次计数并在特定批次执行全局更新
|
||||
if self.is_train:
|
||||
self.batch_counter += 1
|
||||
|
||||
# 每update_frequency个批次执行一次全局更新,其余时间保持冻结
|
||||
if self.batch_counter % self.update_frequency == 0:
|
||||
# 只在特定批次更新键,无论freeze_embedding状态如何
|
||||
self._global_keys_update()
|
||||
# 标记所有键为已更新状态
|
||||
with torch.no_grad():
|
||||
self.has_update_keys.fill_(1)
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
604
model/model_ADMIN_Jun-17-112121-2025_Conflict.py
Normal file
604
model/model_ADMIN_Jun-17-112121-2025_Conflict.py
Normal file
@ -0,0 +1,604 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# 1. 计算token序列的平均嵌入
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
|
||||
# 2. 转换维度
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
|
||||
|
||||
# 3. 将one-hot索引转换为子空间索引
|
||||
indices_x = pre_update_indices // self.num_keys
|
||||
indices_y = pre_update_indices % self.num_keys
|
||||
|
||||
# 4. 收集需要更新的唯一子键
|
||||
unique_x = torch.unique(indices_x)
|
||||
unique_y = torch.unique(indices_y)
|
||||
|
||||
# 5. 更新第一个子空间的键
|
||||
for k1 in unique_x:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k1 = (indices_x == k1)
|
||||
if mask_k1.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k1_embeddings = pre_update_embeddings[mask_k1]
|
||||
k1_avg_embedding = k1_embeddings.mean(dim=0)
|
||||
|
||||
# 拆分为两个子空间并更新第一个子空间
|
||||
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
|
||||
|
||||
# 6. 更新第二个子空间的键
|
||||
for k2 in unique_y:
|
||||
# 找出所有使用该子键的索引
|
||||
mask_k2 = (indices_y == k2)
|
||||
if mask_k2.sum() == 0:
|
||||
continue
|
||||
|
||||
# 获取所有相关嵌入并计算平均值
|
||||
k2_embeddings = pre_update_embeddings[mask_k2]
|
||||
k2_avg_embedding = k2_embeddings.mean(dim=0)
|
||||
|
||||
# 更新第二个子空间
|
||||
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
@ -1,8 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate mini
|
||||
#source $(conda info --base)/etc/profile.d/conda.sh
|
||||
#conda activate mini
|
||||
source /mnt/wcy/miniconda/bin/activate
|
||||
conda activate accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
@ -26,7 +28,7 @@ export PYTHONFAULTHANDLER=1
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
|
||||
CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
|
@ -461,14 +461,14 @@ def main():
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
#########################################################
|
||||
# 初始化accelerator和deepspeed
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user