Compare commits

..

No commits in common. "cy" and "master" have entirely different histories.
cy ... master

12 changed files with 207 additions and 2860 deletions

102
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,102 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Debug Train Pretrain Accelerate",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "512",
"--n_layers", "8",
"--batch_size", "8",
"--epochs", "1"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0,1"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug Train Pretrain Accelerate (Small Test)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "512",
"--n_layers", "8",
"--batch_size", "2",
"--epochs", "1",
"--log_interval", "10",
"--save_interval", "100",
"--accumulation_steps", "4"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0",
"WANDB_MODE": "offline"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug ExtractDB Comparison",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "256",
"--n_layers", "4",
"--batch_size", "2",
"--epochs", "1",
"--log_interval", "10",
"--save_interval", "200",
"--accumulation_steps", "2",
"--comparison_mode",
"--knowledge_num", "256",
"--knowledge_length", "64",
"--comparison_mode"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0",
"WANDB_MODE": "offline"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
}
]
}

18
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,18 @@
{
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.linting.enabled": true,
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.formatting.provider": "black",
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "off",
"files.exclude": {
"**/__pycache__": true,
"**/*.pyc": true,
"**/.git": false,
"**/wandb": false
}
}

View File

@ -1,97 +0,0 @@
import json
import os
import torch
from transformers import AutoTokenizer
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
"""分析database_init.json文件中的数据条目数量和质量"""
print(f"开始分析数据库文件: {json_path}")
# 1. 加载tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(f"成功加载tokenizer: {tokenizer_path}")
except Exception as e:
print(f"加载tokenizer失败: {e}")
return
# 2. 加载JSON文件
try:
with open(json_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
print(f"加载了 {len(sentences_data)} 条sentences数据")
except Exception as e:
print(f"加载JSON文件失败: {e}")
return
# 3. 分析句子长度分布
if len(sentences_data) == 0:
print("没有找到有效的句子数据")
return
# 按照importance_score排序
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
print(f"按importance_score排序完成最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
# 统计句子长度分布
token_lengths = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 4. 分析token长度分布
for i, sentence_data in enumerate(sorted_sentences):
sentence = sentence_data.get('corrected_sentence', '')
if not sentence:
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
continue
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(sentence_tokens))
if i < 5: # 显示前5条数据样例
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
# 5. 统计分析结果
token_lengths = torch.tensor(token_lengths)
stats = {
"总条目数": len(sorted_sentences),
"有效条目数": len(token_lengths),
"token长度-平均值": token_lengths.float().mean().item(),
"token长度-最小值": token_lengths.min().item(),
"token长度-最大值": token_lengths.max().item(),
"token长度-中位数": token_lengths.median().item(),
"token长度-标准差": token_lengths.float().std().item(),
}
# 统计长度分布
length_bins = {
"小于16": (token_lengths < 16).sum().item(),
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
"256及以上": (token_lengths >= 256).sum().item(),
}
# 打印统计信息
print("\n===== 数据库分析结果 =====")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n===== Token长度分布 =====")
for bin_name, count in length_bins.items():
percentage = (count / len(token_lengths)) * 100
print(f"{bin_name}: {count} ({percentage:.1f}%)")
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
return stats, length_bins
if __name__ == "__main__":
# 指定数据库文件路径
database_path = "./dataset/database_init.json"
analyze_database(database_path)

View File

@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database") parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth", parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
help="Path to the model checkpoint") help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt", parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file") help="Path to save the decoded text file")

112
loss.py
View File

