Compare commits

...

8 Commits

14 changed files with 4015 additions and 1365 deletions

6
.gitignore vendored
View File

@ -4,4 +4,8 @@
wandb/ wandb/
**/*.log **/*.log
models/sentence_transformers/ models/sentence_transformers/
models/sentence_transformers_cache/ models/sentence_transformers_cache/
**/*.pyc
qwen2-1.7B/
images/
cache/

97
analyze_database.py Normal file
View File

@ -0,0 +1,97 @@
import json
import os
import torch
from transformers import AutoTokenizer
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
"""分析database_init.json文件中的数据条目数量和质量"""
print(f"开始分析数据库文件: {json_path}")
# 1. 加载tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(f"成功加载tokenizer: {tokenizer_path}")
except Exception as e:
print(f"加载tokenizer失败: {e}")
return
# 2. 加载JSON文件
try:
with open(json_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
print(f"加载了 {len(sentences_data)} 条sentences数据")
except Exception as e:
print(f"加载JSON文件失败: {e}")
return
# 3. 分析句子长度分布
if len(sentences_data) == 0:
print("没有找到有效的句子数据")
return
# 按照importance_score排序
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
print(f"按importance_score排序完成最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
# 统计句子长度分布
token_lengths = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 4. 分析token长度分布
for i, sentence_data in enumerate(sorted_sentences):
sentence = sentence_data.get('corrected_sentence', '')
if not sentence:
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
continue
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(sentence_tokens))
if i < 5: # 显示前5条数据样例
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
# 5. 统计分析结果
token_lengths = torch.tensor(token_lengths)
stats = {
"总条目数": len(sorted_sentences),
"有效条目数": len(token_lengths),
"token长度-平均值": token_lengths.float().mean().item(),
"token长度-最小值": token_lengths.min().item(),
"token长度-最大值": token_lengths.max().item(),
"token长度-中位数": token_lengths.median().item(),
"token长度-标准差": token_lengths.float().std().item(),
}
# 统计长度分布
length_bins = {
"小于16": (token_lengths < 16).sum().item(),
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
"256及以上": (token_lengths >= 256).sum().item(),
}
# 打印统计信息
print("\n===== 数据库分析结果 =====")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n===== Token长度分布 =====")
for bin_name, count in length_bins.items():
percentage = (count / len(token_lengths)) * 100
print(f"{bin_name}: {count} ({percentage:.1f}%)")
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
return stats, length_bins
if __name__ == "__main__":
# 指定数据库文件路径
database_path = "./dataset/database_init.json"
analyze_database(database_path)

133
loss.py
View File

@ -1,33 +1,112 @@
import re import re
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
log_file = 'out/train.log' def parse_log_file(file_path):
steps_per_epoch = 58880 # 你需要根据实际日志设置 """
Parse the training log file to extract epoch, step, and loss information.
"""
# Regular expression to match log entries with loss information
pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)'
epochs = []
steps = []
losses = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
log_content = f.read()
# Find all matches
matches = re.findall(pattern, log_content)
for match in matches:
epoch = int(match[0])
step = int(match[1])
loss = float(match[2])
epochs.append(epoch)
steps.append(step)
losses.append(loss)
return epochs, steps, losses
except Exception as e:
print(f"Error parsing log file: {e}")
return [], [], []
with open(log_file, 'r', encoding='utf-8') as f: def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'):
log_text = f.read() """
Plot the loss curve and save it to a file.
"""
plt.figure(figsize=(12, 6))
# Create continuous steps for better visualization
continuous_steps = []
current_max_step = 0
prev_epoch = 0
for i, (e, s) in enumerate(zip(epochs, steps)):
if e > prev_epoch:
# New epoch starts
if i > 0:
current_max_step = continuous_steps[-1]
prev_epoch = e
continuous_steps.append(s + current_max_step)
# 修改:减小线条宽度和点的大小
plt.plot(continuous_steps, losses, marker='.', linestyle='-',
color='#1f77b4', markersize=3, linewidth=0.8)
plt.title('Training Loss Over Steps', fontsize=16)
plt.xlabel('Steps (Continuous)', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5)
# 修改:减小红线宽度
for i in range(1, len(epochs)):
if epochs[i] > epochs[i-1]:
plt.axvline(x=continuous_steps[i], color='r',
linestyle='--', alpha=0.5, linewidth=0.8)
unique_epochs = sorted(set(epochs))
# Add epoch labels
for e in unique_epochs:
indices = [i for i, epoch in enumerate(epochs) if epoch == e]
if indices:
mid_idx = indices[len(indices) // 2]
plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}',
horizontalalignment='center', verticalalignment='center',
fontsize=10,
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})
# 移除悬停注释,简化图表
# for i, (e, s, l) in enumerate(zip(epochs, steps, losses)):
# plt.annotate(...)
plt.tight_layout()
plt.savefig(output_file, dpi=300)
print(f"Loss curve saved as {output_file}")
# Also display the data in a table format
print("\nExtracted training data:")
print("-" * 50)
print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}")
print("-" * 50)
for e, s, l in zip(epochs, steps, losses):
print(f"{e:<10}{s:<10}{l:<15.6f}")
# 提取 epoch, step, loss def main():
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE) # Specify the path to your log file
matches = pattern.findall(log_text) log_file_path = 'out/train.log'
# Parse the log file
epochs, steps, losses = parse_log_file(log_file_path)
if epochs and steps and losses:
plot_loss_curve(epochs, steps, losses)
else:
print("No data extracted from log file. Please check if the file format is correct.")
global_steps = [] if __name__ == "__main__":
losses = [] main()
for epoch, step, loss in matches:
epoch = int(epoch)
step = int(step)
global_step = (epoch - 1) * steps_per_epoch + step
global_steps.append(global_step)
losses.append(float(loss))
plt.figure(figsize=(12, 6))
plt.plot(global_steps, losses, label='Loss')
plt.xlabel('Global Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('out/loss_curve.png')
plt.show()

View File

@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig):
rope_theta: int = 1e6, rope_theta: int = 1e6,
dropout: float = 0.0, dropout: float = 0.0,
flash_attn: bool = True, flash_attn: bool = True,
embeddings_epoch: int = 2,
#################################################### ####################################################
# DB related configurations # DB related configurations
#################################################### ####################################################
@ -39,6 +40,7 @@ class LMConfig(PretrainedConfig):
#################################################### ####################################################
knowledge_num: int = 64*64, knowledge_num: int = 64*64,
knowledge_length: int = 8, knowledge_length: int = 8,
knowledge_dim: int = 128,
**kwargs, **kwargs,
): ):
self.dim = dim self.dim = dim
@ -53,6 +55,7 @@ class LMConfig(PretrainedConfig):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.dropout = dropout self.dropout = dropout
self.flash_attn = flash_attn self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
#################################################### ####################################################
# DB related configurations # DB related configurations
#################################################### ####################################################
@ -72,4 +75,5 @@ class LMConfig(PretrainedConfig):
#################################################### ####################################################
self.knowledge_num = knowledge_num self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -2,7 +2,7 @@ import math
import struct import struct
import inspect import inspect
import time import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union from typing import Any, Optional, Tuple, List, Union
import numpy as np import numpy as np
@ -11,14 +11,9 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@ -31,7 +26,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x) return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore t = torch.arange(end, device=freqs.device) # type: ignore
@ -39,7 +34,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
def apply_rotary_emb(xq, xk, pos_cis): def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x): def unite_shape(pos_cis, x):
ndim = x.ndim ndim = x.ndim
@ -55,200 +50,195 @@ def apply_rotary_emb(xq, xk, pos_cis):
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。 class KnowledgeDataset(nn.Module):
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6): def __init__(self, params, tok_embeddings, is_train=True):
"""使用实数张量实现位置编码,避免使用复数张量
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
return pos_emb
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量
原始函数将输入张量转换为复数形式与位置编码相乘然后再转回实数形式
这个函数直接使用实数运算实现相同的旋转操作
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 将head_dim分成两半
half_head_dim = head_dim // 2
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
return xq_out.type_as(xq), xk_out.type_as(xk)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__() super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.is_train = is_train
assert args.n_heads % self.n_kv_heads == 0 self.params = params
self.n_local_heads = args.n_heads self.tok_embeddings = tok_embeddings
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, # 嵌入参数
x: torch.Tensor, self.knowledge_dim = params.knowledge_dim
pos_cis: torch.Tensor, self.key_dim = self.knowledge_dim // 2
db_value=None): self.to_queries = nn.Sequential(
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 nn.Linear(params.dim, self.knowledge_dim, bias=False),
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# kv_cache实现 REMOVED
# if past_key_value is not None:
# xk = torch.cat([past_key_value[0], xk], dim=1)
# xv = torch.cat([past_key_value[1], xv], dim=1)
# past_kv = (xk, xv) if use_cache else None
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
) )
# 如果提供了db_value根据头的数量调整它的形状并与xv合并 ## 数据库参数
if db_value is not None: self.knowledge_num = params.knowledge_num
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D] self.knowledge_length = params.knowledge_length
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D] # 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 检查是否需要调整D维度 # 计算step数目用于动态调整权重
if db_value.shape[-1] != xv.shape[-1]: self.step_counter = 0
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合 # 移除批次计数器和更新频率相关代码
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention def intelligent_selection(self, query, all_scores, all_indices):
if self.flash and seq_len != 1: """智能分层选择策略"""
dropout_p = self.dropout if self.training else 0.0 if self.is_train == False:
output = F.scaled_dot_product_attention( return all_scores, all_indices
xq, xk, xv,
attn_mask=None, batch_size = all_scores.size(0)
dropout_p=dropout_p, device = all_scores.device
is_causal=True dtype = all_scores.dtype
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
return all_best_tokens, all_best_tokens_embeddings
with torch.no_grad():
# 1. 计算token序列的平均嵌入
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
# 2. 转换维度
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__( def __init__(
@ -295,6 +285,58 @@ class CrossAttention(nn.Module):
return context return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, config: LMConfig): def __init__(self, config: LMConfig):
super().__init__() super().__init__()
@ -427,183 +469,30 @@ class MOEFeedForward(nn.Module):
class MiniMindBlock(nn.Module): class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig): def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__() super().__init__()
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.dim = config.dim self.dim = config.dim
self.head_dim = config.dim // config.n_heads self.head_dim = config.dim // config.n_heads
self.attention = Attention(config) self.self_attention = Attention(config)
self.cross_att = CrossAttention(config) self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
# 假设num_experts是已定义的总专家数量的平方根 def forward(self, x, pos_cis):
h_attn = self.self_attention(
# 查询生成的参数
# 创建查询生成模块
# if weight_down_embed is not None:
# self.to_queries = nn.Sequential(
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
# )
# # 超参数
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis):
# import pdb;pdb.set_trace()
# db_value = None
# # 如果有weight_down_embed使用Product Key机制
# if self.weight_down_embed is not None:
# # 1. 生成queries
# batch_size, seq_len, dim = x.shape
# # collapse sequence dimension by averaging
# x_flat = x.mean(dim=1) # [batch_size, dim]
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# # 2. 计算queries与keys的相似度
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# # 3. 在两个子空间分别做top-k
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# # 4. 组合两个子空间的分数和索引
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# # 5. 最终top-k选择
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
# indices = all_indices.gather(-1, pk_indices)
# # 6. 从embedding中获取专家值
# # 从embedding中获取值
# flat_indices = indices.view(-1) # 将索引展平为一维张量
# db_values = self.weight_down_embed(flat_indices)
# # 重塑回原始形状
# db_value = db_values.view(batch_size, -1, dim)
# 注意力计算
h_attn = self.attention(
self.attention_norm(x), self.attention_norm(x),
pos_cis, pos_cis
db_value=db_value
) )
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_att(h_attn, db_value) h_attn = self.cross_attention(h_attn, db_embeddings)
# 残差连接
h = x + h_attn h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h)) out = h + self.feed_forward(self.ffn_norm(h))
return out return out
class ExtractDB(nn.Module):
def __init__(self,params):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowledge_length
# 使用register_buffer代替nn.Parameter避免梯度问题
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 最终top-k选择
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
# db_value = db_values.view(self.batch_size,-1)
return db_values
@torch.no_grad()
def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
@torch.no_grad()
def update_keys_with_zq(self, flat_indices, z_q):
"""
flat_indices: [batch]q_to_k输出的检索到的key的全局索引0~knowledge_num-1
z_q: [batch, 2, dim_key]每个样本的两个子空间query
"""
num_keys = self.num_keys
idx_x = flat_indices // num_keys # [batch]
idx_y = flat_indices % num_keys # [batch]
# 对于每个样本把keys的两个子空间分别替换为z_q的对应部分
for i in range(flat_indices.size(0)):
self.keys.data[idx_x[i], 0, :] = z_q[i, 0, :].to(self.keys.dtype)
self.keys.data[idx_y[i], 1, :] = z_q[i, 1, :].to(self.keys.dtype)
class MiniMindLM(PreTrainedModel): class MiniMindLM(PreTrainedModel):
@ -615,115 +504,35 @@ class MiniMindLM(PreTrainedModel):
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout) self.dropout = nn.Dropout(params.dropout)
# 移除旧的weight_down_embed声明 self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.extract_db = ExtractDB(self.params) self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight self.tok_embeddings.weight = self.output.weight
self.database_output.weight = self.output.weight self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, self.params.dim, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False) persistent=False)
self.params = params self.OUT = CausalLMOutputWithPast()
self.value_update_schedule = 0.9 # 前%冻结 self.freeze_embedding = False
self.global_step = 0 # 当前步数
self.total_steps = None # 总步数,训练脚本里赋值
def forward(self, def forward(self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args): **args):
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)] pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
h_list = []
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
# 禁用数据库模式,使用固定值替代数据库查询
if self.params.disable_db:
# 创建一个形状为[batch_size, n_layers, dim]的tensor所有元素值为1e-4
batch_size = h.size(0)
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
dtype=h.dtype, device=h.device)
else:
# 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer( h = layer(
h, db_value, pos_cis_real h, pos_cis
) )
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
# Get features from v path
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
z_v_flat = z_v_features.reshape(-1, dim_z)
token_logits = self.database_output(z_v_flat)
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path
z_q_input = self.downsample_q_specific(shared_features) # [batch, dim, seq_len]
z_q_input = z_q_input.permute(0, 2, 1) # [batch, seq_len, dim]
z_k = self.extract_db.q_to_k(z_q_input) # [batch]
z_q_pooled = z_q_input.mean(dim=1) # [batch, dim]
z_q_vec = self.extract_db.to_queries(z_q_pooled) # [batch, 2*dim_key]
z_q_vec = z_q_vec.view(z_q_vec.size(0), 2, self.extract_db.dim_key) # [batch, 2, dim_key]
progress = self.global_step / self.total_steps if self.total_steps else 0
if progress >= self.value_update_schedule:
self.extract_db.updata_value(z_k, token_indices)
self.extract_db.update_keys_with_zq(z_k, z_q_vec)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :]) logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
@ -736,12 +545,6 @@ class MiniMindLM(PreTrainedModel):
output.aux_loss = aux_loss output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# except:
# pass
return output return output
@torch.inference_mode() @torch.inference_mode()
@ -774,13 +577,14 @@ class MiniMindLM(PreTrainedModel):
return res return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq = input_ids.shape[1], True start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1: while input_ids.shape[1] < max_new_tokens - 1:
if first_seq: if first_seq:
out, first_seq = self(input_ids, **args), False out, first_seq = self(input_ids, **args), False
else: else:
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args) out = self(input_ids[:, -1:],
logits = out.logits[:, -1, :] start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9) logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
@ -798,4 +602,3 @@ class MiniMindLM(PreTrainedModel):
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

603
model/model0.py Normal file
View File

@ -0,0 +1,603 @@
import math
import struct
import inspect
import time
#子空间不分解+嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.key_dim)
# queries = queries.permute(1, 0, 2)
# 2. 计算queries与keys的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores, indices = scores_and_indices[0], scores_and_indices[1]
# 5. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

675
model/model1.py Normal file
View File

@ -0,0 +1,675 @@
import math
import struct
import inspect
import time
#子空间二维分解+全局嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间
self.num_keys = int(math.sqrt(self.knowledge_num))
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 添加批次计数器和更新频率
self.batch_counter = 0
self.update_frequency = 100 # 每100个批次更新一次
def _global_keys_update(self):
"""全局更新所有子键"""
# 移除对self.freeze_embedding的检查确保在调用时总是执行更新
with torch.no_grad():
# 创建用于存储每个子键的嵌入和计数的张量
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
# 分批处理所有知识条目,避免内存溢出
batch_size = 1000 # 可根据可用内存调整
for i in range(0, self.knowledge_num, batch_size):
end_idx = min(i + batch_size, self.knowledge_num)
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
# 获取批次的嵌入
batch_tokens = self.knowledge_dataset[batch_indices]
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
batch_embeddings = self.to_queries(batch_embeddings)
# 计算批次中每个条目对应的子键索引
indices_x = batch_indices // self.num_keys
indices_y = batch_indices % self.num_keys
# 累加每个子键对应的嵌入
for j in range(len(batch_indices)):
k1, k2 = indices_x[j].item(), indices_y[j].item()
embedding = batch_embeddings[j]
# 更新第一个子空间累加值
k1_embeddings_sum[k1] += embedding[:self.key_dim]
k1_counts[k1] += 1
# 更新第二个子空间累加值
k2_embeddings_sum[k2] += embedding[self.key_dim:]
k2_counts[k2] += 1
# 计算平均值并更新键
# 避免除零错误
k1_counts = torch.clamp(k1_counts, min=1)
k2_counts = torch.clamp(k2_counts, min=1)
# 计算每个子键的平均嵌入
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
print(f"执行了全局键更新,批次: {self.batch_counter}")
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
with torch.no_grad():
# 1. 计算token序列的平均嵌入
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
# 2. 转换维度
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 9. 更新批次计数并在特定批次执行全局更新
if self.is_train:
self.batch_counter += 1
# 每update_frequency个批次执行一次全局更新其余时间保持冻结
if self.batch_counter % self.update_frequency == 0:
# 只在特定批次更新键无论freeze_embedding状态如何
self._global_keys_update()
# 标记所有键为已更新状态
with torch.no_grad():
self.has_update_keys.fill_(1)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

679
model/model2.py Normal file
View File

@ -0,0 +1,679 @@
import math
import struct
import inspect
import time
#子空间四维分解+全局嵌入更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
# 修改:子空间维度从原来的一半变为四分之一
self.key_dim = self.knowledge_dim // 4
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改:将键存储从二维分解空间改为四维分解空间
# 计算每个子空间的键数量(使用四次根号N)
self.num_keys = int(self.knowledge_num ** 0.25)
# 修改子空间个数从2变为4
self.keys = nn.Parameter(torch.randn(self.num_keys, 4, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 添加批次计数器和更新频率
self.batch_counter = 0
self.update_frequency = 100 # 每100个批次更新一次
def _global_keys_update(self):
"""全局更新所有子键"""
# 移除对self.freeze_embedding的检查确保在调用时总是执行更新
with torch.no_grad():
# 创建用于存储每个子键的嵌入和计数的张量修改为4个子空间
k1_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k2_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k3_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k4_embeddings_sum = torch.zeros(self.num_keys, self.key_dim, device=self.keys.device)
k1_counts = torch.zeros(self.num_keys, device=self.keys.device)
k2_counts = torch.zeros(self.num_keys, device=self.keys.device)
k3_counts = torch.zeros(self.num_keys, device=self.keys.device)
k4_counts = torch.zeros(self.num_keys, device=self.keys.device)
# 分批处理所有知识条目,避免内存溢出
batch_size = 1000 # 可根据可用内存调整
for i in range(0, self.knowledge_num, batch_size):
end_idx = min(i + batch_size, self.knowledge_num)
batch_indices = torch.arange(i, end_idx, device=self.keys.device)
# 获取批次的嵌入
batch_tokens = self.knowledge_dataset[batch_indices]
batch_embeddings = self.tok_embeddings(batch_tokens.view(-1))
batch_embeddings = batch_embeddings.view(len(batch_indices), self.knowledge_length, -1).mean(dim=1)
batch_embeddings = self.to_queries(batch_embeddings)
# 计算批次中每个条目对应的子键索引修改为4个子空间的索引计算
# 使用整数除法和取模运算来提取四维索引
temp = batch_indices
indices_4 = temp % self.num_keys
temp = temp // self.num_keys
indices_3 = temp % self.num_keys
temp = temp // self.num_keys
indices_2 = temp % self.num_keys
indices_1 = temp // self.num_keys
# 累加每个子键对应的嵌入
for j in range(len(batch_indices)):
k1, k2, k3, k4 = indices_1[j].item(), indices_2[j].item(), indices_3[j].item(), indices_4[j].item()
embedding = batch_embeddings[j]
# 将嵌入分为四份并分别累加到对应的子空间
quarter = self.key_dim
k1_embeddings_sum[k1] += embedding[:quarter]
k1_counts[k1] += 1
k2_embeddings_sum[k2] += embedding[quarter:2*quarter]
k2_counts[k2] += 1
k3_embeddings_sum[k3] += embedding[2*quarter:3*quarter]
k3_counts[k3] += 1
k4_embeddings_sum[k4] += embedding[3*quarter:]
k4_counts[k4] += 1
# 计算平均值并更新键
# 避免除零错误
k1_counts = torch.clamp(k1_counts, min=1)
k2_counts = torch.clamp(k2_counts, min=1)
k3_counts = torch.clamp(k3_counts, min=1)
k4_counts = torch.clamp(k4_counts, min=1)
# 计算每个子键的平均嵌入
self.keys[:, 0] = k1_embeddings_sum / k1_counts.unsqueeze(1)
self.keys[:, 1] = k2_embeddings_sum / k2_counts.unsqueeze(1)
self.keys[:, 2] = k3_embeddings_sum / k3_counts.unsqueeze(1)
self.keys[:, 3] = k4_embeddings_sum / k4_counts.unsqueeze(1)
print(f"执行了全局键更新,批次: {self.batch_counter}")
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为四个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
# 修改:重塑为四个子查询而非两个
queries = queries.reshape(batch_size, 4, self.key_dim) # [batch_size, 4, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [4, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在四个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(4)]
scores_1, scores_2, scores_3, scores_4 = [scores_and_indices[p][0] for p in range(4)]
indices_1, indices_2, indices_3, indices_4 = [scores_and_indices[p][1] for p in range(4)]
# 5. 组合四个子空间的结果
# 首先组合第一、第二子空间
scores_12 = scores_1.unsqueeze(-1) + scores_2.unsqueeze(-2) # [batch_size, topk, topk]
indices_12_base = (indices_1.unsqueeze(-1) * self.num_keys) + indices_2.unsqueeze(-2) # [batch_size, topk, topk]
# 然后组合第三、第四子空间
scores_34 = scores_3.unsqueeze(-1) + scores_4.unsqueeze(-2) # [batch_size, topk, topk]
indices_34_base = (indices_3.unsqueeze(-1) * self.num_keys) + indices_4.unsqueeze(-2) # [batch_size, topk, topk]
# 最后组合所有子空间
scores_flat_12 = scores_12.reshape(batch_size, -1) # [batch_size, topk*topk]
indices_flat_12 = indices_12_base.reshape(batch_size, -1) # [batch_size, topk*topk]
scores_flat_34 = scores_34.reshape(batch_size, -1) # [batch_size, topk*topk]
indices_flat_34 = indices_34_base.reshape(batch_size, -1) # [batch_size, topk*topk]
# 对12和34组合的结果进行top-k选择
topk_scores_12, topk_indices_12 = scores_flat_12.topk(min(self.product_key_topk, scores_flat_12.size(1)), dim=-1)
topk_indices_12 = torch.gather(indices_flat_12, 1, topk_indices_12)
topk_scores_34, topk_indices_34 = scores_flat_34.topk(min(self.product_key_topk, scores_flat_34.size(1)), dim=-1)
topk_indices_34 = torch.gather(indices_flat_34, 1, topk_indices_34)
# 将12和34的结果组合
all_scores = topk_scores_12.unsqueeze(-1) + topk_scores_34.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (topk_indices_12.unsqueeze(-1) * (self.num_keys**2)) + topk_indices_34.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 9. 更新批次计数并在特定批次执行全局更新
if self.is_train:
self.batch_counter += 1
# 每update_frequency个批次执行一次全局更新其余时间保持冻结
if self.batch_counter % self.update_frequency == 0:
# 只在特定批次更新键无论freeze_embedding状态如何
self._global_keys_update()
# 标记所有键为已更新状态
with torch.no_grad():
self.has_update_keys.fill_(1)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -0,0 +1,604 @@
import math
import struct
import inspect
import time
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
return all_best_tokens, all_best_tokens_embeddings
with torch.no_grad():
# 1. 计算token序列的平均嵌入
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [num_indices, dim]
# 2. 转换维度
pre_update_embeddings = self.to_queries(pre_update_embeddings) # [num_indices, knowledge_dim]
# 3. 将one-hot索引转换为子空间索引
indices_x = pre_update_indices // self.num_keys
indices_y = pre_update_indices % self.num_keys
# 4. 收集需要更新的唯一子键
unique_x = torch.unique(indices_x)
unique_y = torch.unique(indices_y)
# 5. 更新第一个子空间的键
for k1 in unique_x:
# 找出所有使用该子键的索引
mask_k1 = (indices_x == k1)
if mask_k1.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k1_embeddings = pre_update_embeddings[mask_k1]
k1_avg_embedding = k1_embeddings.mean(dim=0)
# 拆分为两个子空间并更新第一个子空间
self.keys[k1, 0] = k1_avg_embedding[:self.key_dim]
# 6. 更新第二个子空间的键
for k2 in unique_y:
# 找出所有使用该子键的索引
mask_k2 = (indices_y == k2)
if mask_k2.sum() == 0:
continue
# 获取所有相关嵌入并计算平均值
k2_embeddings = pre_update_embeddings[mask_k2]
k2_avg_embedding = k2_embeddings.mean(dim=0)
# 更新第二个子空间
self.keys[k2, 1] = k2_avg_embedding[self.key_dim:]
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -4,10 +4,35 @@
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子 1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分 2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
## 🆕 防卡死机制
为了解决LLM处理时可能出现的卡死问题新增了以下功能
### 超时和重试机制
- **超时时间**每个LLM请求60秒超时
- **重试机制**失败后最多重试2次采用指数退避策略
- **并发控制**降低并发数至4个减少服务器压力
### 心跳监控系统
- **实时监控**每30秒检查一次LLM响应状态
- **异常警告**超过30秒无成功响应时发出警告
- **服务检测**自动检查ollama服务状态
- **详细统计**:实时显示成功率、超时率等统计信息
### 日志系统
- **详细日志**:所有操作都记录在 `logs/` 目录下
- **双重输出**:同时输出到日志文件和控制台
- **时间戳标记**:日志文件包含启动时间戳
### 改进的错误处理
- **异常恢复**LLM处理失败时使用原句子和默认评分
- **状态监控**处理前检查ollama服务状态
- **批次间休息**批次之间休息5秒避免过度压力
## 安装依赖 ## 安装依赖
```bash ```bash
pip install agno asyncio pydantic pip install agno asyncio pydantic requests
``` ```
确保已安装并启动 ollama并下载 qwen3:4b 模型: 确保已安装并启动 ollama并下载 qwen3:4b 模型:
@ -50,24 +75,52 @@ python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json
## 输出文件 ## 输出文件
**注意:所有输出文件都会自动保存在 `./output/` 目录中** **注意:所有输出文件都会自动保存在相应目录中**
### 步骤1输出 ### 句子提取输出
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据 - `output/extracted_sentences.json`: 提取的原始句子,包含元数据
### 步骤2输出 ### LLM处理输出
- `output/{output_file}.txt`: 修正后的句子文本文件 - `output/{output_file}.txt`: 修正后的句子文本文件
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分) - `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子 - `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
### 检查点文件 ### 检查点文件
- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点 - `output/{output_file}_checkpoint_{数量}.json`: 每1000条句子自动保存的检查点
### 日志文件
- `logs/trex_processor_{时间戳}.log`: 详细的处理日志
## 🆕 故障诊断
### 如果遇到卡死问题:
1. **检查日志文件**:查看 `logs/` 目录下的最新日志
2. **观察心跳监控**:注意控制台的心跳警告信息
3. **检查ollama服务**
```bash
ps aux | grep ollama
curl http://localhost:11434/api/tags
```
4. **重启ollama服务**(如果需要):
```bash
pkill ollama
ollama serve &
```
### 常见警告信息:
- `⚠️ 心跳检测`: 30秒无成功响应正常情况下会自动恢复
- `❌ 严重警告`: 90秒无成功响应可能需要检查服务
- `💀 Ollama服务异常`: ollama服务可能已停止
- `💀 致命错误`: 连续多次警告(建议重启程序)
## 检查点恢复机制 ## 检查点恢复机制
- 步骤2会自动检测已有的检查点文件`output/` 目录中) - 步骤2会自动检测已有的检查点文件`output/` 目录中)
- 只处理尚未处理的句子,避免重复工作 - 只处理尚未处理的句子,避免重复工作
- 如果所有句子都已处理,会直接生成最终输出文件 - 如果所有句子都已处理,会直接生成最终输出文件
- 中断后重新运行会自动从最新检查点继续
## 示例工作流 ## 示例工作流
@ -84,14 +137,18 @@ python trex_to_sentences_simple.py --step llm
## 性能特点 ## 性能特点
- **并发处理**: 最大54个并发LLM请求 - **保守的并发**: 最大4个并发LLM请求降低卡死风险
- **检查点保存**: 每2000条句子自动保存支持断点续传 - **检查点保存**: 每1000条句子自动保存支持断点续传
- **进度显示**: 详细的处理进度和时间预估 - **智能监控**: 详细的处理进度和时间预估
- **错误处理**: LLM请求失败时使用原句子和默认评分 - **健壮的错误处理**: LLM请求失败时使用原句子和默认评分
- **服务监控**: 自动检测ollama服务状态
## 注意事项 ## 注意事项
1. 首次运行步骤2前必须先完成步骤1 1. 首次运行步骤2前必须先完成步骤1
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据) 2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
3. LLM处理速度取决于模型性能和网络状况 3. LLM处理速度取决于模型性能和网络状况
4. 建议先用`--max_files`参数测试小批量数据 4. 建议先用`--max_files`参数测试小批量数据
5. **新增**:如果遇到卡死,查看日志文件和心跳监控信息
6. **新增**程序会自动检测并报告ollama服务状态
7. **新增**:所有处理过程都有详细日志记录,便于问题诊断

View File

@ -0,0 +1,225 @@
#!/usr/bin/env python3
"""
JSON文件合并脚本
读取多个JSON文件并合并为一个JSON文件
"""
import json
import os
from typing import Dict, List, Any, Union
# 需要合并的JSON文件列表
JSON_FILES_TO_MERGE = [
"output/trex_sentences_enhanced_checkpoint_360000.json"
]
for i in range(1, 1010):
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
def load_json_file(file_path: str) -> Union[Dict, List, None]:
"""加载JSON文件"""
if not os.path.exists(file_path):
print(f"警告: 文件 {file_path} 不存在")
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载: {file_path}")
return data
except json.JSONDecodeError as e:
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
return None
except Exception as e:
print(f"错误: 读取文件 {file_path} 失败 - {e}")
return None
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
"""合并两个JSON数据结构"""
# 如果两个都是列表,直接合并
if isinstance(data1, list) and isinstance(data2, list):
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)}")
return data1 + data2
# 如果两个都是字典
elif isinstance(data1, dict) and isinstance(data2, dict):
print("合并两个字典结构")
merged = data1.copy()
# 特殊处理:如果都有'sentences'字段且为列表合并sentences
if 'sentences' in data1 and 'sentences' in data2:
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])}")
merged['sentences'] = data1['sentences'] + data2['sentences']
# 更新metadata if exists
if 'metadata' in merged:
if isinstance(merged['metadata'], dict):
merged['metadata']['total_sentences'] = len(merged['sentences'])
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
# 合并其他字段
for key, value in data2.items():
if key != 'sentences' and key not in merged:
merged[key] = value
return merged
# 普通字典合并
for key, value in data2.items():
if key in merged:
# 如果key重复且都是列表合并列表
if isinstance(merged[key], list) and isinstance(value, list):
merged[key] = merged[key] + value
# 如果key重复且都是字典递归合并
elif isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = merge_json_data(merged[key], value)
else:
# 其他情况保留第二个文件的值
merged[key] = value
print(f"字段 '{key}' 被覆盖")
else:
merged[key] = value
return merged
# 类型不匹配的情况,创建一个包含两者的新结构
else:
print("数据类型不匹配,创建包含两者的新结构")
return {
"data_from_save.json": data1,
"data_from_save2.json": data2,
"merged_at": "test.py"
}
def save_merged_json(data: Union[Dict, List], output_path: str):
"""保存合并后的JSON数据"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"合并结果已保存到: {output_path}")
# 显示统计信息
if isinstance(data, dict):
if 'sentences' in data and isinstance(data['sentences'], list):
print(f"总计句子数: {len(data['sentences'])}")
print(f"总计字段数: {len(data)}")
elif isinstance(data, list):
print(f"总计列表项数: {len(data)}")
except Exception as e:
print(f"错误: 保存文件失败 - {e}")
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
"""从合并结果中移除重复的句子(基于句子内容)"""
if isinstance(data, dict) and 'sentences' in data:
if isinstance(data['sentences'], list):
original_count = len(data['sentences'])
seen_sentences = set()
unique_sentences = []
for item in data['sentences']:
if isinstance(item, dict):
# 如果是字典使用sentence字段或corrected_sentence字段作为唯一标识
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
elif isinstance(item, str):
sentence_key = item
else:
sentence_key = str(item)
if sentence_key and sentence_key not in seen_sentences:
seen_sentences.add(sentence_key)
unique_sentences.append(item)
data['sentences'] = unique_sentences
# 更新metadata
if 'metadata' in data and isinstance(data['metadata'], dict):
data['metadata']['total_sentences'] = len(unique_sentences)
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
return data
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
"""合并多个JSON数据结构"""
if not data_list:
return {}
if len(data_list) == 1:
return data_list[0]
print(f"准备合并 {len(data_list)} 个JSON数据结构")
# 从第一个数据开始,逐步合并其他数据
merged_data = data_list[0]
for i, data in enumerate(data_list[1:], 1):
print(f"正在合并第 {i+1} 个数据结构...")
merged_data = merge_json_data(merged_data, data)
return merged_data
def main():
"""主函数"""
print("=== JSON文件合并脚本 ===")
# 输出路径
output_path = "output/merged.json"
print(f"准备合并以下文件:")
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
print(f" {i}. {file_path}")
print(f"输出文件: {output_path}")
print()
# 加载所有文件
loaded_data = []
successfully_loaded = []
for file_path in JSON_FILES_TO_MERGE:
data = load_json_file(file_path)
if data is not None:
loaded_data.append(data)
successfully_loaded.append(file_path)
# 检查是否至少有一个文件加载成功
if not loaded_data:
print("错误: 没有文件能够成功加载,退出")
return
print(f"成功加载了 {len(loaded_data)} 个文件:")
for file_path in successfully_loaded:
print(f"{file_path}")
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
print(f"警告: {failed_count} 个文件加载失败")
print()
# 合并所有数据
if len(loaded_data) == 1:
print("只有一个文件可用,直接使用...")
merged_data = loaded_data[0]
else:
print("开始合并所有文件...")
merged_data = merge_multiple_json_data(loaded_data)
# 去重处理
print("\n检查并去除重复项...")
merged_data = remove_duplicates_from_sentences(merged_data)
# 保存合并结果
print("\n保存合并结果...")
save_merged_json(merged_data, output_path)
print("\n=== 合并完成 ===")
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,10 @@
#!/bin/bash #!/bin/bash
# 激活conda环境 # 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh #source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate #conda activate mini
source /mnt/wcy/miniconda/bin/activate
conda activate accelerate
# 设置环境变量以帮助调试 # 设置环境变量以帮助调试
export NCCL_DEBUG=INFO export NCCL_DEBUG=INFO
@ -26,24 +28,9 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10 # --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate # 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \ CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch \
--num_processes=1 \ --num_processes=1 \
--mixed_precision=bf16 \ --mixed_precision=bf16 \
--main_process_port=29500 \ --main_process_port=29500 \
train_pretrain_accelerate.py \ train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 12 \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowledge_num 4096 \
--knowledge_length 8

View File

@ -74,8 +74,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
nn.init.ones_(module.weight) nn.init.ones_(module.weight)
# 初始化位置编码相关参数 # 初始化位置编码相关参数
if hasattr(model.extract_db, 'keys'): if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02) nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed") Logger("Default model initialization completed")
@ -88,329 +88,130 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
if database_init_path: if database_init_path:
import json import json
import numpy as np
from sentence_transformers import SentenceTransformer
import os import os
Logger(f"Loading database initialization data from {database_init_path}") # 数据库参数
# 1. 加载JSON文件并转换为字典
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
# 3. 下载并初始化本地嵌入模型
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
embedding_cache_dir = "./models/sentence_transformers/cache"
os.makedirs(embedding_cache_dir, exist_ok=True)
Logger(f"Loading embedding model: {embedding_model_name}")
try:
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
Logger("Embedding model loaded successfully")
except Exception as e:
Logger(f"Failed to load embedding model: {e}")
Logger("Falling back to random embeddings")
embedding_model = None
# 4. 对每个corrected_sentence进行嵌入和token长度计算
Logger("Processing sentences for embeddings and token lengths...")
# 提取所有句子
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
# 批量计算token长度
Logger("Computing token lengths...")
token_lengths = []
for sentence in sentences:
tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(tokens))
# 批量计算嵌入 - 大幅提升速度
Logger("Computing embeddings in batches...")
embeddings_list = []
batch_size = 256 # 可以根据GPU内存调整
if embedding_model is not None:
try:
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i+batch_size]
batch_embeddings = embedding_model.encode(
batch_sentences,
convert_to_tensor=False,
show_progress_bar=True if i == 0 else False,
batch_size=batch_size
)
embeddings_list.extend(batch_embeddings)
if (i + batch_size) % (batch_size * 10) == 0:
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
Logger("Batch embedding computation completed")
except Exception as e:
Logger(f"Error in batch encoding: {e}")
Logger("Falling back to random embeddings")
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
else:
# 使用随机嵌入
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
# 创建处理后的句子列表
processed_sentences = []
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
processed_sentences.append({
'sentence': sentence_data.get('corrected_sentence', ''),
'importance_score': sentence_data.get('importance_score', 0.0),
'token_length': token_length,
'embedding': embedding, # Convert numpy array to list
'original_index': i
})
# # Create a JSON-serializable version for saving
# json_serializable_sentences = []
# for sentence in processed_sentences:
# json_sentence = sentence.copy()
# # Convert embedding to list if it's a numpy array
# if hasattr(json_sentence['embedding'], 'tolist'):
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
# json_serializable_sentences.append(json_sentence)
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
# 转换为numpy数组以便后续处理
embeddings_array = np.array(embeddings_list)
token_lengths_array = np.array(token_lengths)
Logger(f"Embedding processing completed:")
Logger(f" - Total sentences: {len(processed_sentences)}")
Logger(f" - Embedding shape: {embeddings_array.shape}")
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
# 2. 聚类处理 - 优化版本
Logger("Starting optimized clustering process...")
# 聚类参数
knowledge_num = args.knowledge_num knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length knowledge_length = args.knowledge_length
min_tokens = int(0.85 * knowledge_length)
max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) # 检查是否使用缓存
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 cache_dir = os.path.dirname(args.cluster_cache_path)
Logger("Pre-computing similarity matrix for faster clustering...") if cache_dir:
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) os.makedirs(cache_dir, exist_ok=True)
similarity_matrix = cosine_similarity(embeddings_matrix)
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
else:
similarity_matrix = None
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
clustered_rows = [] processed_tensor = None
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens") # 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
# 选择聚类算法 try:
if args.fast_clustering and len(processed_sentences) > 5000: Logger(f"Loading cached processed results from {args.cluster_cache_path}")
Logger("Using ultra-fast approximate clustering algorithm...") processed_tensor = torch.load(args.cluster_cache_path)
# 超快速聚类:随机采样 + 批量处理 # 验证缓存文件的形状是否可用
import random cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
random.seed(42) # 确保可重现性
if cached_knowledge_length == knowledge_length:
# 按重要性分层采样 if cached_knowledge_num >= knowledge_num:
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7] # 缓存足够大,可以截取使用
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7] processed_tensor = processed_tensor[:knowledge_num, :]
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3] Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}") Logger("Skipping database initialization - using cached results")
else:
for cluster_idx in tqdm(range(knowledge_num)): # 缓存太小,需要重新计算
# 分层选择种子:优先选择高重要性句子 Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
if high_importance: processed_tensor = None
seed_pool = high_importance
elif medium_importance:
seed_pool = medium_importance
else: else:
seed_pool = low_importance if low_importance else list(range(len(processed_sentences))) # knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
if not seed_pool: processed_tensor = None
break except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
# 随机选择种子(在同一重要性层级内) processed_tensor = None
seed_global_idx = random.choice(seed_pool)
seed_sentence = processed_sentences[seed_global_idx]
# 从所有池中移除种子
for pool in [high_importance, medium_importance, low_importance]:
if seed_global_idx in pool:
pool.remove(seed_global_idx)
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens < max_tokens:
# 快速选择:只从附近的句子中随机选择
all_remaining = high_importance + medium_importance + low_importance
if all_remaining:
# 随机采样候选句子(而不是计算所有相似度)
sample_size = min(100, len(all_remaining))
candidates = random.sample(all_remaining, sample_size)
# 简单按token长度和重要性选择
for candidate_idx in candidates:
candidate = processed_sentences[candidate_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens:
current_cluster_indices.append(candidate_idx)
current_tokens += candidate_tokens + 1
# 从池中移除
for pool in [high_importance, medium_importance, low_importance]:
if candidate_idx in pool:
pool.remove(candidate_idx)
break
if current_tokens >= min_tokens:
break
# 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens)
if (cluster_idx + 1) % 1000 == 0:
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
else: # 只有在没有有效缓存时才进行数据库初始化和处理
# 原始优化算法(适用于中等规模数据集) if processed_tensor is None:
# 优化2: 批量处理和更高效的数据结构 Logger(f"Loading database initialization data from {database_init_path}")
for cluster_idx in tqdm(range(knowledge_num)):
if not remaining_indices: # 1. 加载JSON文件
Logger(f"No more sentences available. Created {cluster_idx} clusters.") with open(database_init_path, 'r', encoding='utf-8') as f:
break database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
sentence = sentence_data.get('corrected_sentence', '')
# 2.1 选择importance_score最高的句子作为种子 # 将句子转换为tokens
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices] sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
key=lambda i: remaining_sentences_subset[i]['importance_score'])
seed_global_idx = remaining_indices[seed_idx_in_subset]
seed_sentence = processed_sentences[seed_global_idx]
# 从剩余索引中移除种子
remaining_indices.remove(seed_global_idx)
# 当前聚类
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens >= max_tokens:
# 如果种子句子已经超过最大token数直接作为一个聚类
cluster_text = seed_sentence['sentence']
else:
# 2.2 优化的相似度计算和选择
if remaining_indices:
if similarity_matrix is not None:
# 使用预计算的相似度矩阵
similarities = similarity_matrix[seed_global_idx][remaining_indices]
else:
# 动态计算相似度(批量)
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
remaining_embeddings = embeddings_matrix[remaining_indices]
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
similarity_tuples = [(similarities[i], remaining_indices[i], i)
for i in range(len(remaining_indices))]
# 按相似度排序(降序)
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
# 优化3: 贪心选择,但限制搜索范围以提高速度
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
selected_indices_in_remaining = []
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
candidate = processed_sentences[global_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
current_cluster_indices.append(global_idx)
selected_indices_in_remaining.append(pos_in_remaining)
current_tokens += candidate_tokens + 1
if current_tokens >= min_tokens:
break
# 批量移除选中的句子(从后往前移除以避免索引问题)
for pos in sorted(selected_indices_in_remaining, reverse=True):
remaining_indices.pop(pos)
# 拼接句子
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences)
# 将聚类文本转换为token
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
# 截断或填充到knowledge_length # 截断或填充到knowledge_length
if len(cluster_tokens) > knowledge_length: if len(sentence_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length] # 如果超过长度,截断
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else: else:
# 用pad_token_id填充 # 如果不足长度用空token填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 original_length = len(sentence_tokens)
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
clustered_rows.append(cluster_tokens) processed_rows.append(sentence_tokens)
# 优化4: 减少日志频率 if (i + 1) % 1000 == 0:
if (cluster_idx + 1) % 500 == 0: Logger(f"Processed {i + 1}/{num_to_process} sentences")
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 如果聚类数量不足用随机token填充 # 4. 初始化模型的knowledge_dataset
while len(clustered_rows) < knowledge_num: if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
random_tokens = [pad_token_id] * knowledge_length Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
clustered_rows.append(random_tokens)
# 转换为tensor
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
Logger(f"Clustering completed:")
Logger(f" - Created {len(clustered_rows)} clusters")
Logger(f" - Cluster shape: {clustered_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 3. 初始化模型的weight_down_embed
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
else: else:
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize") Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选 # 存储为全局变量作为备选
globals()['clustered_database'] = clustered_tensor globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model") Logger(f"Database embeddings and sentences stored in model")
@ -423,6 +224,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
total_steps_in_epoch = len(train_loader) total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else '' moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 添加CUDA事件来分析性能 (只在主进程进行) # 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process:
@ -486,7 +288,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 前向传播 # 前向传播
with ctx: with ctx:
res = model(X) if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
loss = loss_fct( loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)), res.logits.view(-1, res.logits.size(-1)),
Y.view(-1) Y.view(-1)
@ -610,7 +417,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
wandb.log(log_dict) wandb.log(log_dict)
# 保存模型 (只在主进程进行) # 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process: loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量 # 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -629,21 +438,22 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
def main(): def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate") parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=24) parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48) parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--accumulation_steps", type=int, default=32) parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=10000) parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=1024, type=int) parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=32, type=int) parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int) parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
@ -651,12 +461,14 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
args = parser.parse_args() args = parser.parse_args()
######################################################### #########################################################
# 初始化accelerator和deepspeed # 初始化accelerator和deepspeed
######################################################### #########################################################
@ -692,7 +504,8 @@ def main():
disable_db=args.disable_db, disable_db=args.disable_db,
flash_attn=args.use_flash_attn, flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num, knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
) )
######################################################### #########################################################