Compare commits

...

6 Commits

Author SHA1 Message Date
c09cd63794 基于trex数据集构建知识库单初始值预处理 2025-05-23 15:47:17 +08:00
Jax922
45da3b383b DynamicKV-LLM Pretrain v1.1.2 2025-05-23 01:18:08 +08:00
00d3c24e03 构建了数据库解码模型 2025-05-22 11:32:15 +08:00
Jax922
feeccf733c DynamicKV-LLM Pretrain v1.1.1 2025-05-22 10:05:31 +08:00
Jax922
42e3d38a3f 使用变量代替固定值 2025-05-21 08:14:36 +00:00
Gary
d7fe504e1e update 2025-05-16 08:38:59 +00:00
8 changed files with 932 additions and 30 deletions

144
dataset_decoder.py Normal file
View File

@ -0,0 +1,144 @@
import os
import argparse
import torch
from transformers import AutoTokenizer
from model.model import MiniMindLM, ExtractDB
from model.LMConfig import LMConfig
def decode_dataset(model_path, output_path, device="cuda"):
"""
Decode the weight_down_embed buffer in the model to readable text
Args:
model_path: Path to the model checkpoint
output_path: Path to save the decoded text
device: Device to load the model on
"""
print(f"Loading tokenizer from ./model/minimind_tokenizer")
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
print(f"Setting up model configuration")
# Create model configuration matching the training parameters
lm_config = LMConfig(
dim=1024,
n_layers=32,
max_seq_len=1024,
use_flash_attn=True,
knowledge_num=16384, # From the script parameters
knowledge_length=64 # From the script parameters
)
print(f"Initializing model")
model = MiniMindLM(lm_config).to(device)
print(f"Loading model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
# Get model parameters
model_state = dict(model.named_parameters())
model_state.update(dict(model.named_buffers()))
# Find parameters with matching names but different shapes
shape_mismatch = {}
for name, param in model_state.items():
if name in state_dict and param.shape != state_dict[name].shape:
shape_mismatch[name] = (param.shape, state_dict[name].shape)
# Find parameters in model but not in state_dict and vice versa
model_only = set(model_state.keys()) - set(state_dict.keys())
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
# Create filtered state_dict with only compatible parameters
filtered_state_dict = {}
for name, param in state_dict.items():
if name in model_state and param.shape == model_state[name].shape:
filtered_state_dict[name] = param
# Print parameter differences
if shape_mismatch:
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
for name, (model_shape, state_shape) in shape_mismatch.items():
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
if model_only:
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
for name in sorted(model_only):
print(f" {name}: {model_state[name].shape}")
# 特殊处理pos_cis_real参数
if name == "pos_cis_real":
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
if state_dict_only:
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
for name in sorted(state_dict_only):
print(f" {name}: {state_dict[name].shape}")
# 如果checkpoint中有output.weight但模型中没有尝试加载到tok_embeddings
if name == "output.weight" and "tok_embeddings.weight" in model_state:
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
# Load only the compatible parameters
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
model.load_state_dict(filtered_state_dict, strict=False)
# 检查extract_db和weight_down_embed是否存在
if not hasattr(model, "extract_db"):
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
return
print("Accessing weight_down_embed buffer")
# Get the weight_down_embed buffer from the model
try:
weight_down_embed = model.extract_db.weight_down_embed
print(f"Successfully accessed weight_down_embed buffer")
except Exception as e:
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
print(f"Model structure: {model.__class__.__name__}")
print(f"ExtractDB attributes: {dir(model.extract_db)}")
return
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Decoding knowledge and writing to {output_path}")
knowledge_num, knowledge_length = weight_down_embed.shape
with open(output_path, 'w', encoding='utf-8') as f:
for i in range(knowledge_num):
try:
# Get token IDs for this knowledge entry
token_ids = weight_down_embed[i].cpu().tolist()
# Decode tokens to text
text = tokenizer.decode(token_ids, skip_special_tokens=True)
# Write to file
f.write(f"Knowledge_{i}: {text}\n")
# Print progress periodically
if (i + 1) % 100 == 0:
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
except Exception as e:
print(f"Error decoding knowledge entry {i}: {e}")
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
print(f"Decoding completed. Output saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to load the model on")
args = parser.parse_args()
decode_dataset(args.model_path, args.output_path, args.device)

View File

@ -37,8 +37,8 @@ class LMConfig(PretrainedConfig):
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowlwdge_num: int = 64*64,
knowlwdge_length: int = 8,
knowledge_num: int = 64*64,
knowledge_length: int = 8,
**kwargs,
):
self.dim = dim
@ -70,6 +70,6 @@ class LMConfig(PretrainedConfig):
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowlwdge_num = knowlwdge_num
self.knowlwdge_length = knowlwdge_length
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
super().__init__(**kwargs)

