使用分类来获取谓词

This commit is contained in:
Yu Chengzhang 2025-07-05 22:18:32 +08:00
parent fcab661af9
commit 75265f6652
7 changed files with 374 additions and 1669 deletions

View File

@ -14,6 +14,70 @@ from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
class PretrainDataset(Dataset): class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512): def __init__(self, data_path, tokenizer, max_length=512):
super().__init__() super().__init__()
@ -204,15 +268,94 @@ class TriplePretrainDataset(Dataset):
- 预先tokenize所有数据 - 预先tokenize所有数据
- 使用进度条显示处理进度 - 使用进度条显示处理进度
""" """
def __init__(self, data_path, tokenizer, max_length=512): def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
print("🚀 开始加载和预处理三元组数据...") self.val_samples = None
self.samples = self.load_and_preprocess_data(data_path) self.predicate_to_id = {} # 初始化
if samples is None:
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
print("🚀 开始加载和预处理三元组数据...")
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
print("🚀 加载和预处理三元组数据完成")
else:
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
self.samples = samples
print("🚀 加载和预处理三元组数据完成")
def load_predicate_vocab(self, path):
with open(path, 'r', encoding='utf-8') as f:
predicate_vocab = json.load(f)
return predicate_vocab
def get_val_samples(self):
return self.val_samples
def clear_cache(self, data_path):
"""清除缓存文件"""
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
cache_files = [
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
]
for cache_file in cache_files:
if os.path.exists(cache_file):
os.remove(cache_file)
print(f"🗑️ 已删除缓存文件: {cache_file}")
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
os.rmdir(cache_dir)
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
def load_and_preprocess_data(self, path): def load_and_preprocess_data(self, path):
"""加载并预处理三元组数据""" """加载并预处理三元组数据"""
# 生成缓存文件名(基于数据文件路径)
cache_dir = os.path.join(os.path.dirname(path), 'cache')
os.makedirs(cache_dir, exist_ok=True)
data_filename = os.path.basename(path).split('.')[0]
cache_files = {
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
}
# 检查缓存文件是否存在
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
if cache_exists:
print("📁 发现缓存文件,直接加载...")
# 从缓存加载
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
self.predicate_vocab = json.load(f)
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
self.predicate_to_id = json.load(f)
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
train_samples = json.load(f)
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
val_samples = json.load(f)
print(f"✅ 从缓存加载完成:")
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
return train_samples, val_samples
# 缓存不存在,重新处理数据
print("📂 缓存不存在,开始加载和处理原始数据...")
# 1. 加载原始数据 # 1. 加载原始数据
print("📂 加载原始数据...") print("📂 加载原始数据...")
if path.endswith('.json'): if path.endswith('.json'):
@ -228,71 +371,92 @@ class TriplePretrainDataset(Dataset):
raise ValueError(f"Unsupported file format: {path}") raise ValueError(f"Unsupported file format: {path}")
print(f"📊 原始数据量: {len(data)} 个样本") print(f"📊 原始数据量: {len(data)} 个样本")
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
print("🔍 过滤低频谓词数据...")
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
# 2. 数据验证和筛选只保留一个target # 3.获取占比大于等于0.01%的谓词
print("🔍 验证数据格式并选择单个target...") valid_predicates = set()
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
valid_predicates.add(predicate)
else:
# 如果不是统计格式,假设是有效谓词
valid_predicates.add(predicate)
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}")
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
original_count = len(data)
filtered_data = []
print("🚀 开始过滤低频谓词数据...")
for sample in tqdm(data, desc="过滤低频谓词"):
result = process_sample_filter((sample, valid_predicates))
if result is not None:
filtered_data.append(result)
data = filtered_data
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}")
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
print("🔍 更新谓词词表并创建序号映射...")
original_vocab_size = len(self.predicate_vocab)
filtered_predicate_vocab = {}
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
filtered_predicate_vocab[predicate] = stats
else:
# 如果不是统计格式,保留
filtered_predicate_vocab[predicate] = stats
# 创建谓词到序号的映射字典
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
self.predicate_vocab = filtered_predicate_vocab
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}")
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
# 6. 数据验证和筛选只保留一个target优先选择占比小的谓词以平衡数据单进程处理
print("🔍 验证数据格式并选择单个target平衡数据...")
valid_samples = [] valid_samples = []
for i, sample in enumerate(tqdm(data, desc="验证数据格式")): print("🚀 开始验证数据格式...")
if not isinstance(sample, dict) or 'text' not in sample: for sample in tqdm(data, desc="验证数据格式"):
continue result = process_sample_validation((sample, self.predicate_vocab))
if result is not None:
targets = sample.get('target', []) valid_samples.append(result)
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择第一个有效的target
selected_target = None
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
selected_target = triple
break
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
valid_samples.append({
'text': sample['text'],
'target': selected_target # 只保留一个target
})
print(f"✅ 有效样本数: {len(valid_samples)}") print(f"✅ 有效样本数: {len(valid_samples)}")
# 7.拆分训练集合与测试集合
import random
random.seed(42)
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
train_samples = [sample for sample in valid_samples if sample not in val_samples]
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
# 8. 保存到缓存文件
print("💾 保存处理结果到缓存文件...")
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
# 3. 分批tokenize目标句子 with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
print("🔤 分批tokenize目标句子...") json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
processed_samples = [] with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
batch_size = 1000 # 每批处理1000个句子避免内存爆炸 json.dump(train_samples, f, ensure_ascii=False, indent=2)
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"): with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
# 获取当前批次 json.dump(val_samples, f, ensure_ascii=False, indent=2)
batch_samples = valid_samples[i:i + batch_size]
# 提取当前批次的目标句子
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
# 批量tokenize当前批次
batch_encodings = self.tokenizer(
batch_target_sentences,
max_length=128, # 目标句子通常较短
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 构建当前批次的样本数据
for j, sample in enumerate(batch_samples):
processed_samples.append({
'text': sample['text'], # 保持原始文本不进行tokenize
'target_input_ids': batch_encodings.input_ids[j],
'target_attention_mask': batch_encodings.attention_mask[j],
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
})
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本") print("✅ 缓存文件保存完成")
return processed_samples
return train_samples, val_samples
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
@ -302,10 +466,10 @@ class TriplePretrainDataset(Dataset):
return f"{triple['subject']} {triple['predicate']} {triple['object']}" return f"{triple['subject']} {triple['predicate']} {triple['object']}"
def __getitem__(self, index): def __getitem__(self, index):
"""返回数据,输入文本在运行时tokenize目标已预tokenize""" """返回数据,用于谓词分类任务"""
sample = self.samples[index] sample = self.samples[index]
# 在运行时tokenize输入文本(用于语言建模) # 在运行时tokenize输入文本
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}" input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
encoding = self.tokenizer( encoding = self.tokenizer(
input_text, input_text,
@ -317,19 +481,18 @@ class TriplePretrainDataset(Dataset):
input_ids = encoding.input_ids.squeeze() input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id) loss_mask = (input_ids != self.tokenizer.pad_token_id)
# 获取谓词分类标签
target_predicate = sample['target']['predicate']
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
# 构建训练数据 # 构建训练数据
X = input_ids[:-1] X = input_ids[:-1]
Y = input_ids[1:]
loss_mask = loss_mask[1:] loss_mask = loss_mask[1:]
return { return {
'input_ids': X, 'input_ids': X,
'labels': Y, 'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
'loss_mask': loss_mask, 'loss_mask': loss_mask
'target_input_ids': sample['target_input_ids'], # 已经是tensor
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
'target_sentence': sample['target_sentence'], # 字符串,用于调试
'original_text': sample['text']
} }

View File

@ -489,8 +489,8 @@ class TripleExtractionHead(nn.Module):
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 交叉注意力机制(用于主语和宾语提取) # 交叉注意力机制(用于主语和宾语提取)
self.cross_attention_subject = CrossAttention(config) # self.cross_attention_subject = CrossAttention(config)
self.cross_attention_object = CrossAttention(config) # self.cross_attention_object = CrossAttention(config)
# 归一化层 # 归一化层
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps) self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
@ -498,13 +498,13 @@ class TripleExtractionHead(nn.Module):
# Feed Forward 网络 # Feed Forward 网络
self.predicate_ff = FeedForward(config) self.predicate_ff = FeedForward(config)
self.subject_ff = FeedForward(config) # self.subject_ff = FeedForward(config)
self.object_ff = FeedForward(config) # self.object_ff = FeedForward(config)
# 输出投影层 - 修改为支持序列预测 # 输出投影层 - 修改为支持序列预测
self.predicate_output = nn.Linear(config.dim, self.max_predicate_len *config.dim, bias=False) self.predicate_output = nn.Linear(config.dim, 264, bias=False)
self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False) # self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False) # self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
print(f"三元组提取任务头配置:") print(f"三元组提取任务头配置:")
print(f"- 主语最大长度: {self.max_subject_len}") print(f"- 主语最大长度: {self.max_subject_len}")
@ -530,30 +530,29 @@ class TripleExtractionHead(nn.Module):
# 2. h1通过feed_forward得到谓语输出 # 2. h1通过feed_forward得到谓语输出
predicate_features = self.predicate_ff(h1) predicate_features = self.predicate_ff(h1)
predicate_features = predicate_features.mean(dim=1) predicate_features = predicate_features.mean(dim=1)
predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size] predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1)
# 3. h1通过交叉注意力k,v都是h得到h2 # # 3. h1通过交叉注意力k,v都是h得到h2
h2 = self.cross_attention_subject(h1, h) # query是h1key和value都是h # h2 = self.cross_attention_subject(h1, h) # query是h1key和value都是h
h2 = h1 + h2 # 残差连接 # h2 = h1 + h2 # 残差连接
# 4. h2通过feed_forward得到主语输出 # # 4. h2通过feed_forward得到主语输出
subject_features = self.subject_ff(self.subject_norm(h2)) # subject_features = self.subject_ff(self.subject_norm(h2))
subject_features = subject_features.mean(dim=1) # subject_features = subject_features.mean(dim=1)
subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size] # subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1) # subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
# 5. h2通过交叉注意力k,v都是h得到h3 # # 5. h2通过交叉注意力k,v都是h得到h3
h3 = self.cross_attention_object(h2, h) # query是h2key和value都是h # h3 = self.cross_attention_object(h2, h) # query是h2key和value都是h
h3 = h2 + h3 # 残差连接 # h3 = h2 + h3 # 残差连接
# 6. h3通过feed_forward得到宾语输出 # # 6. h3通过feed_forward得到宾语输出
object_features = self.object_ff(self.object_norm(h3)) # object_features = self.object_ff(self.object_norm(h3))
object_features = object_features.mean(dim=1) # object_features = object_features.mean(dim=1)
object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size] # object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
object_logits = object_raw.view(batch_size, self.max_object_len, -1) # object_logits = object_raw.view(batch_size, self.max_object_len, -1)
return predicate_logits, subject_logits, object_logits return predicate_class
class MiniMindBlock(nn.Module): class MiniMindBlock(nn.Module):
@ -656,18 +655,8 @@ class MiniMindLM(PreTrainedModel):
) )
# 应用三元组提取任务头 # 应用三元组提取任务头
predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis) predicate_class = self.triple_extraction_head(h, pos_cis)
predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1)
subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1)
object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1)
predicate_logits = self.output(predicate_logits)
subject_logits = self.output(subject_logits)
object_logits = self.output(object_logits)
predicate_logits = predicate_logits.reshape(input_ids.size(0), self.params.max_predicate_len, -1)
subject_logits = subject_logits.reshape(input_ids.size(0), self.params.max_subject_len, -1)
object_logits = object_logits.reshape(input_ids.size(0), self.params.max_object_len, -1)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :]) logits = self.output(self.norm(h)[:, slice_indices, :])
@ -682,9 +671,7 @@ class MiniMindLM(PreTrainedModel):
# 添加三元组提取结果 # 添加三元组提取结果
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size] # 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
output.predicate_logits = predicate_logits output.predicate_class = predicate_class
output.subject_logits = subject_logits
output.object_logits = object_logits
return output return output