@ -1,112 +0,0 @@
import re
import matplotlib.pyplot as plt
import numpy as np
def parse_log_file(file_path):
"""
Parse the training log file to extract epoch, step, and loss information.
"""
# Regular expression to match log entries with loss information
pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)'
epochs = []
steps = []
losses = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
log_content = f.read()
# Find all matches
matches = re.findall(pattern, log_content)
for match in matches:
epoch = int(match[0])
step = int(match[1])
loss = float(match[2])
epochs.append(epoch)
steps.append(step)
losses.append(loss)
return epochs, steps, losses
except Exception as e:
print(f"Error parsing log file: {e}")
return [], [], []
def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'):
"""
Plot the loss curve and save it to a file.
"""
plt.figure(figsize=(12, 6))
# Create continuous steps for better visualization
continuous_steps = []
current_max_step = 0
prev_epoch = 0
for i, (e, s) in enumerate(zip(epochs, steps)):
if e > prev_epoch:
# New epoch starts
if i > 0:
current_max_step = continuous_steps[-1]
prev_epoch = e
continuous_steps.append(s + current_max_step)
# 修改:减小线条宽度和点的大小
plt.plot(continuous_steps, losses, marker='.', linestyle='-',
color='#1f77b4', markersize=3, linewidth=0.8)
plt.title('Training Loss Over Steps', fontsize=16)
plt.xlabel('Steps (Continuous)', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5)
# 修改:减小红线宽度
for i in range(1, len(epochs)):
if epochs[i] > epochs[i-1]:
plt.axvline(x=continuous_steps[i], color='r',
linestyle='--', alpha=0.5, linewidth=0.8)
unique_epochs = sorted(set(epochs))
# Add epoch labels
for e in unique_epochs:
indices = [i for i, epoch in enumerate(epochs) if epoch == e]
if indices:
mid_idx = indices[len(indices) // 2]
plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}',
horizontalalignment='center', verticalalignment='center',
fontsize=10,
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})
# 移除悬停注释,简化图表
# for i, (e, s, l) in enumerate(zip(epochs, steps, losses)):
# plt.annotate(...)
plt.tight_layout()
plt.savefig(output_file, dpi=300)
print(f"Loss curve saved as {output_file}")
# Also display the data in a table format
print("\nExtracted training data:")
print("-" * 50)
print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}")
print("-" * 50)
for e, s, l in zip(epochs, steps, losses):
print(f"{e:<10}{s:<10}{l:<15.6f}")
def main():
# Specify the path to your log file
log_file_path = 'out/train.log'
# Parse the log file
epochs, steps, losses = parse_log_file(log_file_path)
if epochs and steps and losses:
plot_loss_curve(epochs, steps, losses)
else:
print("No data extracted from log file. Please check if the file format is correct.")
if __name__ == "__main__":
main()

View File