View File

@ -515,7 +515,7 @@ class MiniMindBlock(nn.Module):
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out
return out
class ExtractDB(nn.Module):
def __init__(self,params):
@ -524,22 +524,26 @@ class ExtractDB(nn.Module):
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowlwdge_num = params.knowlwdge_num # 100专家确保是完全平方数
self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowlwdge_length*params.dim
self.knowledge_length = params.knowledge_length
# 使用register_buffer代替nn.Parameter避免梯度问题
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
@ -574,12 +578,12 @@ class ExtractDB(nn.Module):
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]
db_value = db_values.view(self.batch_size, -1, self.dim)
return db_value
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
# db_value = db_values.view(self.batch_size,-1)
return db_values
@torch.no_grad()
def updata_value(self, k, v):
def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
@ -604,7 +608,9 @@ class MiniMindLM(PreTrainedModel):
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight
self.database_output.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
@ -623,9 +629,9 @@ class MiniMindLM(PreTrainedModel):
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, 8, kernel_size=1, padding='same')
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
@ -654,9 +660,13 @@ class MiniMindLM(PreTrainedModel):
dtype=h.dtype, device=h.device)
else:
# 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer(
h, db_value, pos_cis_real
)
@ -673,12 +683,27 @@ class MiniMindLM(PreTrainedModel):
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
z_v = self.downsample_v_specific(shared_features)
# Get features from v path - now we output embedding-dimension vectors
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, z_v)
self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])

View File

