Compare commits
6 Commits
5841f8b4e5
...
c09cd63794
Author | SHA1 | Date | |
---|---|---|---|
c09cd63794 | |||
![]() |
45da3b383b | ||
00d3c24e03 | |||
![]() |
feeccf733c | ||
![]() |
42e3d38a3f | ||
![]() |
d7fe504e1e |
144
dataset_decoder.py
Normal file
144
dataset_decoder.py
Normal file
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from model.model import MiniMindLM, ExtractDB
|
||||
from model.LMConfig import LMConfig
|
||||
|
||||
def decode_dataset(model_path, output_path, device="cuda"):
|
||||
"""
|
||||
Decode the weight_down_embed buffer in the model to readable text
|
||||
|
||||
Args:
|
||||
model_path: Path to the model checkpoint
|
||||
output_path: Path to save the decoded text
|
||||
device: Device to load the model on
|
||||
"""
|
||||
print(f"Loading tokenizer from ./model/minimind_tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
|
||||
print(f"Setting up model configuration")
|
||||
# Create model configuration matching the training parameters
|
||||
lm_config = LMConfig(
|
||||
dim=1024,
|
||||
n_layers=32,
|
||||
max_seq_len=1024,
|
||||
use_flash_attn=True,
|
||||
knowledge_num=16384, # From the script parameters
|
||||
knowledge_length=64 # From the script parameters
|
||||
)
|
||||
|
||||
print(f"Initializing model")
|
||||
model = MiniMindLM(lm_config).to(device)
|
||||
|
||||
print(f"Loading model weights from {model_path}")
|
||||
state_dict = torch.load(model_path, map_location=device)
|
||||
|
||||
# Get model parameters
|
||||
model_state = dict(model.named_parameters())
|
||||
model_state.update(dict(model.named_buffers()))
|
||||
|
||||
# Find parameters with matching names but different shapes
|
||||
shape_mismatch = {}
|
||||
for name, param in model_state.items():
|
||||
if name in state_dict and param.shape != state_dict[name].shape:
|
||||
shape_mismatch[name] = (param.shape, state_dict[name].shape)
|
||||
|
||||
# Find parameters in model but not in state_dict and vice versa
|
||||
model_only = set(model_state.keys()) - set(state_dict.keys())
|
||||
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
|
||||
|
||||
# Create filtered state_dict with only compatible parameters
|
||||
filtered_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in model_state and param.shape == model_state[name].shape:
|
||||
filtered_state_dict[name] = param
|
||||
|
||||
# Print parameter differences
|
||||
if shape_mismatch:
|
||||
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
|
||||
for name, (model_shape, state_shape) in shape_mismatch.items():
|
||||
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
|
||||
|
||||
if model_only:
|
||||
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
|
||||
for name in sorted(model_only):
|
||||
print(f" {name}: {model_state[name].shape}")
|
||||
|
||||
# 特殊处理pos_cis_real参数
|
||||
if name == "pos_cis_real":
|
||||
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
|
||||
|
||||
if state_dict_only:
|
||||
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
|
||||
for name in sorted(state_dict_only):
|
||||
print(f" {name}: {state_dict[name].shape}")
|
||||
|
||||
# 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings
|
||||
if name == "output.weight" and "tok_embeddings.weight" in model_state:
|
||||
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
|
||||
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
|
||||
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
|
||||
|
||||
# Load only the compatible parameters
|
||||
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
|
||||
model.load_state_dict(filtered_state_dict, strict=False)
|
||||
|
||||
# 检查extract_db和weight_down_embed是否存在
|
||||
if not hasattr(model, "extract_db"):
|
||||
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
|
||||
return
|
||||
|
||||
print("Accessing weight_down_embed buffer")
|
||||
# Get the weight_down_embed buffer from the model
|
||||
try:
|
||||
weight_down_embed = model.extract_db.weight_down_embed
|
||||
print(f"Successfully accessed weight_down_embed buffer")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
|
||||
print(f"Model structure: {model.__class__.__name__}")
|
||||
print(f"ExtractDB attributes: {dir(model.extract_db)}")
|
||||
return
|
||||
|
||||
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
|
||||
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
print(f"Decoding knowledge and writing to {output_path}")
|
||||
knowledge_num, knowledge_length = weight_down_embed.shape
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for i in range(knowledge_num):
|
||||
try:
|
||||
# Get token IDs for this knowledge entry
|
||||
token_ids = weight_down_embed[i].cpu().tolist()
|
||||
|
||||
# Decode tokens to text
|
||||
text = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
|
||||
# Write to file
|
||||
f.write(f"Knowledge_{i}: {text}\n")
|
||||
|
||||
# Print progress periodically
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
|
||||
except Exception as e:
|
||||
print(f"Error decoding knowledge entry {i}: {e}")
|
||||
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
|
||||
|
||||
print(f"Decoding completed. Output saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
||||
help="Path to the model checkpoint")
|
||||
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||
help="Path to save the decoded text file")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to load the model on")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
decode_dataset(args.model_path, args.output_path, args.device)
|
@ -37,8 +37,8 @@ class LMConfig(PretrainedConfig):
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
####################################################
|
||||
knowlwdge_num: int = 64*64,
|
||||
knowlwdge_length: int = 8,
|
||||
knowledge_num: int = 64*64,
|
||||
knowledge_length: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
self.dim = dim
|
||||
@ -70,6 +70,6 @@ class LMConfig(PretrainedConfig):
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
####################################################
|
||||
self.knowlwdge_num = knowlwdge_num
|
||||
self.knowlwdge_length = knowlwdge_length
|
||||
self.knowledge_num = knowledge_num
|
||||
self.knowledge_length = knowledge_length
|
||||
super().__init__(**kwargs)
|
||||
|
@ -515,7 +515,7 @@ class MiniMindBlock(nn.Module):
|
||||
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
return out
|
||||
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
@ -524,22 +524,26 @@ class ExtractDB(nn.Module):
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.knowlwdge_num = params.knowlwdge_num # 100专家,确保是完全平方数
|
||||
self.knowledge_num = params.knowledge_num # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_length = params.knowlwdge_length*params.dim
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
|
||||
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
|
||||
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
|
||||
|
||||
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def q_to_k(self,x):
|
||||
# 1. 生成queries
|
||||
self.batch_size, seq_len, dim = x.shape
|
||||
@ -574,12 +578,12 @@ class ExtractDB(nn.Module):
|
||||
|
||||
def get_data(self, index):
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
|
||||
# db_value = db_values.view(self.batch_size,-1)
|
||||
return db_values
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):
|
||||
def updata_value(self, k, v):#要加一个从向量返回index的过程
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
@ -604,7 +608,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.database_output.weight = self.output.weight
|
||||
|
||||
# Calculate input dimension
|
||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
||||
@ -623,9 +629,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
# Specific layers for v path
|
||||
self.downsample_v_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
||||
nn.Conv1d(128, 8, kernel_size=1, padding='same')
|
||||
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
@ -654,9 +660,13 @@ class MiniMindLM(PreTrainedModel):
|
||||
dtype=h.dtype, device=h.device)
|
||||
else:
|
||||
# 正常模式,使用数据库查询
|
||||
# import pdb;pdb.set_trace()
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
token_idx = self.extract_db.get_data(index) #这里是index
|
||||
|
||||
db_value =self.tok_embeddings(token_idx)
|
||||
|
||||
h = layer(
|
||||
h, db_value, pos_cis_real
|
||||
)
|
||||
@ -673,12 +683,27 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
z_v = self.downsample_v_specific(shared_features)
|
||||
|
||||
# Get features from v path - now we output embedding-dimension vectors
|
||||
z_v_features = self.downsample_v_specific(shared_features)
|
||||
batch_z, seq_len, dim_z = z_v_features.shape
|
||||
|
||||
# Reshape to batch_size * knowledge_length, dim
|
||||
z_v_flat = z_v_features.reshape(-1, dim_z)
|
||||
|
||||
# Direct token prediction - like the main language model head
|
||||
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
|
||||
# Get token indices directly from logits
|
||||
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
||||
token_indices = token_indices_flat.reshape(batch_z, -1)
|
||||
|
||||
# Process query path as before
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, z_v)
|
||||
self.extract_db.updata_value(z_k, token_indices)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
|
734
preprocessing/trex_to_sentences_simple.py
Normal file
734
preprocessing/trex_to_sentences_simple.py
Normal file
@ -0,0 +1,734 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TREx数据集增强预处理脚本
|
||||
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Any, Union
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
from pydantic import BaseModel, Field
|
||||
from agno.agent import Agent
|
||||
from agno.models.ollama import Ollama
|
||||
|
||||
|
||||
class ProcessedSentence(BaseModel):
|
||||
"""处理后的句子结构"""
|
||||
corrected_sentence: str = Field(
|
||||
...,
|
||||
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
|
||||
)
|
||||
importance_score: float = Field(
|
||||
...,
|
||||
description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
|
||||
ge=0.0,
|
||||
le=10.0
|
||||
)
|
||||
|
||||
|
||||
class EnhancedTRExProcessor:
|
||||
def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True):
|
||||
self.input_dir = input_dir
|
||||
self.output_file = output_file
|
||||
self.max_files = max_files
|
||||
self.enable_llm_processing = enable_llm_processing
|
||||
|
||||
# 初始化agno agent
|
||||
if self.enable_llm_processing:
|
||||
self.setup_agent()
|
||||
|
||||
# 扩展的Wikidata属性映射
|
||||
self.property_mappings = {
|
||||
# 基本关系
|
||||
"http://www.wikidata.org/prop/direct/P31": "is a",
|
||||
"http://www.wikidata.org/prop/direct/P279": "is a type of",
|
||||
|
||||
# 人物相关
|
||||
"http://www.wikidata.org/prop/direct/P106": "works as",
|
||||
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
|
||||
"http://www.wikidata.org/prop/direct/P19": "was born in",
|
||||
"http://www.wikidata.org/prop/direct/P20": "died in",
|
||||
"http://www.wikidata.org/prop/direct/P569": "was born on",
|
||||
"http://www.wikidata.org/prop/direct/P570": "died on",
|
||||
"http://www.wikidata.org/prop/direct/P22": "has father",
|
||||
"http://www.wikidata.org/prop/direct/P25": "has mother",
|
||||
"http://www.wikidata.org/prop/direct/P26": "is married to",
|
||||
|
||||
# 组织相关
|
||||
"http://www.wikidata.org/prop/direct/P102": "is a member of",
|
||||
"http://www.wikidata.org/prop/direct/P108": "works for",
|
||||
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
|
||||
"http://www.wikidata.org/prop/direct/P112": "was founded by",
|
||||
"http://www.wikidata.org/prop/direct/P571": "was founded in",
|
||||
"http://www.wikidata.org/prop/direct/P169": "has CEO",
|
||||
|
||||
# 地理相关
|
||||
"http://www.wikidata.org/prop/direct/P17": "is located in",
|
||||
"http://www.wikidata.org/prop/direct/P131": "is located in",
|
||||
"http://www.wikidata.org/prop/direct/P36": "has capital",
|
||||
"http://www.wikidata.org/prop/direct/P47": "borders",
|
||||
|
||||
# 其他关系
|
||||
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
|
||||
"http://www.wikidata.org/prop/direct/P361": "is part of",
|
||||
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
|
||||
"http://www.wikidata.org/prop/direct/P127": "is owned by",
|
||||
"http://www.wikidata.org/prop/direct/P155": "follows",
|
||||
"http://www.wikidata.org/prop/direct/P156": "is followed by",
|
||||
"http://www.wikidata.org/prop/direct/P138": "is named after"
|
||||
}
|
||||
|
||||
def setup_agent(self):
|
||||
"""设置agno agent"""
|
||||
try:
|
||||
self.agent = Agent(
|
||||
model=Ollama(
|
||||
id="qwen3:4b",
|
||||
# 使用options设置temperature和其他参数
|
||||
options={
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.8,
|
||||
"top_k": 20,
|
||||
"num_ctx": 4096,
|
||||
}
|
||||
),
|
||||
response_model=ProcessedSentence,
|
||||
instructions=[
|
||||
"你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。",
|
||||
"",
|
||||
"### 句子修正规则:",
|
||||
"1. 移除Wikipedia特有标记:如(disambiguation)、(film)、(band)等括号内容",
|
||||
"2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is'、'or'等",
|
||||
"3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确",
|
||||
"4. 清理乱码和特殊字符:如â、€、™等编码问题",
|
||||
"5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺",
|
||||
"6. 不要添加原文没有的信息,只修正错误",
|
||||
"",
|
||||
"### 修正示例:",
|
||||
"- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'",
|
||||
"- 修正:'Argument is related to philosophy and logic.'",
|
||||
"",
|
||||
"- 错误:'Beijing is a capital city and are.'",
|
||||
"- 修正:'Beijing is a capital city.'",
|
||||
"",
|
||||
"重要性评分标准(0.0-10.0,以0.1递进):",
|
||||
"",
|
||||
"0.0分 - 完全错误或无意义的信息",
|
||||
"例:'苹果是一种金属'、'太阳从西边升起'、'1+1=3'",
|
||||
"",
|
||||
"0.5分 - 几乎无价值的信息",
|
||||
"例:'某个虚构角色的袜子颜色'、'游戏中NPC的对话第三句话'、'某人昨天早餐吃了什么'",
|
||||
"",
|
||||
"1.0分 - 极其罕见、无实用价值的知识",
|
||||
"例:'某小说背景角色宠物名字'、'某部电影片尾字幕第15行内容'、'某网站用户ID为123456的昵称'",
|
||||
"",
|
||||
"1.5分 - 非常小众的细节信息",
|
||||
"例:'某电影第37分钟路人甲服装'、'某游戏隐藏关卡的背景音乐时长'、'某漫画第200页第3个对话框内容'",
|
||||
"",
|
||||
"2.0分 - 小众专业领域的细节",
|
||||
"例:'稀有矿物在特定温度下颜色变化'、'某种昆虫的第三对触角长度'、'某化学反应的副产物分子式'",
|
||||
"",
|
||||
"2.5分 - 专业人士才关心的技术细节",
|
||||
"例:'软件库特定版本发布日期'、'某算法的时间复杂度系数'、'某种材料的热膨胀系数'",
|
||||
"",
|
||||
"3.0分 - 特定领域的专业知识",
|
||||
"例:'编程语言语法特性'、'某种病毒的基因序列'、'古代某朝代的官职制度'",
|
||||
"",
|
||||
"3.5分 - 有一定价值的专业信息",
|
||||
"例:'某历史朝代特定制度'、'某种药物的作用机制'、'某技术标准的制定时间'",
|
||||
"",
|
||||
"4.0分 - 较少人知道但有意义的知识",
|
||||
"例:'某国家独特文化传统'、'某科学家的重要发现'、'某历史事件的详细过程'",
|
||||
"",
|
||||
"4.5分 - 部分人群感兴趣的知识",
|
||||
"例:'作家创作背景'、'某艺术流派特点'、'某运动项目规则细节'",
|
||||
"",
|
||||
"5.0分 - 中等重要性的一般知识",
|
||||
"例:'城市著名景点'、'某企业发展历史'、'某动物生活习性'",
|
||||
"",
|
||||
"5.5分 - 比较有用的常识",
|
||||
"例:'植物生长环境'、'健康饮食常识'、'基本急救知识'",
|
||||
"",
|
||||
"6.0分 - 多数受教育人群应该知道的知识",
|
||||
"例:'莎士比亚代表作品'、'基本几何定理'、'世界主要货币'",
|
||||
"",
|
||||
"6.5分 - 重要的文化或科学常识",
|
||||
"例:'DNA基本结构'、'牛顿三大定律'、'世界主要宗教'",
|
||||
"",
|
||||
"7.0分 - 重要的基础知识",
|
||||
"例:'二次世界大战时间'、'人体主要器官功能'、'基本数学运算规则'",
|
||||
"",
|
||||
"7.5分 - 非常重要的常识",
|
||||
"例:'光速是宇宙中最快的'、'地球是圆的'、'血液循环基本原理'",
|
||||
"",
|
||||
"8.0分 - 基础教育中的核心知识",
|
||||
"例:'地球绕太阳运行'、'四季形成原理'、'基本语法规则'",
|
||||
"",
|
||||
"8.5分 - 每个人都应该掌握的重要知识",
|
||||
"例:'水的化学式H2O'、'基本安全常识'、'简单数学计算'",
|
||||
"",
|
||||
"9.0分 - 极其重要的基础概念",
|
||||
"例:'人类需要氧气生存'、'火是热的'、'基本方向概念'",
|
||||
"",
|
||||
"9.5分 - 人人必知的核心知识",
|
||||
"例:'一天有24小时'、'一年有12个月'、'基本数字概念'",
|
||||
"",
|
||||
"10.0分 - 最基础、最重要的常识",
|
||||
"例:'人类需要食物和水生存'、'天空是蓝色的'、'石头比羽毛重'",
|
||||
"",
|
||||
"评分时请考虑:",
|
||||
"1. 知识的普及程度 - 有多少人知道这个知识",
|
||||
"2. 实用价值 - 这个知识在日常生活中有多大用处",
|
||||
"3. 教育重要性 - 这个知识在教育体系中的地位",
|
||||
"4. 文化意义 - 这个知识对理解世界的重要性",
|
||||
"",
|
||||
"请直接输出结构化结果,不需要思考过程。"
|
||||
],
|
||||
markdown=False
|
||||
)
|
||||
print("LLM处理器初始化成功")
|
||||
except Exception as e:
|
||||
print(f"LLM处理器初始化失败: {e}")
|
||||
print("将使用基础模式(不使用LLM后处理)")
|
||||
self.enable_llm_processing = False
|
||||
|
||||
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
|
||||
"""使用LLM处理单个句子(保留用于单独调用)"""
|
||||
try:
|
||||
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
||||
|
||||
# 使用agent.arun进行异步调用
|
||||
response = await self.agent.arun(prompt)
|
||||
|
||||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||
if isinstance(response, ProcessedSentence):
|
||||
return response
|
||||
else:
|
||||
message = response.messages[-1].content
|
||||
message = message.replace("```json", "").replace("```", "")
|
||||
message = json.loads(message)
|
||||
return ProcessedSentence(
|
||||
corrected_sentence=message['corrected_sentence'],
|
||||
importance_score=message['importance_score']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM处理句子时出错: {e}")
|
||||
# 出错时返回原句子和中等评分
|
||||
return ProcessedSentence(
|
||||
corrected_sentence=sentence,
|
||||
importance_score=5.0
|
||||
)
|
||||
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""清理文本,处理特殊字符"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 处理常见的Unicode字符
|
||||
text = text.replace("–", "-") # en dash
|
||||
text = text.replace("—", "-") # em dash
|
||||
text = text.replace("'", "'") # right single quotation mark
|
||||
text = text.replace("'", "'") # left single quotation mark
|
||||
text = text.replace(""", '"') # left double quotation mark
|
||||
text = text.replace(""", '"') # right double quotation mark
|
||||
|
||||
# 处理可能的转义序列
|
||||
try:
|
||||
text = text.encode('utf-8').decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
|
||||
# 清理多余的空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
# 移除可能的引号
|
||||
text = text.strip('"\'')
|
||||
|
||||
return text
|
||||
|
||||
def parse_large_json_file(self, file_path: str) -> List[Dict]:
|
||||
"""解析大型JSON文件,处理可能的格式问题"""
|
||||
documents = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 尝试不同的解析方法
|
||||
if content.startswith('[') and content.endswith(']'):
|
||||
# 标准JSON数组
|
||||
documents = json.loads(content)
|
||||
else:
|
||||
# 可能是连续的JSON对象
|
||||
# 尝试在}{"之间分割
|
||||
if '}{"' in content:
|
||||
json_strings = content.split('}{')
|
||||
json_strings[0] += '}' # 第一个对象
|
||||
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
|
||||
|
||||
for i in range(1, len(json_strings) - 1):
|
||||
json_strings[i] = '{' + json_strings[i] + '}'
|
||||
|
||||
for json_str in json_strings:
|
||||
try:
|
||||
doc = json.loads(json_str)
|
||||
documents.append(doc)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
# 尝试作为单个JSON对象
|
||||
try:
|
||||
documents = [json.loads(content)]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing {file_path}: {e}")
|
||||
|
||||
return documents
|
||||
|
||||
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
|
||||
"""从文档中提取句子"""
|
||||
sentences = []
|
||||
|
||||
title = self.clean_text(doc.get('title', ''))
|
||||
text = self.clean_text(doc.get('text', ''))
|
||||
entities = doc.get('entities', [])
|
||||
triples = doc.get('triples', [])
|
||||
|
||||
# 处理显式三元组
|
||||
for triple in triples:
|
||||
sentence = self.triple_to_sentence(triple)
|
||||
if sentence:
|
||||
sentences.append(sentence)
|
||||
|
||||
# 从实体和文本中生成基本句子(如果三元组句子不够)
|
||||
if title and text and len(sentences) < 5:
|
||||
# 基于标题和实体生成句子
|
||||
entity_names = []
|
||||
for entity in entities[:15]:
|
||||
entity_name = self.clean_text(entity.get('surfaceform', ''))
|
||||
if entity_name and len(entity_name) > 2:
|
||||
entity_names.append(entity_name)
|
||||
|
||||
# 生成简单的描述句子
|
||||
if entity_names:
|
||||
important_entities = []
|
||||
title_lower = title.lower()
|
||||
for entity in entity_names:
|
||||
if (entity.lower() != title_lower and
|
||||
entity not in important_entities and
|
||||
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
|
||||
important_entities.append(entity)
|
||||
if len(important_entities) >= 6:
|
||||
break
|
||||
|
||||
if important_entities and len(sentences) < 3:
|
||||
entities_str = ', '.join(important_entities[:3])
|
||||
sentences.append(f"{title} is related to {entities_str}.")
|
||||
|
||||
return sentences
|
||||
|
||||
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
|
||||
"""将三元组转换为自然语言句子"""
|
||||
try:
|
||||
subject = triple.get('subject', {})
|
||||
predicate = triple.get('predicate', {})
|
||||
obj = triple.get('object', {})
|
||||
|
||||
subject_name = self.clean_text(subject.get('surfaceform', ''))
|
||||
object_name = self.clean_text(obj.get('surfaceform', ''))
|
||||
predicate_uri = predicate.get('uri', '')
|
||||
|
||||
# 检查是否有有效的主语和宾语
|
||||
if not subject_name or not object_name:
|
||||
return ""
|
||||
|
||||
# 检查主语和宾语是否过短或无意义
|
||||
if len(subject_name) <= 2 or len(object_name) <= 2:
|
||||
return ""
|
||||
|
||||
# 获取关系文本
|
||||
relation_text = self.property_mappings.get(predicate_uri, "is related to")
|
||||
|
||||
# 避免重复的主语宾语
|
||||
if subject_name.lower() == object_name.lower():
|
||||
return ""
|
||||
|
||||
return f"{subject_name} {relation_text} {object_name}."
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error converting triple to sentence: {e}")
|
||||
return ""
|
||||
|
||||
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
|
||||
"""使用信号量控制并发的LLM处理"""
|
||||
async with semaphore:
|
||||
try:
|
||||
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
||||
|
||||
# 使用agent.arun进行异步调用
|
||||
response = await self.agent.arun(prompt)
|
||||
|
||||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||
if isinstance(response, ProcessedSentence):
|
||||
result = {
|
||||
"index": index,
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": response.corrected_sentence,
|
||||
"importance_score": response.importance_score
|
||||
}
|
||||
else:
|
||||
message = response.messages[-1].content
|
||||
message = message.replace("```json", "").replace("```", "")
|
||||
message = json.loads(message)
|
||||
# print(message)
|
||||
result = {
|
||||
"index": index,
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": message['corrected_sentence'],
|
||||
"importance_score": message['importance_score']
|
||||
}
|
||||
|
||||
# 打印详细进度信息
|
||||
if index % 100 == 0:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
|
||||
remaining_sentences = total_sentences - (index + 1)
|
||||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||
|
||||
# 格式化时间显示
|
||||
def format_time(seconds):
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = seconds / 60
|
||||
return f"{minutes:.1f}分钟"
|
||||
else:
|
||||
hours = seconds / 3600
|
||||
return f"{hours:.1f}小时"
|
||||
|
||||
print(f"已完成第 {index + 1} 个句子的处理")
|
||||
print(f" - 剩余句子数: {remaining_sentences}")
|
||||
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||||
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||||
print(f" - 已用时间: {format_time(elapsed_time)}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {index} 个句子时出错: {e}")
|
||||
# 出错时返回原句子和中等评分
|
||||
return {
|
||||
"index": index,
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": sentence,
|
||||
"importance_score": 5.0
|
||||
}
|
||||
|
||||
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||||
"""批量并发处理句子,每2000条保存一次检查点"""
|
||||
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:54)...")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
total_sentences = len(sentences)
|
||||
|
||||
# 分批处理,每批2000个句子
|
||||
batch_size = 2000
|
||||
all_processed_sentences = []
|
||||
|
||||
for batch_start in range(0, total_sentences, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_sentences)
|
||||
batch_sentences = sentences[batch_start:batch_end]
|
||||
|
||||
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||||
|
||||
# 创建信号量限制并发数
|
||||
semaphore = asyncio.Semaphore(54)
|
||||
|
||||
# 创建当前批次的任务
|
||||
tasks = []
|
||||
for i, sentence in enumerate(batch_sentences):
|
||||
global_index = batch_start + i
|
||||
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
|
||||
tasks.append(task)
|
||||
|
||||
# 并发执行当前批次的任务
|
||||
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理当前批次的结果,过滤异常
|
||||
batch_processed_sentences = []
|
||||
batch_error_count = 0
|
||||
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
print(f"任务执行异常: {result}")
|
||||
batch_error_count += 1
|
||||
elif isinstance(result, dict):
|
||||
batch_processed_sentences.append(result)
|
||||
else:
|
||||
batch_error_count += 1
|
||||
|
||||
# 按原始顺序排序(因为并发执行可能改变顺序)
|
||||
batch_processed_sentences.sort(key=lambda x: x['index'])
|
||||
|
||||
# 移除index字段
|
||||
for item in batch_processed_sentences:
|
||||
del item['index']
|
||||
|
||||
# 添加到总结果中
|
||||
all_processed_sentences.extend(batch_processed_sentences)
|
||||
|
||||
# 保存检查点
|
||||
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
|
||||
|
||||
# 打印当前批次统计信息
|
||||
elapsed_time = time.time() - start_time
|
||||
completed_sentences = len(all_processed_sentences)
|
||||
|
||||
print(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||||
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||||
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||||
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||||
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||||
print(f" - 检查点已保存:{checkpoint_filename}")
|
||||
|
||||
if batch_end < total_sentences:
|
||||
remaining_sentences = total_sentences - completed_sentences
|
||||
avg_time_per_sentence = elapsed_time / completed_sentences
|
||||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||||
|
||||
# 打印最终统计信息
|
||||
total_time = time.time() - start_time
|
||||
print(f"\n=== 全部处理完成!===")
|
||||
print(f" - 总成功:{len(all_processed_sentences)}")
|
||||
print(f" - 总用时:{total_time/60:.1f}分钟")
|
||||
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
|
||||
|
||||
return all_processed_sentences
|
||||
|
||||
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
|
||||
"""保存检查点文件"""
|
||||
# 生成检查点文件名
|
||||
base_name = os.path.splitext(self.output_file)[0]
|
||||
checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json"
|
||||
|
||||
# 保存检查点
|
||||
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"metadata": {
|
||||
"total_processed": len(processed_sentences),
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"checkpoint_number": current_count
|
||||
},
|
||||
"sentences": processed_sentences
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return checkpoint_filename
|
||||
|
||||
async def process_files(self) -> List[Dict[str, Any]]:
|
||||
"""处理所有文件"""
|
||||
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {self.input_dir}")
|
||||
return []
|
||||
|
||||
# 排序文件以确保一致的处理顺序
|
||||
json_files.sort()
|
||||
|
||||
if self.max_files:
|
||||
json_files = json_files[:self.max_files]
|
||||
|
||||
print(f"Found {len(json_files)} JSON files to process")
|
||||
|
||||
all_sentences = []
|
||||
|
||||
for i, file_path in enumerate(json_files):
|
||||
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||||
|
||||
documents = self.parse_large_json_file(file_path)
|
||||
print(f" Parsed {len(documents)} documents")
|
||||
|
||||
for doc in documents:
|
||||
sentences = self.extract_sentences_from_document(doc)
|
||||
all_sentences.extend(sentences)
|
||||
|
||||
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||||
|
||||
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||||
|
||||
# 去重
|
||||
unique_sentences = []
|
||||
seen = set()
|
||||
for sentence in all_sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence and sentence not in seen and len(sentence) > 10:
|
||||
unique_sentences.append(sentence)
|
||||
seen.add(sentence)
|
||||
|
||||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||||
|
||||
# 使用LLM处理句子
|
||||
if self.enable_llm_processing:
|
||||
processed_sentences = await self.process_sentences_with_llm(unique_sentences)
|
||||
else:
|
||||
# 基础模式:不使用LLM
|
||||
processed_sentences = [
|
||||
{
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": sentence,
|
||||
"importance_score": 5.0
|
||||
}
|
||||
for sentence in unique_sentences
|
||||
]
|
||||
|
||||
return processed_sentences
|
||||
|
||||
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
||||
"""保存处理后的句子到文件"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True)
|
||||
|
||||
# 保存为JSON格式,包含完整信息
|
||||
json_output_file = self.output_file.replace('.txt', '.json')
|
||||
with open(json_output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 保存为简单文本格式(仅修正后的句子)
|
||||
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||||
for item in processed_sentences:
|
||||
f.write(item['corrected_sentence'] + '\n')
|
||||
|
||||
# 生成重要性排序文件
|
||||
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
|
||||
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
|
||||
with open(importance_file, 'w', encoding='utf-8') as f:
|
||||
for item in importance_sorted:
|
||||
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
|
||||
|
||||
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
|
||||
print(f" - JSON格式: {json_output_file}")
|
||||
print(f" - 文本格式: {self.output_file}")
|
||||
print(f" - 重要性排序: {importance_file}")
|
||||
|
||||
# 统计信息
|
||||
scores = [item['importance_score'] for item in processed_sentences]
|
||||
avg_score = sum(scores) / len(scores) if scores else 0
|
||||
print(f" - 平均重要性评分: {avg_score:.2f}")
|
||||
print(f" - 最高评分: {max(scores):.1f}")
|
||||
print(f" - 最低评分: {min(scores):.1f}")
|
||||
|
||||
async def run(self):
|
||||
"""运行处理流程"""
|
||||
print("Starting enhanced TREx to sentences conversion...")
|
||||
processed_sentences = await self.process_files()
|
||||
self.save_sentences(processed_sentences)
|
||||
print("Enhanced conversion completed!")
|
||||
|
||||
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
||||
"""查找最新的检查点文件"""
|
||||
base_name = os.path.splitext(self.output_file)[0]
|
||||
pattern = f"./output/{base_name}_checkpoint_*.json"
|
||||
checkpoint_files = glob.glob(pattern)
|
||||
|
||||
if not checkpoint_files:
|
||||
return None
|
||||
|
||||
# 按检查点编号排序,获取最新的
|
||||
latest_file = None
|
||||
latest_count = 0
|
||||
|
||||
for file in checkpoint_files:
|
||||
try:
|
||||
# 从文件名中提取数字
|
||||
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||||
if match:
|
||||
count = int(match.group(1))
|
||||
if count > latest_count:
|
||||
latest_count = count
|
||||
latest_file = file
|
||||
except:
|
||||
continue
|
||||
|
||||
if latest_file:
|
||||
return latest_file, latest_count
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
|
||||
"""从检查点文件加载已处理的句子"""
|
||||
try:
|
||||
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if 'sentences' in data:
|
||||
return data['sentences']
|
||||
else:
|
||||
# 旧格式的检查点文件
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"加载检查点文件失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
|
||||
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
||||
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path')
|
||||
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
||||
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
|
||||
parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.input_dir):
|
||||
print(f"Error: Input directory {args.input_dir} does not exist!")
|
||||
return
|
||||
|
||||
processor = EnhancedTRExProcessor(
|
||||
args.input_dir,
|
||||
args.output_file,
|
||||
args.max_files,
|
||||
enable_llm_processing=not args.no_llm
|
||||
)
|
||||
|
||||
# 检查是否要从检查点恢复
|
||||
if args.resume:
|
||||
checkpoint_result = processor.find_latest_checkpoint()
|
||||
if checkpoint_result:
|
||||
latest_checkpoint, latest_count = checkpoint_result
|
||||
print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)")
|
||||
confirm = input("是否从检查点恢复?(y/n): ").lower().strip()
|
||||
if confirm == 'y':
|
||||
processed_sentences = processor.load_checkpoint(latest_checkpoint)
|
||||
if processed_sentences:
|
||||
print(f"成功加载 {len(processed_sentences)} 条已处理的句子")
|
||||
processor.save_sentences(processed_sentences)
|
||||
print("从检查点恢复完成!")
|
||||
return
|
||||
else:
|
||||
print("检查点文件加载失败,将重新开始处理")
|
||||
else:
|
||||
print("不从检查点恢复,将重新开始处理")
|
||||
else:
|
||||
print("未找到检查点文件,将重新开始处理")
|
||||
|
||||
# 运行异步处理
|
||||
asyncio.run(processor.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -45,5 +45,5 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10\
|
||||
--knowlwdge_num 4096 \
|
||||
--knowlwdge_length 8
|
||||
--knowledge_num 4096 \
|
||||
--knowledge_length 8
|
||||
|
@ -46,5 +46,5 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10\
|
||||
--knowlwdge_num 1024 \
|
||||
--knowlwdge_length 8
|
||||
--knowledge_num 16384 \
|
||||
--knowledge_length 64
|
||||
|
@ -291,7 +291,7 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
|
||||
# 加载模型
|
||||
model = MiniMindLM(lm_config).to(args.device)
|
||||
|
||||
@ -349,7 +349,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
# 性能分析相关参数
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
@ -406,7 +406,6 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
|
@ -289,8 +289,8 @@ def main():
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
|
||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
@ -327,8 +327,8 @@ def main():
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.use_flash_attn,
|
||||
knowlwdge_num=args.knowlwdge_num,
|
||||
knowlwdge_length=args.knowlwdge_length
|
||||
knowledge_num=args.knowledge_num,
|
||||
knowledge_length=args.knowledge_length
|
||||
)
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user