@ -2,7 +2,7 @@ import math
import struct import struct
import inspect import inspect
import time import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union from typing import Any, Optional, Tuple, List, Union
import numpy as np import numpy as np
@ -67,21 +67,23 @@ class KnowledgeDataset(nn.Module):
## 数据库参数 ## 数据库参数
self.knowledge_num = params.knowledge_num self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 修改键存储为二维分解空间,设置为可训练参数 # 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.num_keys = int(math.sqrt(self.knowledge_num)) self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度 # 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset', self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)) torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
# 计算step数目用于动态调整权重 # 计算step数目用于动态调整权重
self.step_counter = 0 self.step_counter = 0
# 移除批次计数器和更新频率相关代码 self.freeze_embedding = False
def intelligent_selection(self, query, all_scores, all_indices): def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略""" """智能分层选择策略"""
@ -104,8 +106,7 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices] candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1) flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens) flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1) pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view( pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1 len(unique_indices), self.knowledge_length, -1
@ -157,86 +158,84 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0) all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad(): with torch.no_grad():
# 1. 计算token序列的平均嵌入 pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim] pre_update_embeddings = self.to_queries(pre_update_embeddings)
# 2. 转换维度 self.keys[pre_update_indices] = pre_update_embeddings
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引 def search_index(self,x):
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape batch_size, seq_len, dim = x.shape
# 1. 序列维度平均 # collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim] x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询 queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim] # queries = queries.reshape(batch_size, 2, self.key_dim)
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim] # queries = queries.permute(1, 0, 2)
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度 # 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 4. 在两个子空间分别做top-k # 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] scores, indices = scores_and_indices[0], scores_and_indices[1]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果 # 5. 应用智能分层选择策略
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings return best_tokens, best_tokens_embeddings
@ -523,9 +522,10 @@ class MiniMindLM(PreTrainedModel):
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0: if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制 # 同时冻结KnowledgeDataset的嵌入更新
# self.knowledge_dataset.freeze_embedding = True self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
@ -601,4 +601,3 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:] yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

View File

@ -1,603 +0,0 @@
import math
import struct
import inspect
import time
#子空间不分解+嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.key_dim)
# queries = queries.permute(1, 0, 2)
# 2. 计算queries与keys的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores, indices = scores_and_indices[0], scores_and_indices[1]
# 5. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,675 +0,0 @@
import math
import struct
import inspect
import time
#子空间二维分解+全局嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间
self.num_keys = int(math.sqrt(self.knowledge_num))
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 添加批次计数器和更新频率
self.batch_counter = 0
self.update_frequency = 100 # 每100个批次更新一次
def _global_keys_update(self):
"""全局更新所有子键"""
# 移除对self.freeze_embedding的检查确保在调用时总是执行更新
with torch.no_grad():
# 创建用于存储每个子键的嵌入和计数的张量
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
# 分批处理所有知识条目,避免内存溢出
batch_size = 1000 # 可根据可用内存调整
for i in range(0, self.knowledge_num, batch_size):
end_idx = min(i + batch_size, self.knowledge_num)
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
# 获取批次的嵌入
batch_tokens = self.knowledge_dataset[batch_indices]
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
batch_embeddings = self.to_queries(batch_embeddings)
# 计算批次中每个条目对应的子键索引
indices_x = batch_indices // self.num_keys
indices_y = batch_indices % self.num_keys
# 累加每个子键对应的嵌入
for j in range(len(batch_indices)):
k1, k2 = indices_x[j].item(), indices_y[j].item()
embedding = batch_embeddings[j]
# 更新第一个子空间累加值
k1_embeddings_sum[k1] += embedding[:self.key_dim]
k1_counts[k1] += 1
# 更新第二个子空间累加值
k2_embeddings_sum[k2] += embedding[self.key_dim:]
k2_counts[k2] += 1
# 计算平均值并更新键
# 避免除零错误
k1_counts = torch.clamp(k1_counts, min=1)
k2_counts = torch.clamp(k2_counts, min=1)
# 计算每个子键的平均嵌入
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
print(f"执行了全局键更新,批次: {self.batch_counter}")
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
with torch.no_grad():
# 1. 计算token序列的平均嵌入
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
# 2. 转换维度
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 9. 更新批次计数并在特定批次执行全局更新
if self.is_train:
self.batch_counter += 1
# 每update_frequency个批次执行一次全局更新其余时间保持冻结
if self.batch_counter % self.update_frequency == 0:
# 只在特定批次更新键无论freeze_embedding状态如何
self._global_keys_update()
# 标记所有键为已更新状态
with torch.no_grad():
self.has_update_keys.fill_(1)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,679 +0,0 @@
import math
import struct
import inspect
import time
#子空间四维分解+全局嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
# 修改:子空间维度从原来的一半变为四分之一
self.key_dim = self.knowledge_dim // 4
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改:将键存储从二维分解空间改为四维分解空间
# 计算每个子空间的键数量(使用四次根号N)
self.num_keys = int(self.knowledge_num ** 0.25)
# 修改子空间个数从2变为4
self.keys = nn.Parameter(torch.randn(self.num_keys, 4, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 添加批次计数器和更新频率
self.batch_counter = 0
self.update_frequency = 100 # 每100个批次更新一次
def _global_keys_update(self):
"""全局更新所有子键"""
# 移除对self.freeze_embedding的检查确保在调用时总是执行更新
with torch.no_grad():
# 创建用于存储每个子键的嵌入和计数的张量修改为4个子空间
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k3_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k4_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
k3_counts = torch.zeros(self.num_keys, device=self.keys.device)
k4_counts = torch.zeros(self.num_keys, device=self.keys.device)
# 分批处理所有知识条目,避免内存溢出
batch_size = 1000 # 可根据可用内存调整
for i in range(0, self.knowledge_num, batch_size):
end_idx = min(i + batch_size, self.knowledge_num)
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
# 获取批次的嵌入
batch_tokens = self.knowledge_dataset[batch_indices]
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
batch_embeddings = self.to_queries(batch_embeddings)
# 计算批次中每个条目对应的子键索引修改为4个子空间的索引计算
# 使用整数除法和取模运算来提取四维索引
temp = batch_indices
indices_4 = temp % self.num_keys
temp = temp // self.num_keys
indices_3 = temp % self.num_keys
temp = temp // self.num_keys
indices_2 = temp % self.num_keys
indices_1 = temp // self.num_keys
# 累加每个子键对应的嵌入
for j in range(len(batch_indices)):
k1, k2, k3, k4 = indices_1[j].item(), indices_2[j].item(), indices_3[j].item(), indices_4[j].item()
embedding = batch_embeddings[j]
# 将嵌入分为四份并分别累加到对应的子空间
quarter = self.key_dim
k1_embeddings_sum[k1] += embedding[:quarter]
k1_counts[k1] += 1
k2_embeddings_sum[k2] += embedding[quarter:2*quarter]
k2_counts[k2] += 1
k3_embeddings_sum[k3] += embedding[2*quarter:3*quarter]
k3_counts[k3] += 1
k4_embeddings_sum[k4] += embedding[3*quarter:]
k4_counts[k4] += 1
# 计算平均值并更新键
# 避免除零错误
k1_counts = torch.clamp(k1_counts, min=1)
k2_counts = torch.clamp(k2_counts, min=1)
k3_counts = torch.clamp(k3_counts, min=1)
k4_counts = torch.clamp(k4_counts, min=1)
# 计算每个子键的平均嵌入
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
self.keys[:, 2] = k3_embeddings_sum / k3_counts.unsqueeze(1)
self.keys[:, 3] = k4_embeddings_sum / k4_counts.unsqueeze(1)
print(f"执行了全局键更新,批次: {self.batch_counter}")
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为四个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
# 修改:重塑为四个子查询而非两个
queries = queries.reshape(batch_size, 4, self.key_dim) # [batch_size, 4, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [4, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在四个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(4)]
scores_1, scores_2, scores_3, scores_4 = [scores_and_indices[p][0] for p in range(4)]
indices_1, indices_2, indices_3, indices_4 = [scores_and_indices[p][1] for p in range(4)]
# 5. 组合四个子空间的结果
# 首先组合第一、第二子空间
scores_12 = scores_1.unsqueeze(-1) + scores_2.unsqueeze(-2) # [batch_size, topk, topk]
indices_12_base = (indices_1.unsqueeze(-1) * self.num_keys) + indices_2.unsqueeze(-2) # [batch_size, topk, topk]
# 然后组合第三、第四子空间
scores_34 = scores_3.unsqueeze(-1) + scores_4.unsqueeze(-2) # [batch_size, topk, topk]
indices_34_base = (indices_3.unsqueeze(-1) * self.num_keys) + indices_4.unsqueeze(-2) # [batch_size, topk, topk]
# 最后组合所有子空间
scores_flat_12 = scores_12.reshape(batch_size, -1) # [batch_size, topk*topk]
indices_flat_12 = indices_12_base.reshape(batch_size, -1) # [batch_size, topk*topk]
scores_flat_34 = scores_34.reshape(batch_size, -1) # [batch_size, topk*topk]
indices_flat_34 = indices_34_base.reshape(batch_size, -1) # [batch_size, topk*topk]
# 对12和34组合的结果进行top-k选择
topk_scores_12, topk_indices_12 = scores_flat_12.topk(min(self.product_key_topk, scores_flat_12.size(1)), dim=-1)
topk_indices_12 = torch.gather(indices_flat_12, 1, topk_indices_12)
topk_scores_34, topk_indices_34 = scores_flat_34.topk(min(self.product_key_topk, scores_flat_34.size(1)), dim=-1)
topk_indices_34 = torch.gather(indices_flat_34, 1, topk_indices_34)
# 将12和34的结果组合
all_scores = topk_scores_12.unsqueeze(-1) + topk_scores_34.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (topk_indices_12.unsqueeze(-1) * (self.num_keys**2)) + topk_indices_34.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 9. 更新批次计数并在特定批次执行全局更新
if self.is_train:
self.batch_counter += 1
# 每update_frequency个批次执行一次全局更新其余时间保持冻结
if self.batch_counter % self.update_frequency == 0:
# 只在特定批次更新键无论freeze_embedding状态如何
self._global_keys_update()
# 标记所有键为已更新状态
with torch.no_grad():
self.has_update_keys.fill_(1)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,604 +0,0 @@
import math
import struct
import inspect
import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
return all_best_tokens, all_best_tokens_embeddings
with torch.no_grad():
# 1. 计算token序列的平均嵌入
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
# 2. 转换维度
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

8
run_file/DynamicKV-LLM_Mini_Minimind.sh Executable file → Normal file
View File

@ -1,10 +1,8 @@
#!/bin/bash #!/bin/bash
# 激活conda环境 # 激活conda环境
#source $(conda info --base)/etc/profile.d/conda.sh source $(conda info --base)/etc/profile.d/conda.sh
#conda activate mini conda activate mini
source /mnt/wcy/miniconda/bin/activate
conda activate accelerate
# 设置环境变量以帮助调试 # 设置环境变量以帮助调试
export NCCL_DEBUG=INFO export NCCL_DEBUG=INFO
@ -28,7 +26,7 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10 # --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate # 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch \ CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
--num_processes=1 \ --num_processes=1 \
--mixed_precision=bf16 \ --mixed_precision=bf16 \
--main_process_port=29500 \ --main_process_port=29500 \

View File

@ -461,7 +461,7 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")