@ -0,0 +1,734 @@
#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
"""
import json
import os
import glob
from typing import List, Dict, Any, Union
import re
import asyncio
import time
from pydantic import BaseModel, Field
from agno.agent import Agent
from agno.models.ollama import Ollama
class ProcessedSentence(BaseModel):
"""处理后的句子结构"""
corrected_sentence: str = Field(
...,
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
)
importance_score: float = Field(
...,
description="重要性评分范围0.0-10.0以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
ge=0.0,
le=10.0
)
class EnhancedTRExProcessor:
def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True):
self.input_dir = input_dir
self.output_file = output_file
self.max_files = max_files
self.enable_llm_processing = enable_llm_processing
# 初始化agno agent
if self.enable_llm_processing:
self.setup_agent()
# 扩展的Wikidata属性映射
self.property_mappings = {
# 基本关系
"http://www.wikidata.org/prop/direct/P31": "is a",
"http://www.wikidata.org/prop/direct/P279": "is a type of",
# 人物相关
"http://www.wikidata.org/prop/direct/P106": "works as",
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
"http://www.wikidata.org/prop/direct/P19": "was born in",
"http://www.wikidata.org/prop/direct/P20": "died in",
"http://www.wikidata.org/prop/direct/P569": "was born on",
"http://www.wikidata.org/prop/direct/P570": "died on",
"http://www.wikidata.org/prop/direct/P22": "has father",
"http://www.wikidata.org/prop/direct/P25": "has mother",
"http://www.wikidata.org/prop/direct/P26": "is married to",
# 组织相关
"http://www.wikidata.org/prop/direct/P102": "is a member of",
"http://www.wikidata.org/prop/direct/P108": "works for",
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
"http://www.wikidata.org/prop/direct/P112": "was founded by",
"http://www.wikidata.org/prop/direct/P571": "was founded in",
"http://www.wikidata.org/prop/direct/P169": "has CEO",
# 地理相关
"http://www.wikidata.org/prop/direct/P17": "is located in",
"http://www.wikidata.org/prop/direct/P131": "is located in",
"http://www.wikidata.org/prop/direct/P36": "has capital",
"http://www.wikidata.org/prop/direct/P47": "borders",
# 其他关系
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
"http://www.wikidata.org/prop/direct/P361": "is part of",
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
"http://www.wikidata.org/prop/direct/P127": "is owned by",
"http://www.wikidata.org/prop/direct/P155": "follows",
"http://www.wikidata.org/prop/direct/P156": "is followed by",
"http://www.wikidata.org/prop/direct/P138": "is named after"
}
def setup_agent(self):
"""设置agno agent"""
try:
self.agent = Agent(
model=Ollama(
id="qwen3:4b",
# 使用options设置temperature和其他参数
options={
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"num_ctx": 4096,
}
),
response_model=ProcessedSentence,
instructions=[
"你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。",
"",
"### 句子修正规则:",
"1. 移除Wikipedia特有标记如(disambiguation)、(film)、(band)等括号内容",
"2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is''or'",
"3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确",
"4. 清理乱码和特殊字符:如â、€、™等编码问题",
"5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺",
"6. 不要添加原文没有的信息,只修正错误",
"",
"### 修正示例:",
"- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'",
"- 修正:'Argument is related to philosophy and logic.'",
"",
"- 错误:'Beijing is a capital city and are.'",
"- 修正:'Beijing is a capital city.'",
"",
"重要性评分标准0.0-10.0以0.1递进):",
"",
"0.0分 - 完全错误或无意义的信息",
"例:'苹果是一种金属''太阳从西边升起''1+1=3'",
"",
"0.5分 - 几乎无价值的信息",
"例:'某个虚构角色的袜子颜色''游戏中NPC的对话第三句话''某人昨天早餐吃了什么'",
"",
"1.0分 - 极其罕见、无实用价值的知识",
"例:'某小说背景角色宠物名字''某部电影片尾字幕第15行内容''某网站用户ID为123456的昵称'",
"",
"1.5分 - 非常小众的细节信息",
"例:'某电影第37分钟路人甲服装''某游戏隐藏关卡的背景音乐时长''某漫画第200页第3个对话框内容'",
"",
"2.0分 - 小众专业领域的细节",
"例:'稀有矿物在特定温度下颜色变化''某种昆虫的第三对触角长度''某化学反应的副产物分子式'",
"",
"2.5分 - 专业人士才关心的技术细节",
"例:'软件库特定版本发布日期''某算法的时间复杂度系数''某种材料的热膨胀系数'",
"",
"3.0分 - 特定领域的专业知识",
"例:'编程语言语法特性''某种病毒的基因序列''古代某朝代的官职制度'",
"",
"3.5分 - 有一定价值的专业信息",
"例:'某历史朝代特定制度''某种药物的作用机制''某技术标准的制定时间'",
"",
"4.0分 - 较少人知道但有意义的知识",
"例:'某国家独特文化传统''某科学家的重要发现''某历史事件的详细过程'",
"",
"4.5分 - 部分人群感兴趣的知识",
"例:'作家创作背景''某艺术流派特点''某运动项目规则细节'",
"",
"5.0分 - 中等重要性的一般知识",
"例:'城市著名景点''某企业发展历史''某动物生活习性'",
"",
"5.5分 - 比较有用的常识",
"例:'植物生长环境''健康饮食常识''基本急救知识'",
"",
"6.0分 - 多数受教育人群应该知道的知识",
"例:'莎士比亚代表作品''基本几何定理''世界主要货币'",
"",
"6.5分 - 重要的文化或科学常识",
"例:'DNA基本结构''牛顿三大定律''世界主要宗教'",
"",
"7.0分 - 重要的基础知识",
"例:'二次世界大战时间''人体主要器官功能''基本数学运算规则'",
"",
"7.5分 - 非常重要的常识",
"例:'光速是宇宙中最快的''地球是圆的''血液循环基本原理'",
"",
"8.0分 - 基础教育中的核心知识",
"例:'地球绕太阳运行''四季形成原理''基本语法规则'",
"",
"8.5分 - 每个人都应该掌握的重要知识",
"例:'水的化学式H2O''基本安全常识''简单数学计算'",
"",
"9.0分 - 极其重要的基础概念",
"例:'人类需要氧气生存''火是热的''基本方向概念'",
"",
"9.5分 - 人人必知的核心知识",
"例:'一天有24小时''一年有12个月''基本数字概念'",
"",
"10.0分 - 最基础、最重要的常识",
"例:'人类需要食物和水生存''天空是蓝色的''石头比羽毛重'",
"",
"评分时请考虑:",
"1. 知识的普及程度 - 有多少人知道这个知识",
"2. 实用价值 - 这个知识在日常生活中有多大用处",
"3. 教育重要性 - 这个知识在教育体系中的地位",
"4. 文化意义 - 这个知识对理解世界的重要性",
"",
"请直接输出结构化结果,不需要思考过程。"
],
markdown=False
)
print("LLM处理器初始化成功")
except Exception as e:
print(f"LLM处理器初始化失败: {e}")
print("将使用基础模式不使用LLM后处理")
self.enable_llm_processing = False
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
"""使用LLM处理单个句子保留用于单独调用"""
try:
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
# 使用agent.arun进行异步调用
response = await self.agent.arun(prompt)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
return response
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
return ProcessedSentence(
corrected_sentence=message['corrected_sentence'],
importance_score=message['importance_score']
)
except Exception as e:
print(f"LLM处理句子时出错: {e}")
# 出错时返回原句子和中等评分
return ProcessedSentence(
corrected_sentence=sentence,
importance_score=5.0
)
def clean_text(self, text: str) -> str:
"""清理文本,处理特殊字符"""
if not text:
return ""
# 处理常见的Unicode字符
text = text.replace("", "-") # en dash
text = text.replace("", "-") # em dash
text = text.replace("'", "'") # right single quotation mark
text = text.replace("'", "'") # left single quotation mark
text = text.replace(""", '"') # left double quotation mark
text = text.replace(""", '"') # right double quotation mark
# 处理可能的转义序列
try:
text = text.encode('utf-8').decode('utf-8')
except:
pass
# 清理多余的空格
text = re.sub(r'\s+', ' ', text).strip()
# 移除可能的引号
text = text.strip('"\'')
return text
def parse_large_json_file(self, file_path: str) -> List[Dict]:
"""解析大型JSON文件处理可能的格式问题"""
documents = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 尝试不同的解析方法
if content.startswith('[') and content.endswith(']'):
# 标准JSON数组
documents = json.loads(content)
else:
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if '}{"' in content:
json_strings = content.split('}{')
json_strings[0] += '}' # 第一个对象
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
for i in range(1, len(json_strings) - 1):
json_strings[i] = '{' + json_strings[i] + '}'
for json_str in json_strings:
try:
doc = json.loads(json_str)
documents.append(doc)
except json.JSONDecodeError:
continue
else:
# 尝试作为单个JSON对象
try:
documents = [json.loads(content)]
except json.JSONDecodeError:
pass
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return documents
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
"""从文档中提取句子"""
sentences = []
title = self.clean_text(doc.get('title', ''))
text = self.clean_text(doc.get('text', ''))
entities = doc.get('entities', [])
triples = doc.get('triples', [])
# 处理显式三元组
for triple in triples:
sentence = self.triple_to_sentence(triple)
if sentence:
sentences.append(sentence)
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len(sentences) < 5:
# 基于标题和实体生成句子
entity_names = []
for entity in entities[:15]:
entity_name = self.clean_text(entity.get('surfaceform', ''))
if entity_name and len(entity_name) > 2:
entity_names.append(entity_name)
# 生成简单的描述句子
if entity_names:
important_entities = []
title_lower = title.lower()
for entity in entity_names:
if (entity.lower() != title_lower and
entity not in important_entities and
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
important_entities.append(entity)
if len(important_entities) >= 6:
break
if important_entities and len(sentences) < 3:
entities_str = ', '.join(important_entities[:3])
sentences.append(f"{title} is related to {entities_str}.")
return sentences
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
"""将三元组转换为自然语言句子"""
try:
subject = triple.get('subject', {})
predicate = triple.get('predicate', {})
obj = triple.get('object', {})
subject_name = self.clean_text(subject.get('surfaceform', ''))
object_name = self.clean_text(obj.get('surfaceform', ''))
predicate_uri = predicate.get('uri', '')
# 检查是否有有效的主语和宾语
if not subject_name or not object_name:
return ""
# 检查主语和宾语是否过短或无意义
if len(subject_name) <= 2 or len(object_name) <= 2:
return ""
# 获取关系文本
relation_text = self.property_mappings.get(predicate_uri, "is related to")
# 避免重复的主语宾语
if subject_name.lower() == object_name.lower():
return ""
return f"{subject_name} {relation_text} {object_name}."
except Exception as e:
print(f"Error converting triple to sentence: {e}")
return ""
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
"""使用信号量控制并发的LLM处理"""
async with semaphore:
try:
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
# 使用agent.arun进行异步调用
response = await self.agent.arun(prompt)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": response.corrected_sentence,
"importance_score": response.importance_score
}
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
# print(message)
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": message['corrected_sentence'],
"importance_score": message['importance_score']
}
# 打印详细进度信息
if index % 100 == 0:
current_time = time.time()
elapsed_time = current_time - start_time
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
remaining_sentences = total_sentences - (index + 1)
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
# 格式化时间显示
def format_time(seconds):
if seconds < 60:
return f"{seconds:.1f}"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}分钟"
else:
hours = seconds / 3600
return f"{hours:.1f}小时"
print(f"已完成第 {index + 1} 个句子的处理")
print(f" - 剩余句子数: {remaining_sentences}")
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
print(f" - 已用时间: {format_time(elapsed_time)}")
return result
except Exception as e:
print(f"处理第 {index} 个句子时出错: {e}")
# 出错时返回原句子和中等评分
return {
"index": index,
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
}
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""批量并发处理句子每2000条保存一次检查点"""
print(f"开始使用LLM并发处理 {len(sentences)} 个句子最大并发数54...")
# 记录开始时间
start_time = time.time()
total_sentences = len(sentences)
# 分批处理每批2000个句子
batch_size = 2000
all_processed_sentences = []
for batch_start in range(0, total_sentences, batch_size):
batch_end = min(batch_start + batch_size, total_sentences)
batch_sentences = sentences[batch_start:batch_end]
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
# 创建信号量限制并发数
semaphore = asyncio.Semaphore(54)
# 创建当前批次的任务
tasks = []
for i, sentence in enumerate(batch_sentences):
global_index = batch_start + i
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
tasks.append(task)
# 并发执行当前批次的任务
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理当前批次的结果,过滤异常
batch_processed_sentences = []
batch_error_count = 0
for result in batch_results:
if isinstance(result, Exception):
print(f"任务执行异常: {result}")
batch_error_count += 1
elif isinstance(result, dict):
batch_processed_sentences.append(result)
else:
batch_error_count += 1
# 按原始顺序排序(因为并发执行可能改变顺序)
batch_processed_sentences.sort(key=lambda x: x['index'])
# 移除index字段
for item in batch_processed_sentences:
del item['index']
# 添加到总结果中
all_processed_sentences.extend(batch_processed_sentences)
# 保存检查点
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
# 打印当前批次统计信息
elapsed_time = time.time() - start_time
completed_sentences = len(all_processed_sentences)
print(f"{batch_start//batch_size + 1} 批处理完成!")
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
print(f" - 检查点已保存:{checkpoint_filename}")
if batch_end < total_sentences:
remaining_sentences = total_sentences - completed_sentences
avg_time_per_sentence = elapsed_time / completed_sentences
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
# 打印最终统计信息
total_time = time.time() - start_time
print(f"\n=== 全部处理完成!===")
print(f" - 总成功:{len(all_processed_sentences)}")
print(f" - 总用时:{total_time/60:.1f}分钟")
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
return all_processed_sentences
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
"""保存检查点文件"""
# 生成检查点文件名
base_name = os.path.splitext(self.output_file)[0]
checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json"
# 保存检查点
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"total_processed": len(processed_sentences),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"checkpoint_number": current_count
},
"sentences": processed_sentences
}, f, ensure_ascii=False, indent=2)
return checkpoint_filename
async def process_files(self) -> List[Dict[str, Any]]:
"""处理所有文件"""
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return []
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 使用LLM处理句子
if self.enable_llm_processing:
processed_sentences = await self.process_sentences_with_llm(unique_sentences)
else:
# 基础模式不使用LLM
processed_sentences = [
{
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
}
for sentence in unique_sentences
]
return processed_sentences
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
"""保存处理后的句子到文件"""
# 确保输出目录存在
os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True)
# 保存为JSON格式包含完整信息
json_output_file = self.output_file.replace('.txt', '.json')
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
# 保存为简单文本格式(仅修正后的句子)
with open(self.output_file, 'w', encoding='utf-8') as f:
for item in processed_sentences:
f.write(item['corrected_sentence'] + '\n')
# 生成重要性排序文件
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
with open(importance_file, 'w', encoding='utf-8') as f:
for item in importance_sorted:
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
print(f" - JSON格式: {json_output_file}")
print(f" - 文本格式: {self.output_file}")
print(f" - 重要性排序: {importance_file}")
# 统计信息
scores = [item['importance_score'] for item in processed_sentences]
avg_score = sum(scores) / len(scores) if scores else 0
print(f" - 平均重要性评分: {avg_score:.2f}")
print(f" - 最高评分: {max(scores):.1f}")
print(f" - 最低评分: {min(scores):.1f}")
async def run(self):
"""运行处理流程"""
print("Starting enhanced TREx to sentences conversion...")
processed_sentences = await self.process_files()
self.save_sentences(processed_sentences)
print("Enhanced conversion completed!")
def find_latest_checkpoint(self) -> Union[tuple, None]:
"""查找最新的检查点文件"""
base_name = os.path.splitext(self.output_file)[0]
pattern = f"./output/{base_name}_checkpoint_*.json"
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
# 从文件名中提取数字
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
return latest_file, latest_count
else:
return None
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
"""从检查点文件加载已处理的句子"""
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'sentences' in data:
return data['sentences']
else:
# 旧格式的检查点文件
return data
except Exception as e:
print(f"加载检查点文件失败: {e}")
return []
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path')
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available')
args = parser.parse_args()
if not os.path.exists(args.input_dir):
print(f"Error: Input directory {args.input_dir} does not exist!")
return
processor = EnhancedTRExProcessor(
args.input_dir,
args.output_file,
args.max_files,
enable_llm_processing=not args.no_llm
)
# 检查是否要从检查点恢复
if args.resume:
checkpoint_result = processor.find_latest_checkpoint()
if checkpoint_result:
latest_checkpoint, latest_count = checkpoint_result
print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)")
confirm = input("是否从检查点恢复?(y/n): ").lower().strip()
if confirm == 'y':
processed_sentences = processor.load_checkpoint(latest_checkpoint)
if processed_sentences:
print(f"成功加载 {len(processed_sentences)} 条已处理的句子")
processor.save_sentences(processed_sentences)
print("从检查点恢复完成!")
return
else:
print("检查点文件加载失败,将重新开始处理")
else:
print("不从检查点恢复,将重新开始处理")
else:
print("未找到检查点文件,将重新开始处理")
# 运行异步处理
asyncio.run(processor.run())
if __name__ == "__main__":
main()

View File

@ -45,5 +45,5 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowlwdge_num 4096 \
--knowlwdge_length 8
--knowledge_num 4096 \
--knowledge_length 8

View File

@ -46,5 +46,5 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowlwdge_num 1024 \
--knowlwdge_length 8
--knowledge_num 16384 \
--knowledge_length 64

View File

@ -291,7 +291,7 @@ def train_epoch(epoch, wandb):
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
# 加载模型
model = MiniMindLM(lm_config).to(args.device)
@ -349,7 +349,7 @@ if __name__ == "__main__":
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
# 性能分析相关参数
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
@ -406,7 +406,6 @@ if __name__ == "__main__":
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
else:
wandb = None
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None

View File

@ -289,8 +289,8 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args()
#########################################################
@ -327,8 +327,8 @@ def main():
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowlwdge_num=args.knowlwdge_num,
knowlwdge_length=args.knowlwdge_length
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length
)
#########################################################