View File

@ -1,225 +0,0 @@
#!/usr/bin/env python3
"""
JSON文件合并脚本
读取多个JSON文件并合并为一个JSON文件
"""
import json
import os
from typing import Dict, List, Any, Union
# 需要合并的JSON文件列表
JSON_FILES_TO_MERGE = [
"output/trex_sentences_enhanced_checkpoint_360000.json"
]
for i in range(1, 1010):
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
def load_json_file(file_path: str) -> Union[Dict, List, None]:
"""加载JSON文件"""
if not os.path.exists(file_path):
print(f"警告: 文件 {file_path} 不存在")
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载: {file_path}")
return data
except json.JSONDecodeError as e:
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
return None
except Exception as e:
print(f"错误: 读取文件 {file_path} 失败 - {e}")
return None
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
"""合并两个JSON数据结构"""
# 如果两个都是列表,直接合并
if isinstance(data1, list) and isinstance(data2, list):
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)}")
return data1 + data2
# 如果两个都是字典
elif isinstance(data1, dict) and isinstance(data2, dict):
print("合并两个字典结构")
merged = data1.copy()
# 特殊处理:如果都有'sentences'字段且为列表合并sentences
if 'sentences' in data1 and 'sentences' in data2:
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])}")
merged['sentences'] = data1['sentences'] + data2['sentences']
# 更新metadata if exists
if 'metadata' in merged:
if isinstance(merged['metadata'], dict):
merged['metadata']['total_sentences'] = len(merged['sentences'])
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
# 合并其他字段
for key, value in data2.items():
if key != 'sentences' and key not in merged:
merged[key] = value
return merged
# 普通字典合并
for key, value in data2.items():
if key in merged:
# 如果key重复且都是列表合并列表
if isinstance(merged[key], list) and isinstance(value, list):
merged[key] = merged[key] + value
# 如果key重复且都是字典递归合并
elif isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = merge_json_data(merged[key], value)
else:
# 其他情况保留第二个文件的值
merged[key] = value
print(f"字段 '{key}' 被覆盖")
else:
merged[key] = value
return merged
# 类型不匹配的情况,创建一个包含两者的新结构
else:
print("数据类型不匹配,创建包含两者的新结构")
return {
"data_from_save.json": data1,
"data_from_save2.json": data2,
"merged_at": "test.py"
}
def save_merged_json(data: Union[Dict, List], output_path: str):
"""保存合并后的JSON数据"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"合并结果已保存到: {output_path}")
# 显示统计信息
if isinstance(data, dict):
if 'sentences' in data and isinstance(data['sentences'], list):
print(f"总计句子数: {len(data['sentences'])}")
print(f"总计字段数: {len(data)}")
elif isinstance(data, list):
print(f"总计列表项数: {len(data)}")
except Exception as e:
print(f"错误: 保存文件失败 - {e}")
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
"""从合并结果中移除重复的句子(基于句子内容)"""
if isinstance(data, dict) and 'sentences' in data:
if isinstance(data['sentences'], list):
original_count = len(data['sentences'])
seen_sentences = set()
unique_sentences = []
for item in data['sentences']:
if isinstance(item, dict):
# 如果是字典使用sentence字段或corrected_sentence字段作为唯一标识
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
elif isinstance(item, str):
sentence_key = item
else:
sentence_key = str(item)
if sentence_key and sentence_key not in seen_sentences:
seen_sentences.add(sentence_key)
unique_sentences.append(item)
data['sentences'] = unique_sentences
# 更新metadata
if 'metadata' in data and isinstance(data['metadata'], dict):
data['metadata']['total_sentences'] = len(unique_sentences)
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
return data
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
"""合并多个JSON数据结构"""
if not data_list:
return {}
if len(data_list) == 1:
return data_list[0]
print(f"准备合并 {len(data_list)} 个JSON数据结构")
# 从第一个数据开始,逐步合并其他数据
merged_data = data_list[0]
for i, data in enumerate(data_list[1:], 1):
print(f"正在合并第 {i+1} 个数据结构...")
merged_data = merge_json_data(merged_data, data)
return merged_data
def main():
"""主函数"""
print("=== JSON文件合并脚本 ===")
# 输出路径
output_path = "output/merged.json"
print(f"准备合并以下文件:")
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
print(f" {i}. {file_path}")
print(f"输出文件: {output_path}")
print()
# 加载所有文件
loaded_data = []
successfully_loaded = []
for file_path in JSON_FILES_TO_MERGE:
data = load_json_file(file_path)
if data is not None:
loaded_data.append(data)
successfully_loaded.append(file_path)
# 检查是否至少有一个文件加载成功
if not loaded_data:
print("错误: 没有文件能够成功加载,退出")
return
print(f"成功加载了 {len(loaded_data)} 个文件:")
for file_path in successfully_loaded:
print(f"{file_path}")
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
print(f"警告: {failed_count} 个文件加载失败")
print()
# 合并所有数据
if len(loaded_data) == 1:
print("只有一个文件可用,直接使用...")
merged_data = loaded_data[0]
else:
print("开始合并所有文件...")
merged_data = merge_multiple_json_data(loaded_data)
# 去重处理
print("\n检查并去除重复项...")
merged_data = remove_duplicates_from_sentences(merged_data)
# 保存合并结果
print("\n保存合并结果...")
save_merged_json(merged_data, output_path)
print("\n=== 合并完成 ===")
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
if __name__ == "__main__":
main()

View File

@ -1,61 +0,0 @@
#!/usr/bin/env python3
"""
小规模测试预处理脚本
"""
import sys
import os
# 添加路径
sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing')
# 导入主模块
from preprocess_pretrain import *
# 修改配置为小规模测试
DATASET_CONFIG["wikipedia"]["max_samples"] = 100
DATASET_CONFIG["gutenberg"]["max_samples"] = 50
DATASET_CONFIG["openwebtext"]["max_samples"] = 20
DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50
DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30
DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15
# 修改输出路径
OUTPUT_FILE = "/tmp/test_main.jsonl"
OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl"
def test_small_scale():
"""小规模测试"""
print("Starting small scale test...")
# 设置随机种子
random.seed(42)
try:
# 初始化tokenizer
init_tokenizer()
# 开始合并数据集
merge_datasets()
# 检查输出文件
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
main_lines = len(f.readlines())
print(f"Main file created: {main_lines} lines")
if os.path.exists(OUTPUT_FILE_EXTRA):
with open(OUTPUT_FILE_EXTRA, 'r') as f:
extra_lines = len(f.readlines())
print(f"Extra file created: {extra_lines} lines")
print("Small scale test completed successfully!")
except Exception as e:
print(f"Test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_small_scale()

File diff suppressed because it is too large Load Diff

View File

@ -127,6 +127,7 @@ dependencies = [
"regex==2024.11.6", "regex==2024.11.6",
"requests==2.32.3", "requests==2.32.3",
"rich==13.7.1", "rich==13.7.1",
"rouge-score>=0.1.2",
"rpds-py==0.24.0", "rpds-py==0.24.0",
"s3transfer==0.13.0", "s3transfer==0.13.0",
"safetensors==0.5.3", "safetensors==0.5.3",

View File

@ -397,6 +397,66 @@ def log_memory_status(step, accelerator, stage="", detailed=False):
Logger(log_msg, accelerator) Logger(log_msg, accelerator)
# 验证函数
def validate_model(model, val_loader, accelerator, ctx, args):
"""
验证模型性能
Args:
model: 模型
val_loader: 验证集数据加载器
accelerator: accelerator对象
ctx: 上下文管理器
args: 参数
Returns:
dict: 包含平均损失和准确率的字典
"""
model.eval()
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
num_batches = 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for batch_data in val_loader:
try:
# 数据准备
X = batch_data['input_ids'].to(accelerator.device)
Y = batch_data['labels']
# 前向传播
with ctx:
res = model(X, step=0) # 验证时step设为0
loss = criterion(res.predicate_class.cpu(), Y.cpu())
# 计算准确率
predicted_classes = torch.argmax(res.predicate_class, dim=1)
predicted_classes = predicted_classes.to(Y.device)
correct_predictions += (predicted_classes == Y).sum().item()
total_predictions += Y.size(0)
# 累计损失
total_loss += loss.item()
num_batches += 1
except Exception as e:
Logger(f"验证时出错: {e}", accelerator)
continue
# 计算平均值
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
model.train() # 重新设置为训练模式
return {
'avg_loss': avg_loss,
'accuracy': accuracy,
'total_samples': total_predictions
}
# 日志记录函数 # 日志记录函数
def Logger(msg, accelerator=None): def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印 # 如果没有提供accelerator则只在主进程打印
@ -515,7 +575,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer): def train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
# 三元组提取训练模式:不需要传统的交叉熵损失函数 # 三元组提取训练模式:不需要传统的交叉熵损失函数
epoch_start_time = time.time() epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader) total_steps_in_epoch = len(train_loader)
@ -563,9 +623,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
X = batch_data['input_ids'] X = batch_data['input_ids']
Y = batch_data['labels'] Y = batch_data['labels']
loss_mask = batch_data['loss_mask'] loss_mask = batch_data['loss_mask']
target_input_ids = batch_data['target_input_ids'] # target_input_ids = batch_data['target_input_ids']
target_attention_mask = batch_data['target_attention_mask'] # target_attention_mask = batch_data['target_attention_mask']
target_sentences = batch_data['target_sentences'] # 用于调试输出 # target_sentences = batch_data['target_sentences'] # 用于调试输出
# === 2. 学习率更新 === # === 2. 学习率更新 ===
if scheduler is not None: if scheduler is not None:
@ -590,36 +650,34 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# === 4. 损失计算 === # === 4. 损失计算 ===
# 三元组提取模式只使用ROUGE Loss进行三元组损失计算 # 三元组提取模式只使用ROUGE Loss进行三元组损失计算
Logger("三元组提取训练模式", accelerator) if step == 0 else None # Logger("三元组提取训练模式", accelerator) if step == 0 else None
# 确保有三元组输出 # # 确保有三元组输出
if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')): # if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
raise ValueError("模型没有输出三元组logits请检查模型配置") # raise ValueError("模型没有输出三元组logits请检查模型配置")
# 确保有目标数据 # # 确保有目标数据
if target_input_ids is None: # if target_input_ids is None:
raise ValueError("没有三元组目标数据,请检查数据格式") # raise ValueError("没有三元组目标数据,请检查数据格式")
# 计算三元组损失 # 计算分类损失
try: try:
Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None Logger("使用分类交叉熵损失", accelerator) if step == 0 else None
# 计时GPU损失计算 # 计时GPU损失计算
if args.profile and accelerator.is_main_process and loss_start is not None: if args.profile and accelerator.is_main_process and loss_start is not None:
loss_start.record() loss_start.record()
# 计算优化后的嵌入余弦相似度损失 # 计算交叉熵损失
loss = compute_triple_rouge_loss_optimized( criterion = nn.CrossEntropyLoss()
res.subject_logits, res.predicate_logits, res.object_logits, loss = criterion(res.predicate_class, Y)
target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature
)
# 计时GPU损失计算结束 # 计时GPU损失计算结束
if args.profile and accelerator.is_main_process and loss_end is not None: if args.profile and accelerator.is_main_process and loss_end is not None:
loss_end.record() loss_end.record()
except Exception as e: except Exception as e:
Logger(f"Error: ROUGE loss computation failed: {e}", accelerator) Logger(f"Error: 分类损失计算失败: {e}", accelerator)
import traceback import traceback
Logger(f"Traceback: {traceback.format_exc()}", accelerator) Logger(f"Traceback: {traceback.format_exc()}", accelerator)
loss = res.logits.sum() * 0.0 + 1.0 loss = res.logits.sum() * 0.0 + 1.0
@ -683,13 +741,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator) f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
Logger("=" * 50, accelerator) Logger("=" * 50, accelerator)
Logger("=== 三元组预测示例 ===", accelerator) # Logger("=== 三元组预测示例 ===", accelerator)
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer) # predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
# 显示前2个样本的目标句子 # # 显示前2个样本的目标句子
for i, target_sentence in enumerate(target_sentences[:2]): # for i, target_sentence in enumerate(target_sentences[:2]):
Logger(f"样本{i+1}目标: {target_sentence}", accelerator) # Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator) # Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
Logger("==================", accelerator) Logger("=======val dataset=========", accelerator)
# 重置GPU事件 # 重置GPU事件
forward_start = torch.cuda.Event(enable_timing=True) forward_start = torch.cuda.Event(enable_timing=True)
@ -734,11 +792,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# SwanLab日志记录 # SwanLab日志记录
if args.use_swanlab and accelerator.is_main_process and swanlab_run: if args.use_swanlab and accelerator.is_main_process and swanlab_run:
Logger("=======val dataset=========", accelerator)
# 验证集评估
val_results = validate_model(model, val_loader, accelerator, ctx, args)
Logger(f"验证集结果 - 平均损失: {val_results['avg_loss']:.6f}, 准确率: {val_results['accuracy']:.4f}, 样本数: {val_results['total_samples']}", accelerator)
log_dict = { log_dict = {
"epoch": epoch + 1, "epoch": epoch + 1,
"step": step + 1, "step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch, "total_steps_in_epoch": total_steps_in_epoch,
"triple_embedding_cosine_loss": loss.item() * args.accumulation_steps, "train_loss": loss.item() * args.accumulation_steps,
"val_loss": val_results['avg_loss'],
"val_accuracy": val_results['accuracy'],
"val_samples": val_results['total_samples'],
"lr": current_lr, "lr": current_lr,
"tokens_per_sec": tokens_per_sec, "tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time, "epoch_time_left_seconds": epoch_remaining_time,
@ -776,7 +843,7 @@ def main():
parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数") parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=192) parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数 parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
@ -793,6 +860,7 @@ def main():
parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json") parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json")
parser.add_argument("--predicate_vocab_path", type=str, default="./dataset/predicate_stats.json", help="Path to predicate vocabulary/statistics file")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
@ -932,7 +1000,8 @@ def main():
# 创建数据集和数据加载器(专用于三元组提取训练) # 创建数据集和数据加载器(专用于三元组提取训练)
######################################################### #########################################################
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator) Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_ds = TriplePretrainDataset(data_path=args.data_path, predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
val_ds = TriplePretrainDataset(data_path=args.data_path,samples=train_ds.get_val_samples(), predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
# 创建自定义collate_fn来处理优化后的数据格式 # 创建自定义collate_fn来处理优化后的数据格式
def triple_collate_fn(batch): def triple_collate_fn(batch):
@ -940,17 +1009,17 @@ def main():
input_ids = torch.stack([item['input_ids'] for item in batch]) input_ids = torch.stack([item['input_ids'] for item in batch])
labels = torch.stack([item['labels'] for item in batch]) labels = torch.stack([item['labels'] for item in batch])
loss_mask = torch.stack([item['loss_mask'] for item in batch]) loss_mask = torch.stack([item['loss_mask'] for item in batch])
target_input_ids = torch.stack([item['target_input_ids'] for item in batch]) # target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch]) # target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
target_sentences = [item['target_sentence'] for item in batch] # 用于调试 # target_sentences = [item['target_sentence'] for item in batch] # 用于调试
return { return {
'input_ids': input_ids, 'input_ids': input_ids,
'labels': labels, 'labels': labels,
'loss_mask': loss_mask, 'loss_mask': loss_mask,
'target_input_ids': target_input_ids, # 'target_input_ids': target_input_ids,
'target_attention_mask': target_attention_mask, # 'target_attention_mask': target_attention_mask,
'target_sentences': target_sentences # 'target_sentences': target_sentences
} }
train_loader = DataLoader( train_loader = DataLoader(
@ -963,6 +1032,15 @@ def main():
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用 # persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
collate_fn=triple_collate_fn collate_fn=triple_collate_fn
) )
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size,
pin_memory=False,
drop_last=True,
shuffle=False,
num_workers=0,
collate_fn=triple_collate_fn
)
######################################################### #########################################################
# 创建优化器 # 创建优化器
@ -993,7 +1071,7 @@ def main():
overall_start_time = time.time() # Record overall start time overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs): for epoch in range(args.epochs):
Logger(f"开始第{epoch+1}轮训练", accelerator) Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
# 每个epoch结束后进行内存清理 # 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator) Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)