diff --git a/model/dataset.py b/model/dataset.py index 6658eca..47281f4 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -14,6 +14,70 @@ from tqdm import tqdm os.environ["TOKENIZERS_PARALLELISM"] = "true" +def process_sample_filter(data_args): + """处理单个样本的过滤逻辑""" + sample, valid_predicates = data_args + if 'target' in sample and isinstance(sample['target'], list): + # 过滤target中的低频谓词 + valid_targets = [] + for triple in sample['target']: + if isinstance(triple, dict) and 'predicate' in triple: + if triple['predicate'] in valid_predicates: + valid_targets.append(triple) + + # 如果还有有效的target,保留这个样本 + if valid_targets: + sample['target'] = valid_targets + return sample + else: + return None + else: + # 如果没有target信息,保留样本 + return sample + + +def process_sample_validation(data_args): + """处理单个样本的验证逻辑""" + sample, predicate_vocab = data_args + if not isinstance(sample, dict) or 'text' not in sample: + return None + + targets = sample.get('target', []) + if not isinstance(targets, list) or len(targets) == 0: + # 如果没有有效的target,创建一个默认的 + selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} + else: + # 验证并选择target,优先选择占比小的谓词 + selected_target = None + min_percentage = float('inf') + + for triple in targets: + if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']): + predicate = triple['predicate'] + + # 使用predicate_vocab中的统计信息 + if predicate in predicate_vocab: + stats = predicate_vocab[predicate] + if isinstance(stats, dict) and 'percentage' in stats: + percentage = stats['percentage'] + if percentage < min_percentage: + min_percentage = percentage + selected_target = triple + elif selected_target is None: + selected_target = triple + elif selected_target is None: + selected_target = triple + + # 如果没有找到有效的target,使用默认值 + if selected_target is None: + selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} + + return { + 'text': sample['text'], + 'target': selected_target # 只保留一个target + } + + class PretrainDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=512): super().__init__() @@ -204,15 +268,94 @@ class TriplePretrainDataset(Dataset): - 预先tokenize所有数据 - 使用进度条显示处理进度 """ - def __init__(self, data_path, tokenizer, max_length=512): + def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512): super().__init__() self.tokenizer = tokenizer self.max_length = max_length - print("🚀 开始加载和预处理三元组数据...") - self.samples = self.load_and_preprocess_data(data_path) + self.val_samples = None + self.predicate_to_id = {} # 初始化 + if samples is None: + self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path) + print("🚀 开始加载和预处理三元组数据...") + self.samples,self.val_samples = self.load_and_preprocess_data(data_path) + print("🚀 加载和预处理三元组数据完成") + else: + cache_dir = os.path.join(os.path.dirname(data_path), 'cache') + data_filename = os.path.basename(data_path).split('.')[0] + predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json') + self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path) + self.samples = samples + print("🚀 加载和预处理三元组数据完成") + def load_predicate_vocab(self, path): + with open(path, 'r', encoding='utf-8') as f: + predicate_vocab = json.load(f) + return predicate_vocab + + def get_val_samples(self): + return self.val_samples + + def clear_cache(self, data_path): + """清除缓存文件""" + cache_dir = os.path.join(os.path.dirname(data_path), 'cache') + data_filename = os.path.basename(data_path).split('.')[0] + cache_files = [ + os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'), + os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'), + os.path.join(cache_dir, f'{data_filename}_train_samples.json'), + os.path.join(cache_dir, f'{data_filename}_val_samples.json') + ] + for cache_file in cache_files: + if os.path.exists(cache_file): + os.remove(cache_file) + print(f"🗑️ 已删除缓存文件: {cache_file}") + + if os.path.exists(cache_dir) and not os.listdir(cache_dir): + os.rmdir(cache_dir) + print(f"🗑️ 已删除空的缓存目录: {cache_dir}") + def load_and_preprocess_data(self, path): """加载并预处理三元组数据""" + # 生成缓存文件名(基于数据文件路径) + cache_dir = os.path.join(os.path.dirname(path), 'cache') + os.makedirs(cache_dir, exist_ok=True) + + data_filename = os.path.basename(path).split('.')[0] + cache_files = { + 'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'), + 'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'), + 'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'), + 'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json') + } + + # 检查缓存文件是否存在 + cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values()) + + if cache_exists: + print("📁 发现缓存文件,直接加载...") + # 从缓存加载 + with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f: + self.predicate_vocab = json.load(f) + + with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f: + self.predicate_to_id = json.load(f) + + with open(cache_files['train_samples'], 'r', encoding='utf-8') as f: + train_samples = json.load(f) + + with open(cache_files['val_samples'], 'r', encoding='utf-8') as f: + val_samples = json.load(f) + + print(f"✅ 从缓存加载完成:") + print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}") + print(f"✅ 训练集大小: {len(train_samples)}") + print(f"✅ 测试集大小: {len(val_samples)}") + + return train_samples, val_samples + + # 缓存不存在,重新处理数据 + print("📂 缓存不存在,开始加载和处理原始数据...") + # 1. 加载原始数据 print("📂 加载原始数据...") if path.endswith('.json'): @@ -228,71 +371,92 @@ class TriplePretrainDataset(Dataset): raise ValueError(f"Unsupported file format: {path}") print(f"📊 原始数据量: {len(data)} 个样本") + + # 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据 + print("🔍 过滤低频谓词数据...") + print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词") - # 2. 数据验证和筛选(只保留一个target) - print("🔍 验证数据格式并选择单个target...") + # 3.获取占比大于等于0.01%的谓词 + valid_predicates = set() + for predicate, stats in self.predicate_vocab.items(): + if isinstance(stats, dict) and 'percentage' in stats: + if stats['percentage'] >= 0.01: + valid_predicates.add(predicate) + else: + # 如果不是统计格式,假设是有效谓词 + valid_predicates.add(predicate) + + print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个") + + # 4.过滤数据:去除包含低频谓词的数据(单进程处理) + original_count = len(data) + filtered_data = [] + + print("🚀 开始过滤低频谓词数据...") + for sample in tqdm(data, desc="过滤低频谓词"): + result = process_sample_filter((sample, valid_predicates)) + if result is not None: + filtered_data.append(result) + + data = filtered_data + print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条") + + # 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射 + print("🔍 更新谓词词表并创建序号映射...") + original_vocab_size = len(self.predicate_vocab) + filtered_predicate_vocab = {} + + for predicate, stats in self.predicate_vocab.items(): + if isinstance(stats, dict) and 'percentage' in stats: + if stats['percentage'] >= 0.01: + filtered_predicate_vocab[predicate] = stats + else: + # 如果不是统计格式,保留 + filtered_predicate_vocab[predicate] = stats + + # 创建谓词到序号的映射字典 + self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())} + self.predicate_vocab = filtered_predicate_vocab + print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个") + print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号") + + # 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理) + print("🔍 验证数据格式并选择单个target(平衡数据)...") valid_samples = [] - for i, sample in enumerate(tqdm(data, desc="验证数据格式")): - if not isinstance(sample, dict) or 'text' not in sample: - continue - - targets = sample.get('target', []) - if not isinstance(targets, list) or len(targets) == 0: - # 如果没有有效的target,创建一个默认的 - selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} - else: - # 验证并选择第一个有效的target - selected_target = None - for triple in targets: - if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']): - selected_target = triple - break - - # 如果没有找到有效的target,使用默认值 - if selected_target is None: - selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} - - valid_samples.append({ - 'text': sample['text'], - 'target': selected_target # 只保留一个target - }) + print("🚀 开始验证数据格式...") + for sample in tqdm(data, desc="验证数据格式"): + result = process_sample_validation((sample, self.predicate_vocab)) + if result is not None: + valid_samples.append(result) print(f"✅ 有效样本数: {len(valid_samples)}") + + # 7.拆分训练集合与测试集合 + import random + random.seed(42) + val_samples = random.sample(valid_samples, min(1000, len(valid_samples))) + train_samples = [sample for sample in valid_samples if sample not in val_samples] + print(f"✅ 训练集大小: {len(train_samples)}") + print(f"✅ 测试集大小: {len(val_samples)}") + + # 8. 保存到缓存文件 + print("💾 保存处理结果到缓存文件...") + with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f: + json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2) - # 3. 分批tokenize目标句子 - print("🔤 分批tokenize目标句子...") + with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f: + json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2) - processed_samples = [] - batch_size = 1000 # 每批处理1000个句子,避免内存爆炸 + with open(cache_files['train_samples'], 'w', encoding='utf-8') as f: + json.dump(train_samples, f, ensure_ascii=False, indent=2) - for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"): - # 获取当前批次 - batch_samples = valid_samples[i:i + batch_size] - - # 提取当前批次的目标句子 - batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples] - - # 批量tokenize当前批次 - batch_encodings = self.tokenizer( - batch_target_sentences, - max_length=128, # 目标句子通常较短 - padding='max_length', - truncation=True, - return_tensors='pt' - ) - - # 构建当前批次的样本数据 - for j, sample in enumerate(batch_samples): - processed_samples.append({ - 'text': sample['text'], # 保持原始文本,不进行tokenize - 'target_input_ids': batch_encodings.input_ids[j], - 'target_attention_mask': batch_encodings.attention_mask[j], - 'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试 - }) + with open(cache_files['val_samples'], 'w', encoding='utf-8') as f: + json.dump(val_samples, f, ensure_ascii=False, indent=2) - print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本") - return processed_samples + print("✅ 缓存文件保存完成") + + return train_samples, val_samples def __len__(self): return len(self.samples) @@ -302,10 +466,10 @@ class TriplePretrainDataset(Dataset): return f"{triple['subject']} {triple['predicate']} {triple['object']}" def __getitem__(self, index): - """返回数据,输入文本在运行时tokenize,目标已预tokenize""" + """返回数据,用于谓词分类任务""" sample = self.samples[index] - # 在运行时tokenize输入文本(用于语言建模) + # 在运行时tokenize输入文本 input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}" encoding = self.tokenizer( input_text, @@ -317,19 +481,18 @@ class TriplePretrainDataset(Dataset): input_ids = encoding.input_ids.squeeze() loss_mask = (input_ids != self.tokenizer.pad_token_id) + # 获取谓词分类标签 + target_predicate = sample['target']['predicate'] + predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到 + # 构建训练数据 X = input_ids[:-1] - Y = input_ids[1:] loss_mask = loss_mask[1:] return { 'input_ids': X, - 'labels': Y, - 'loss_mask': loss_mask, - 'target_input_ids': sample['target_input_ids'], # 已经是tensor - 'target_attention_mask': sample['target_attention_mask'], # 已经是tensor - 'target_sentence': sample['target_sentence'], # 字符串,用于调试 - 'original_text': sample['text'] + 'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签 + 'loss_mask': loss_mask } diff --git a/model/model_extra.py b/model/model_extra.py index fca8c54..2e0cce0 100644 --- a/model/model_extra.py +++ b/model/model_extra.py @@ -489,8 +489,8 @@ class TripleExtractionHead(nn.Module): self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps) # 交叉注意力机制(用于主语和宾语提取) - self.cross_attention_subject = CrossAttention(config) - self.cross_attention_object = CrossAttention(config) + # self.cross_attention_subject = CrossAttention(config) + # self.cross_attention_object = CrossAttention(config) # 归一化层 self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps) @@ -498,13 +498,13 @@ class TripleExtractionHead(nn.Module): # Feed Forward 网络 self.predicate_ff = FeedForward(config) - self.subject_ff = FeedForward(config) - self.object_ff = FeedForward(config) + # self.subject_ff = FeedForward(config) + # self.object_ff = FeedForward(config) # 输出投影层 - 修改为支持序列预测 - self.predicate_output = nn.Linear(config.dim, self.max_predicate_len *config.dim, bias=False) - self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False) - self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False) + self.predicate_output = nn.Linear(config.dim, 264, bias=False) + # self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False) + # self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False) print(f"三元组提取任务头配置:") print(f"- 主语最大长度: {self.max_subject_len}") @@ -530,30 +530,29 @@ class TripleExtractionHead(nn.Module): # 2. h1通过feed_forward得到谓语输出 predicate_features = self.predicate_ff(h1) predicate_features = predicate_features.mean(dim=1) - predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size] - predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1) + predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size] - # 3. h1通过交叉注意力(k,v都是h)得到h2 - h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h - h2 = h1 + h2 # 残差连接 + # # 3. h1通过交叉注意力(k,v都是h)得到h2 + # h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h + # h2 = h1 + h2 # 残差连接 - # 4. h2通过feed_forward得到主语输出 - subject_features = self.subject_ff(self.subject_norm(h2)) - subject_features = subject_features.mean(dim=1) - subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size] - subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1) + # # 4. h2通过feed_forward得到主语输出 + # subject_features = self.subject_ff(self.subject_norm(h2)) + # subject_features = subject_features.mean(dim=1) + # subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size] + # subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1) - # 5. h2通过交叉注意力(k,v都是h)得到h3 - h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h - h3 = h2 + h3 # 残差连接 + # # 5. h2通过交叉注意力(k,v都是h)得到h3 + # h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h + # h3 = h2 + h3 # 残差连接 - # 6. h3通过feed_forward得到宾语输出 - object_features = self.object_ff(self.object_norm(h3)) - object_features = object_features.mean(dim=1) - object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size] - object_logits = object_raw.view(batch_size, self.max_object_len, -1) + # # 6. h3通过feed_forward得到宾语输出 + # object_features = self.object_ff(self.object_norm(h3)) + # object_features = object_features.mean(dim=1) + # object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size] + # object_logits = object_raw.view(batch_size, self.max_object_len, -1) - return predicate_logits, subject_logits, object_logits + return predicate_class class MiniMindBlock(nn.Module): @@ -656,18 +655,8 @@ class MiniMindLM(PreTrainedModel): ) # 应用三元组提取任务头 - predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis) - predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1) - subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1) - object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1) - - predicate_logits = self.output(predicate_logits) - subject_logits = self.output(subject_logits) - object_logits = self.output(object_logits) - - predicate_logits = predicate_logits.reshape(input_ids.size(0), self.params.max_predicate_len, -1) - subject_logits = subject_logits.reshape(input_ids.size(0), self.params.max_subject_len, -1) - object_logits = object_logits.reshape(input_ids.size(0), self.params.max_object_len, -1) + predicate_class = self.triple_extraction_head(h, pos_cis) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) @@ -682,9 +671,7 @@ class MiniMindLM(PreTrainedModel): # 添加三元组提取结果 # 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size] - output.predicate_logits = predicate_logits - output.subject_logits = subject_logits - output.object_logits = object_logits + output.predicate_class = predicate_class return output diff --git a/preprocessing/merge_output_json.py b/preprocessing/merge_output_json.py deleted file mode 100644 index bfa26e7..0000000 --- a/preprocessing/merge_output_json.py +++ /dev/null @@ -1,225 +0,0 @@ -#!/usr/bin/env python3 -""" -JSON文件合并脚本 -读取多个JSON文件并合并为一个JSON文件 -""" - -import json -import os -from typing import Dict, List, Any, Union - -# 需要合并的JSON文件列表 -JSON_FILES_TO_MERGE = [ - "output/trex_sentences_enhanced_checkpoint_360000.json" -] -for i in range(1, 1010): - JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json") - -def load_json_file(file_path: str) -> Union[Dict, List, None]: - """加载JSON文件""" - if not os.path.exists(file_path): - print(f"警告: 文件 {file_path} 不存在") - return None - - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - print(f"成功加载: {file_path}") - return data - except json.JSONDecodeError as e: - print(f"错误: 无法解析JSON文件 {file_path} - {e}") - return None - except Exception as e: - print(f"错误: 读取文件 {file_path} 失败 - {e}") - return None - -def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]: - """合并两个JSON数据结构""" - - # 如果两个都是列表,直接合并 - if isinstance(data1, list) and isinstance(data2, list): - print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项") - return data1 + data2 - - # 如果两个都是字典 - elif isinstance(data1, dict) and isinstance(data2, dict): - print("合并两个字典结构") - merged = data1.copy() - - # 特殊处理:如果都有'sentences'字段且为列表,合并sentences - if 'sentences' in data1 and 'sentences' in data2: - if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list): - print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项") - merged['sentences'] = data1['sentences'] + data2['sentences'] - - # 更新metadata if exists - if 'metadata' in merged: - if isinstance(merged['metadata'], dict): - merged['metadata']['total_sentences'] = len(merged['sentences']) - merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)] - - # 合并其他字段 - for key, value in data2.items(): - if key != 'sentences' and key not in merged: - merged[key] = value - - return merged - - # 普通字典合并 - for key, value in data2.items(): - if key in merged: - # 如果key重复且都是列表,合并列表 - if isinstance(merged[key], list) and isinstance(value, list): - merged[key] = merged[key] + value - # 如果key重复且都是字典,递归合并 - elif isinstance(merged[key], dict) and isinstance(value, dict): - merged[key] = merge_json_data(merged[key], value) - else: - # 其他情况保留第二个文件的值 - merged[key] = value - print(f"字段 '{key}' 被覆盖") - else: - merged[key] = value - - return merged - - # 类型不匹配的情况,创建一个包含两者的新结构 - else: - print("数据类型不匹配,创建包含两者的新结构") - return { - "data_from_save.json": data1, - "data_from_save2.json": data2, - "merged_at": "test.py" - } - -def save_merged_json(data: Union[Dict, List], output_path: str): - """保存合并后的JSON数据""" - try: - # 确保输出目录存在 - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - print(f"合并结果已保存到: {output_path}") - - # 显示统计信息 - if isinstance(data, dict): - if 'sentences' in data and isinstance(data['sentences'], list): - print(f"总计句子数: {len(data['sentences'])}") - print(f"总计字段数: {len(data)}") - elif isinstance(data, list): - print(f"总计列表项数: {len(data)}") - - except Exception as e: - print(f"错误: 保存文件失败 - {e}") - -def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]: - """从合并结果中移除重复的句子(基于句子内容)""" - if isinstance(data, dict) and 'sentences' in data: - if isinstance(data['sentences'], list): - original_count = len(data['sentences']) - seen_sentences = set() - unique_sentences = [] - - for item in data['sentences']: - if isinstance(item, dict): - # 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识 - sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence') - elif isinstance(item, str): - sentence_key = item - else: - sentence_key = str(item) - - if sentence_key and sentence_key not in seen_sentences: - seen_sentences.add(sentence_key) - unique_sentences.append(item) - - data['sentences'] = unique_sentences - - # 更新metadata - if 'metadata' in data and isinstance(data['metadata'], dict): - data['metadata']['total_sentences'] = len(unique_sentences) - data['metadata']['duplicates_removed'] = original_count - len(unique_sentences) - - print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)") - - return data - -def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]: - """合并多个JSON数据结构""" - if not data_list: - return {} - - if len(data_list) == 1: - return data_list[0] - - print(f"准备合并 {len(data_list)} 个JSON数据结构") - - # 从第一个数据开始,逐步合并其他数据 - merged_data = data_list[0] - - for i, data in enumerate(data_list[1:], 1): - print(f"正在合并第 {i+1} 个数据结构...") - merged_data = merge_json_data(merged_data, data) - - return merged_data - -def main(): - """主函数""" - print("=== JSON文件合并脚本 ===") - - # 输出路径 - output_path = "output/merged.json" - - print(f"准备合并以下文件:") - for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1): - print(f" {i}. {file_path}") - print(f"输出文件: {output_path}") - print() - - # 加载所有文件 - loaded_data = [] - successfully_loaded = [] - - for file_path in JSON_FILES_TO_MERGE: - data = load_json_file(file_path) - if data is not None: - loaded_data.append(data) - successfully_loaded.append(file_path) - - # 检查是否至少有一个文件加载成功 - if not loaded_data: - print("错误: 没有文件能够成功加载,退出") - return - - print(f"成功加载了 {len(loaded_data)} 个文件:") - for file_path in successfully_loaded: - print(f" ✓ {file_path}") - - if len(loaded_data) < len(JSON_FILES_TO_MERGE): - failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data) - print(f"警告: {failed_count} 个文件加载失败") - print() - - # 合并所有数据 - if len(loaded_data) == 1: - print("只有一个文件可用,直接使用...") - merged_data = loaded_data[0] - else: - print("开始合并所有文件...") - merged_data = merge_multiple_json_data(loaded_data) - - # 去重处理 - print("\n检查并去除重复项...") - merged_data = remove_duplicates_from_sentences(merged_data) - - # 保存合并结果 - print("\n保存合并结果...") - save_merged_json(merged_data, output_path) - - print("\n=== 合并完成 ===") - print(f"合并了 {len(successfully_loaded)} 个文件的数据") - -if __name__ == "__main__": - main() diff --git a/preprocessing/test_preprocess_small.py b/preprocessing/test_preprocess_small.py deleted file mode 100644 index 9034eb0..0000000 --- a/preprocessing/test_preprocess_small.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -""" -小规模测试预处理脚本 -""" - -import sys -import os - -# 添加路径 -sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing') - -# 导入主模块 -from preprocess_pretrain import * - -# 修改配置为小规模测试 -DATASET_CONFIG["wikipedia"]["max_samples"] = 100 -DATASET_CONFIG["gutenberg"]["max_samples"] = 50 -DATASET_CONFIG["openwebtext"]["max_samples"] = 20 - -DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50 -DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30 -DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15 - -# 修改输出路径 -OUTPUT_FILE = "/tmp/test_main.jsonl" -OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl" - -def test_small_scale(): - """小规模测试""" - print("Starting small scale test...") - - # 设置随机种子 - random.seed(42) - - try: - # 初始化tokenizer - init_tokenizer() - - # 开始合并数据集 - merge_datasets() - - # 检查输出文件 - if os.path.exists(OUTPUT_FILE): - with open(OUTPUT_FILE, 'r') as f: - main_lines = len(f.readlines()) - print(f"Main file created: {main_lines} lines") - - if os.path.exists(OUTPUT_FILE_EXTRA): - with open(OUTPUT_FILE_EXTRA, 'r') as f: - extra_lines = len(f.readlines()) - print(f"Extra file created: {extra_lines} lines") - - print("Small scale test completed successfully!") - - except Exception as e: - print(f"Test failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_small_scale() \ No newline at end of file diff --git a/preprocessing/trex_to_sentences_simple.py b/preprocessing/trex_to_sentences_simple.py deleted file mode 100644 index e5167b9..0000000 --- a/preprocessing/trex_to_sentences_simple.py +++ /dev/null @@ -1,1238 +0,0 @@ -#!/usr/bin/env python3 -""" -TREx数据集增强预处理脚本 -使用vLLM OpenAI兼容API进行句子后处理和重要性评分 - -支持两个独立步骤: -1. 句子提取:从TREx数据集提取句子并保存为JSON -2. LLM处理:读取JSON文件进行LLM后处理和重要性评分 -""" - -import json -import os -import glob -from typing import List, Dict, Any, Union, Set -import re -import asyncio -import time -import logging -from datetime import datetime -import requests -from pydantic import BaseModel, Field -import aiohttp -import concurrent.futures - -# 设置日志系统 -def setup_logging(): - """设置日志系统""" - # 确保logs目录存在 - os.makedirs('logs', exist_ok=True) - - # 创建日志文件名(包含时间戳) - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - log_file = f'logs/trex_processor_{timestamp}.log' - - # 配置日志格式 - log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s' - - # 配置root logger - logging.basicConfig( - level=logging.INFO, - format=log_format, - handlers=[ - logging.FileHandler(log_file, encoding='utf-8'), - logging.StreamHandler() # 同时输出到控制台 - ] - ) - - # 获取logger - logger = logging.getLogger(__name__) - logger.info(f"日志系统初始化完成,日志文件: {log_file}") - return logger - -# 全局日志对象 -logger = setup_logging() - -class ProcessedSentence(BaseModel): - """处理后的句子结构""" - corrected_sentence: str = Field( - ..., - description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色" - ) - importance_score: float = Field( - ..., - description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度", - ge=0.0, - le=10.0 - ) - - -class EnhancedTRExProcessor: - def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None, - sentences_json: str = None, enable_llm_processing: bool = True): - self.input_dir = input_dir - - # 确保output目录存在 - os.makedirs('output', exist_ok=True) - - # 确保所有输出文件都在output目录中 - if output_file: - if not output_file.startswith('output/'): - self.output_file = os.path.join('output', output_file) - else: - self.output_file = output_file - else: - self.output_file = None - - if sentences_json: - if not sentences_json.startswith('output/'): - self.sentences_json = os.path.join('output', sentences_json) - else: - self.sentences_json = sentences_json - else: - self.sentences_json = "output/extracted_sentences.json" - - self.max_files = max_files - self.enable_llm_processing = enable_llm_processing - - # Ollama API配置 - self.model_name = "gemma3:latest" # Ollama模型名称 - self.ollama_base_url = "http://localhost:11434" # Ollama服务器地址 - self.batch_size_per_request = 8 # 每个API请求处理的句子数量(Ollama建议较小批次) - self.max_concurrent_requests = 2 # 最大并发请求数(Ollama建议较低并发) - self.request_timeout = 180 # 请求超时时间(秒) - self.retry_attempts = 3 # 重试次数 - - # 统计信息 - self.total_requests = 0 - self.successful_requests = 0 - self.failed_requests = 0 - - logger.info(f"处理器初始化完成 - 模型: {self.model_name}, 批次大小: {self.batch_size_per_request}, 并发数: {self.max_concurrent_requests}") - - # 扩展的Wikidata属性映射 - self.property_mappings = { - # 基本关系 - "http://www.wikidata.org/prop/direct/P31": "is a", - "http://www.wikidata.org/prop/direct/P279": "is a type of", - - # 人物相关 - "http://www.wikidata.org/prop/direct/P106": "works as", - "http://www.wikidata.org/prop/direct/P27": "is a citizen of", - "http://www.wikidata.org/prop/direct/P19": "was born in", - "http://www.wikidata.org/prop/direct/P20": "died in", - "http://www.wikidata.org/prop/direct/P569": "was born on", - "http://www.wikidata.org/prop/direct/P570": "died on", - "http://www.wikidata.org/prop/direct/P22": "has father", - "http://www.wikidata.org/prop/direct/P25": "has mother", - "http://www.wikidata.org/prop/direct/P26": "is married to", - - # 组织相关 - "http://www.wikidata.org/prop/direct/P102": "is a member of", - "http://www.wikidata.org/prop/direct/P108": "works for", - "http://www.wikidata.org/prop/direct/P159": "has headquarters in", - "http://www.wikidata.org/prop/direct/P112": "was founded by", - "http://www.wikidata.org/prop/direct/P571": "was founded in", - "http://www.wikidata.org/prop/direct/P169": "has CEO", - - # 地理相关 - "http://www.wikidata.org/prop/direct/P17": "is located in", - "http://www.wikidata.org/prop/direct/P131": "is located in", - "http://www.wikidata.org/prop/direct/P36": "has capital", - "http://www.wikidata.org/prop/direct/P47": "borders", - - # 其他关系 - "http://www.wikidata.org/prop/direct/P1142": "has ideology", - "http://www.wikidata.org/prop/direct/P361": "is part of", - "http://www.wikidata.org/prop/direct/P737": "was influenced by", - "http://www.wikidata.org/prop/direct/P127": "is owned by", - "http://www.wikidata.org/prop/direct/P155": "follows", - "http://www.wikidata.org/prop/direct/P156": "is followed by", - "http://www.wikidata.org/prop/direct/P138": "is named after" - } - - def get_system_prompt(self) -> str: - """获取系统提示""" - return """You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge. - -### Sentence Correction Rules: -1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses -2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc. -3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage -4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues -5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent -6. Do not add information not present in the original text, only correct errors - -### Correction Examples: -- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.' -- Corrected: 'Argument is related to philosophy and logic.' - -- Error: 'Beijing is a capital city and are.' -- Corrected: 'Beijing is a capital city.' - -Importance scoring criteria (0.0-10.0, in increments of 0.1): - -0.0 points - Completely incorrect or meaningless information -Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3' - -0.5 points - Almost worthless information -Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday' - -1.0 points - Extremely rare, non-practical knowledge -Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456' - -1.5 points - Very niche detailed information -Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga' - -2.0 points - Details in niche professional fields -Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts' - -2.5 points - Technical details only professionals care about -Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material' - -3.0 points - Professional knowledge in specific fields -Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties' - -3.5 points - Professional information with some value -Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard' - -4.0 points - Meaningful knowledge known by few -Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events' - -4.5 points - Knowledge of interest to some groups -Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport' - -5.0 points - General knowledge of moderate importance -Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals' - -5.5 points - Fairly useful common sense -Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge' - -6.0 points - Knowledge most educated people should know -Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies' - -6.5 points - Important cultural or scientific common sense -Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions' - -7.0 points - Important foundational knowledge -Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules' - -7.5 points - Very important common sense -Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation' - -8.0 points - Core knowledge in basic education -Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules' - -8.5 points - Important knowledge everyone should master -Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations' - -9.0 points - Extremely important basic concepts -Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts' - -9.5 points - Core knowledge everyone must know -Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts' - -10.0 points - Most basic and important common sense -Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers' - -When scoring, please consider: -1. Popularity of knowledge - How many people know this knowledge -2. Practical value - How useful this knowledge is in daily life -3. Educational importance - The position of this knowledge in the education system -4. Cultural significance - The importance of this knowledge for understanding world - -Please respond with valid JSON in the following format: -{ - "corrected_sentence": "corrected sentence here", - "importance_score": evaluation score -}""" - - async def process_batch_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]: - """使用vLLM OpenAI兼容API处理一批句子""" - processed_sentences = [] - - async with aiohttp.ClientSession() as session: - # 创建并发任务 - semaphore = asyncio.Semaphore(self.max_concurrent_requests) - tasks = [] - - # 将句子分成小批次 - for i in range(0, len(sentences), self.batch_size_per_request): - batch_sentences = sentences[i:i + self.batch_size_per_request] - task = self.process_single_batch_request(session, semaphore, batch_sentences, i) - tasks.append(task) - - # 等待所有任务完成 - batch_results = await asyncio.gather(*tasks, return_exceptions=True) - - # 收集结果 - for result in batch_results: - if isinstance(result, Exception): - logger.error(f"批次处理出错: {result}") - continue - if result: - processed_sentences.extend(result) - - return processed_sentences - - async def process_single_batch_request(self, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore, - sentences: List[str], batch_index: int) -> List[Dict[str, Any]]: - """处理单个批次的API请求""" - async with semaphore: - for attempt in range(self.retry_attempts): - try: - # 为每个句子创建单独的消息 - messages = [] - for sentence in sentences: - messages.append({ - "role": "user", - "content": f"Please correct the errors in the following sentence and evaluate its importance: {sentence}" - }) - - # 构建Ollama请求数据 - request_data = { - "model": self.model_name, - "messages": [ - {"role": "system", "content": self.get_system_prompt()} - ] + messages, - "stream": False, - "options": { - "temperature": 0.2, - "num_predict": 500 * len(sentences) # 为每个句子分配足够的token - }, - "format": "json" # Ollama的JSON格式参数 - } - - # 发送请求到Ollama - async with session.post( - f'{self.ollama_base_url}/api/chat', - json=request_data, - timeout=aiohttp.ClientTimeout(total=self.request_timeout) - ) as response: - - if response.status == 200: - result = await response.json() - return self.parse_ollama_response(result, sentences, batch_index) - else: - error_text = await response.text() - logger.warning(f"API请求失败 (批次 {batch_index}, 尝试 {attempt + 1}/{self.retry_attempts}): {response.status} - {error_text}") - - if attempt == self.retry_attempts - 1: # 最后一次尝试 - logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然失败") - self.failed_requests += len(sentences) - return self.create_default_responses(sentences) - else: - # 等待后重试 - await asyncio.sleep(2 ** attempt) # 指数退避 - continue - - except asyncio.TimeoutError: - logger.warning(f"批次 {batch_index} 请求超时 (尝试 {attempt + 1}/{self.retry_attempts})") - if attempt == self.retry_attempts - 1: - logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然超时") - self.failed_requests += len(sentences) - return self.create_default_responses(sentences) - else: - await asyncio.sleep(2 ** attempt) - continue - - except Exception as e: - logger.warning(f"处理批次 {batch_index} 时出错 (尝试 {attempt + 1}/{self.retry_attempts}): {e}") - if attempt == self.retry_attempts - 1: - logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然失败") - self.failed_requests += len(sentences) - return self.create_default_responses(sentences) - else: - await asyncio.sleep(2 ** attempt) - continue - - # 如果所有重试都失败了 - return self.create_default_responses(sentences) - - def parse_ollama_response(self, response: Dict[str, Any], original_sentences: List[str], batch_index: int) -> List[Dict[str, Any]]: - """解析Ollama响应""" - processed_sentences = [] - - try: - # Ollama的响应格式 - message = response.get('message', {}) - content = message.get('content', '') - - if not content: - logger.warning(f"批次 {batch_index} 没有返回内容") - return self.create_default_responses(original_sentences) - - # 尝试解析JSON响应 - try: - # 如果返回的是单个JSON对象 - if content.strip().startswith('{') and content.strip().endswith('}'): - response_data = json.loads(content) - processed_sentence = ProcessedSentence( - corrected_sentence=response_data.get('corrected_sentence', original_sentences[0] if original_sentences else ""), - importance_score=float(response_data.get('importance_score', 5.0)) - ) - - processed_sentences.append({ - "original_sentence": original_sentences[0] if original_sentences else "", - "corrected_sentence": processed_sentence.corrected_sentence, - "importance_score": processed_sentence.importance_score - }) - self.successful_requests += 1 - - # 如果有多个句子但只返回一个结果,为其他句子创建默认响应 - for i in range(1, len(original_sentences)): - processed_sentences.append({ - "original_sentence": original_sentences[i], - "corrected_sentence": original_sentences[i], - "importance_score": 5.0 - }) - self.failed_requests += 1 - - else: - # 尝试解析多个JSON对象 - json_objects = [] - for line in content.split('\n'): - line = line.strip() - if line.startswith('{') and line.endswith('}'): - try: - json_objects.append(json.loads(line)) - except: - continue - - if json_objects: - for i, (sentence, json_obj) in enumerate(zip(original_sentences, json_objects)): - try: - processed_sentence = ProcessedSentence( - corrected_sentence=json_obj.get('corrected_sentence', sentence), - importance_score=float(json_obj.get('importance_score', 5.0)) - ) - - processed_sentences.append({ - "original_sentence": sentence, - "corrected_sentence": processed_sentence.corrected_sentence, - "importance_score": processed_sentence.importance_score - }) - self.successful_requests += 1 - except Exception as e: - logger.warning(f"解析JSON对象失败: {e}") - processed_sentences.append({ - "original_sentence": sentence, - "corrected_sentence": sentence, - "importance_score": 5.0 - }) - self.failed_requests += 1 - - # 为剩余句子创建默认响应 - for i in range(len(json_objects), len(original_sentences)): - processed_sentences.append({ - "original_sentence": original_sentences[i], - "corrected_sentence": original_sentences[i], - "importance_score": 5.0 - }) - self.failed_requests += 1 - else: - logger.warning(f"批次 {batch_index} 无法解析JSON响应: {content}") - return self.create_default_responses(original_sentences) - - except (json.JSONDecodeError, ValueError) as e: - logger.warning(f"批次 {batch_index} 解析响应JSON失败: {e}") - logger.warning(f"原始内容: {content}") - return self.create_default_responses(original_sentences) - - except Exception as e: - logger.error(f"解析批次 {batch_index} 响应时出错: {e}") - return self.create_default_responses(original_sentences) - - return processed_sentences - - def create_default_responses(self, sentences: List[str]) -> List[Dict[str, Any]]: - """为失败的请求创建默认响应""" - default_responses = [] - for sentence in sentences: - default_responses.append({ - "original_sentence": sentence, - "corrected_sentence": sentence, - "importance_score": 5.0 - }) - return default_responses - - async def process_sentences_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]: - """使用Ollama API处理句子""" - logger.info(f"开始使用Ollama API处理 {len(sentences)} 个句子...") - print(f"开始使用Ollama API处理 {len(sentences)} 个句子...") - - start_time = time.time() - total_sentences = len(sentences) - total_processed_count = 0 - - # 检查Ollama服务状态 - if not self.check_ollama_status(): - logger.error("Ollama服务状态异常,无法继续处理") - print("错误:Ollama服务状态异常,请检查服务是否正常运行") - return [] - - # 分大批次处理(用于检查点保存) - large_batch_size = 1000 # 每1000个句子保存一次检查点 - all_processed_sentences = [] - - for large_batch_start in range(0, total_sentences, large_batch_size): - large_batch_end = min(large_batch_start + large_batch_size, total_sentences) - large_batch_sentences = sentences[large_batch_start:large_batch_end] - large_batch_number = large_batch_start // large_batch_size + 1 - - logger.info(f"=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===") - print(f"\n=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===") - - large_batch_start_time = time.time() - - # 处理当前大批次 - batch_processed = await self.process_batch_with_vllm_api(large_batch_sentences) - all_processed_sentences.extend(batch_processed) - total_processed_count += len(batch_processed) - - # 保存当前大批次的检查点 - checkpoint_filename = self.save_batch_checkpoint(batch_processed, large_batch_number, total_processed_count) - - # 打印进度 - large_batch_time = time.time() - large_batch_start_time - elapsed_time = time.time() - start_time - - logger.info(f"大批次 {large_batch_number} 处理完成!") - logger.info(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟") - logger.info(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)") - logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟") - logger.info(f" - 批次检查点已保存:{checkpoint_filename}") - - print(f"大批次 {large_batch_number} 处理完成!") - print(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟") - print(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)") - print(f" - 已用时间:{elapsed_time/60:.1f}分钟") - print(f" - 批次检查点已保存:{checkpoint_filename}") - - if large_batch_end < total_sentences: - remaining_sentences = total_sentences - total_processed_count - avg_time_per_sentence = elapsed_time / total_processed_count - estimated_remaining_time = avg_time_per_sentence * remaining_sentences - logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") - print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") - - # 打印最终统计 - total_time = time.time() - start_time - logger.info(f"=== 全部处理完成!===") - logger.info(f" - 总成功:{self.successful_requests}") - logger.info(f" - 总失败:{self.failed_requests}") - logger.info(f" - 总用时:{total_time/60:.1f}分钟") - logger.info(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒") - - print(f"\n=== 全部处理完成!===") - print(f" - 总成功:{self.successful_requests}") - print(f" - 总失败:{self.failed_requests}") - print(f" - 总用时:{total_time/60:.1f}分钟") - print(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒") - - return all_processed_sentences - - def check_ollama_status(self) -> bool: - """检查Ollama服务是否正常运行""" - try: - # 检查Ollama API是否响应 - response = requests.get(f'{self.ollama_base_url}/api/tags', timeout=10) - - if response.status_code == 200: - models = response.json() - model_names = [model.get('name', 'unknown') for model in models.get('models', [])] - logger.info(f"Ollama服务状态正常,可用模型: {model_names}") - - # 检查目标模型是否可用 - if self.model_name in model_names: - logger.info(f"目标模型 {self.model_name} 可用") - return True - else: - logger.warning(f"目标模型 {self.model_name} 不在可用模型列表中: {model_names}") - logger.info("尝试拉取模型...") - # 尝试拉取模型 - try: - pull_response = requests.post( - f'{self.ollama_base_url}/api/pull', - json={"name": self.model_name}, - timeout=300 # 5分钟超时 - ) - if pull_response.status_code == 200: - logger.info(f"成功拉取模型 {self.model_name}") - return True - else: - logger.error(f"拉取模型失败: {pull_response.status_code}") - return False - except Exception as e: - logger.error(f"拉取模型时出错: {e}") - return False - else: - logger.error(f"Ollama API响应异常,状态码: {response.status_code}") - return False - - except requests.exceptions.RequestException as e: - logger.error(f"无法连接到Ollama API: {e}") - return False - except Exception as e: - logger.error(f"检查Ollama状态时出错: {e}") - return False - - def clean_text(self, text: str) -> str: - """清理文本,处理特殊字符""" - if not text: - return "" - - # 处理常见的Unicode字符 - text = text.replace("–", "-") # en dash - text = text.replace("—", "-") # em dash - text = text.replace("'", "'") # right single quotation mark - text = text.replace("'", "'") # left single quotation mark - text = text.replace(""", '"') # left double quotation mark - text = text.replace(""", '"') # right double quotation mark - - # 处理可能的转义序列 - try: - text = text.encode('utf-8').decode('utf-8') - except: - pass - - # 清理多余的空格 - text = re.sub(r'\s+', ' ', text).strip() - - # 移除可能的引号 - text = text.strip('"\'') - - return text - - def parse_large_json_file(self, file_path: str) -> List[Dict]: - """解析大型JSON文件,处理可能的格式问题""" - documents = [] - - try: - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read().strip() - - # 尝试不同的解析方法 - if content.startswith('[') and content.endswith(']'): - # 标准JSON数组 - documents = json.loads(content) - else: - # 可能是连续的JSON对象 - # 尝试在}{"之间分割 - if '}{"' in content: - json_strings = content.split('}{') - json_strings[0] += '}' # 第一个对象 - json_strings[-1] = '{' + json_strings[-1] # 最后一个对象 - - for i in range(1, len(json_strings) - 1): - json_strings[i] = '{' + json_strings[i] + '}' - - for json_str in json_strings: - try: - doc = json.loads(json_str) - documents.append(doc) - except json.JSONDecodeError: - continue - else: - # 尝试作为单个JSON对象 - try: - documents = [json.loads(content)] - except json.JSONDecodeError: - pass - - except Exception as e: - print(f"Error parsing {file_path}: {e}") - - return documents - - def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]: - """从文档中提取句子""" - sentences = [] - - title = self.clean_text(doc.get('title', '')) - text = self.clean_text(doc.get('text', '')) - entities = doc.get('entities', []) - triples = doc.get('triples', []) - - # 处理显式三元组 - for triple in triples: - sentence = self.triple_to_sentence(triple) - if sentence: - sentences.append(sentence) - - # 从实体和文本中生成基本句子(如果三元组句子不够) - if title and text and len(sentences) < 5: - # 基于标题和实体生成句子 - entity_names = [] - for entity in entities[:15]: - entity_name = self.clean_text(entity.get('surfaceform', '')) - if entity_name and len(entity_name) > 2: - entity_names.append(entity_name) - - # 生成简单的描述句子 - if entity_names: - important_entities = [] - title_lower = title.lower() - for entity in entity_names: - if (entity.lower() != title_lower and - entity not in important_entities and - not any(t.lower() in entity.lower() for t in title_lower.split()[:2])): - important_entities.append(entity) - if len(important_entities) >= 6: - break - - if important_entities and len(sentences) < 3: - entities_str = ', '.join(important_entities[:3]) - sentences.append(f"{title} is related to {entities_str}.") - - return sentences - - def triple_to_sentence(self, triple: Dict[str, Any]) -> str: - """将三元组转换为自然语言句子""" - try: - subject = triple.get('subject', {}) - predicate = triple.get('predicate', {}) - obj = triple.get('object', {}) - - subject_name = self.clean_text(subject.get('surfaceform', '')) - object_name = self.clean_text(obj.get('surfaceform', '')) - predicate_uri = predicate.get('uri', '') - - # 检查是否有有效的主语和宾语 - if not subject_name or not object_name: - return "" - - # 检查主语和宾语是否过短或无意义 - if len(subject_name) <= 2 or len(object_name) <= 2: - return "" - - # 获取关系文本 - relation_text = self.property_mappings.get(predicate_uri, "is related to") - - # 避免重复的主语宾语 - if subject_name.lower() == object_name.lower(): - return "" - - return f"{subject_name} {relation_text} {object_name}." - - except Exception as e: - print(f"Error converting triple to sentence: {e}") - return "" - - def save_batch_checkpoint(self, processed_sentences: List[Dict[str, Any]], batch_number: int, total_processed_count: int) -> str: - """保存当前批次的检查点文件""" - # 生成检查点文件名,确保在output目录中 - base_name = os.path.splitext(os.path.basename(self.output_file))[0] - checkpoint_filename = os.path.join('output', f"{base_name}_batch_{batch_number}.json") - - # 保存检查点 - with open(checkpoint_filename, 'w', encoding='utf-8') as f: - json.dump({ - "metadata": { - "batch_number": batch_number, - "batch_size": len(processed_sentences), - "total_processed_count": total_processed_count, - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") - }, - "sentences": processed_sentences - }, f, ensure_ascii=False, indent=2) - - return checkpoint_filename - - async def process_files(self) -> List[Dict[str, Any]]: - """处理所有文件""" - json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json")) - - if not json_files: - print(f"No JSON files found in {self.input_dir}") - return [] - - # 排序文件以确保一致的处理顺序 - json_files.sort() - - if self.max_files: - json_files = json_files[:self.max_files] - - print(f"Found {len(json_files)} JSON files to process") - - all_sentences = [] - - for i, file_path in enumerate(json_files): - print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}") - - documents = self.parse_large_json_file(file_path) - print(f" Parsed {len(documents)} documents") - - for doc in documents: - sentences = self.extract_sentences_from_document(doc) - all_sentences.extend(sentences) - - print(f" Generated {len(all_sentences)} total raw sentences so far") - - print(f"总共提取了 {len(all_sentences)} 个原始句子") - - # 去重 - unique_sentences = [] - seen = set() - for sentence in all_sentences: - sentence = sentence.strip() - if sentence and sentence not in seen and len(sentence) > 10: - unique_sentences.append(sentence) - seen.add(sentence) - - print(f"去重后剩余 {len(unique_sentences)} 个句子") - - # 保存原始句子到JSON文件 - sentences_data = { - "metadata": { - "total_sentences": len(unique_sentences), - "extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), - "source_files": len(json_files), - "max_files_limit": self.max_files - }, - "sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences] - } - - with open(self.sentences_json, 'w', encoding='utf-8') as f: - json.dump(sentences_data, f, ensure_ascii=False, indent=2) - - print(f"句子提取完成!已保存到: {self.sentences_json}") - print(f"总计句子数: {len(unique_sentences)}") - - return unique_sentences - - def save_sentences(self, processed_sentences: List[Dict[str, Any]]): - """保存处理后的句子到文件""" - # 确保输出目录存在 - os.makedirs('output', exist_ok=True) - - # 保存为JSON格式,包含完整信息 - json_output_file = self.output_file.replace('.txt', '.json') - with open(json_output_file, 'w', encoding='utf-8') as f: - json.dump(processed_sentences, f, ensure_ascii=False, indent=2) - - # 保存为简单文本格式(仅修正后的句子) - with open(self.output_file, 'w', encoding='utf-8') as f: - for item in processed_sentences: - f.write(item['corrected_sentence'] + '\n') - - # 生成重要性排序文件 - importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True) - importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt') - with open(importance_file, 'w', encoding='utf-8') as f: - for item in importance_sorted: - f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n") - - print(f"保存了 {len(processed_sentences)} 个处理后的句子:") - print(f" - JSON格式: {json_output_file}") - print(f" - 文本格式: {self.output_file}") - print(f" - 重要性排序: {importance_file}") - - # 统计信息 - scores = [item['importance_score'] for item in processed_sentences] - avg_score = sum(scores) / len(scores) if scores else 0 - print(f" - 平均重要性评分: {avg_score:.2f}") - print(f" - 最高评分: {max(scores):.1f}") - print(f" - 最低评分: {min(scores):.1f}") - - async def run(self): - """运行处理流程""" - print("Starting enhanced TREx to sentences conversion...") - processed_sentences = await self.process_files() - self.save_sentences(processed_sentences) - print("Enhanced conversion completed!") - - def find_latest_checkpoint(self) -> Union[tuple, None]: - """查找最新的检查点文件""" - base_name = os.path.splitext(os.path.basename(self.output_file))[0] - pattern = os.path.join('output', f"{base_name}_checkpoint_*.json") - checkpoint_files = glob.glob(pattern) - - if not checkpoint_files: - return None - - # 按检查点编号排序,获取最新的 - latest_file = None - latest_count = 0 - - for file in checkpoint_files: - try: - # 从文件名中提取数字 - match = re.search(r'checkpoint_(\d+)\.json$', file) - if match: - count = int(match.group(1)) - if count > latest_count: - latest_count = count - latest_file = file - except: - continue - - if latest_file: - return latest_file, latest_count - else: - return None - - def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]: - """从检查点文件加载已处理的句子""" - try: - with open(checkpoint_file, 'r', encoding='utf-8') as f: - data = json.load(f) - - if 'sentences' in data: - return data['sentences'] - else: - # 旧格式的检查点文件 - return data - except Exception as e: - print(f"加载检查点文件失败: {e}") - return [] - - def get_processed_sentences_from_checkpoints(self) -> Set[str]: - """从检查点文件中获取已处理过的句子集合""" - if not self.output_file: - return set() - - processed_sentences = set() - base_name = os.path.splitext(os.path.basename(self.output_file))[0] - - # 首先查找新格式的批次文件 - batch_pattern = os.path.join('output', f"{base_name}_batch_*.json") - batch_files = glob.glob(batch_pattern) - - if batch_files: - print(f"找到 {len(batch_files)} 个批次检查点文件") - batch_files.sort() # 确保按顺序处理 - - for batch_file in batch_files: - try: - with open(batch_file, 'r', encoding='utf-8') as f: - data = json.load(f) - - sentences_data = data.get('sentences', []) - for item in sentences_data: - original_sentence = item.get('original_sentence', '') - if original_sentence: - processed_sentences.add(original_sentence) - - batch_number = data.get('metadata', {}).get('batch_number', 'unknown') - print(f" - 从批次 {batch_number} 加载了 {len(sentences_data)} 个句子") - - except Exception as e: - print(f"读取批次文件 {batch_file} 失败: {e}") - continue - - print(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子") - logger.info(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子") - return processed_sentences - - # 如果没有批次文件,尝试查找旧格式的检查点文件 - old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json") - checkpoint_files = glob.glob(old_pattern) - - if not checkpoint_files: - print("未找到检查点文件,将从头开始处理") - return set() - - # 找到最新的检查点文件 - latest_file = None - latest_count = 0 - - for file in checkpoint_files: - try: - match = re.search(r'checkpoint_(\d+)\.json$', file) - if match: - count = int(match.group(1)) - if count > latest_count: - latest_count = count - latest_file = file - except: - continue - - if latest_file: - print(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)") - logger.info(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)") - try: - with open(latest_file, 'r', encoding='utf-8') as f: - data = json.load(f) - - sentences_data = data.get('sentences', []) - for item in sentences_data: - original_sentence = item.get('original_sentence', '') - if original_sentence: - processed_sentences.add(original_sentence) - - print(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子") - logger.info(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子") - - except Exception as e: - print(f"读取检查点文件失败: {e}") - return set() - - return processed_sentences - - async def process_with_llm(self): - """步骤2:从JSON文件读取句子并进行vLLM处理(保持兼容性)""" - await self.process_with_vllm_api() - - async def process_with_vllm_api(self): - """步骤2:从JSON文件读取句子并进行vLLM处理""" - if not self.enable_llm_processing: - print("Error: LLM processing is disabled!") - return - - if not self.output_file: - print("Error: output_file is required for LLM processing!") - return - - print("=== 步骤2:vLLM处理 ===") - - # 读取句子JSON文件 - if not os.path.exists(self.sentences_json): - print(f"Error: Sentences file {self.sentences_json} not found!") - print("请先运行步骤1进行句子提取") - return - - print(f"正在读取句子文件: {self.sentences_json}") - - try: - with open(self.sentences_json, 'r', encoding='utf-8') as f: - data = json.load(f) - - all_sentences = [item["sentence"] for item in data.get("sentences", [])] - print(f"从文件中读取了 {len(all_sentences)} 个句子") - - except Exception as e: - print(f"读取句子文件失败: {e}") - return - - # 获取已处理的句子 - processed_sentences_set = self.get_processed_sentences_from_checkpoints() - - # 过滤出未处理的句子 - unprocessed_sentences = [] - for sentence in all_sentences: - if sentence not in processed_sentences_set: - unprocessed_sentences.append(sentence) - - print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})") - logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})") - - if not unprocessed_sentences: - print("所有句子都已处理完成!") - - # 如果有检查点,直接从最新检查点生成最终文件 - if processed_sentences_set: - latest_checkpoint = self.find_latest_checkpoint() - if latest_checkpoint: - checkpoint_file, _ = latest_checkpoint - processed_data = self.load_checkpoint(checkpoint_file) - self.save_sentences(processed_data) - print("已从检查点生成最终输出文件") - return - - # 处理未处理的句子 - print("开始vLLM处理...") - - # 处理新句子(现在返回空列表,数据保存在批次检查点中) - await self.process_sentences_with_vllm_api(unprocessed_sentences) - - # 处理完成后,合并所有批次检查点生成最终文件 - print("合并所有批次检查点生成最终文件...") - all_processed_sentences = self.merge_all_batch_checkpoints() - - if all_processed_sentences: - # 保存最终结果 - self.save_sentences(all_processed_sentences) - print("vLLM处理完成!") - else: - print("警告:没有找到任何处理结果") - - def merge_all_batch_checkpoints(self) -> List[Dict[str, Any]]: - """合并所有批次检查点文件""" - if not self.output_file: - return [] - - base_name = os.path.splitext(os.path.basename(self.output_file))[0] - - # 查找所有批次检查点文件 - batch_pattern = os.path.join('output', f"{base_name}_batch_*.json") - batch_files = glob.glob(batch_pattern) - - if not batch_files: - # 如果没有批次文件,尝试查找旧格式的检查点文件 - old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json") - old_files = glob.glob(old_pattern) - if old_files: - print("找到旧格式检查点文件,尝试读取最新的...") - latest_checkpoint = self.find_latest_checkpoint() - if latest_checkpoint: - checkpoint_file, _ = latest_checkpoint - return self.load_checkpoint(checkpoint_file) - return [] - - print(f"找到 {len(batch_files)} 个批次检查点文件") - - all_sentences = [] - batch_files.sort() # 确保按顺序处理 - - for batch_file in batch_files: - try: - with open(batch_file, 'r', encoding='utf-8') as f: - data = json.load(f) - - batch_sentences = data.get('sentences', []) - all_sentences.extend(batch_sentences) - - batch_number = data.get('metadata', {}).get('batch_number', 'unknown') - batch_size = len(batch_sentences) - print(f" - 批次 {batch_number}: {batch_size} 个句子") - - except Exception as e: - print(f"读取批次文件 {batch_file} 失败: {e}") - continue - - print(f"总计合并了 {len(all_sentences)} 个句子") - return all_sentences - - def extract_sentences(self): - """步骤1:从TREx数据集提取句子并保存为JSON""" - if not self.input_dir: - print("Error: input_dir is required for sentence extraction!") - return - - print("=== 步骤1:句子提取 ===") - print("开始从TREx数据集提取句子...") - - json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json")) - - if not json_files: - print(f"No JSON files found in {self.input_dir}") - return - - # 排序文件以确保一致的处理顺序 - json_files.sort() - - if self.max_files: - json_files = json_files[:self.max_files] - - print(f"Found {len(json_files)} JSON files to process") - - all_sentences = [] - - for i, file_path in enumerate(json_files): - print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}") - - documents = self.parse_large_json_file(file_path) - print(f" Parsed {len(documents)} documents") - - for doc in documents: - sentences = self.extract_sentences_from_document(doc) - all_sentences.extend(sentences) - - print(f" Generated {len(all_sentences)} total raw sentences so far") - - print(f"总共提取了 {len(all_sentences)} 个原始句子") - - # 去重 - unique_sentences = [] - seen = set() - for sentence in all_sentences: - sentence = sentence.strip() - if sentence and sentence not in seen and len(sentence) > 10: - unique_sentences.append(sentence) - seen.add(sentence) - - print(f"去重后剩余 {len(unique_sentences)} 个句子") - - # 保存原始句子到JSON文件 - sentences_data = { - "metadata": { - "total_sentences": len(unique_sentences), - "extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), - "source_files": len(json_files), - "max_files_limit": self.max_files - }, - "sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences] - } - - with open(self.sentences_json, 'w', encoding='utf-8') as f: - json.dump(sentences_data, f, ensure_ascii=False, indent=2) - - print(f"句子提取完成!已保存到: {self.sentences_json}") - print(f"总计句子数: {len(unique_sentences)}") - - return unique_sentences - - -def main(): - """主函数""" - import argparse - - parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with vLLM processing') - - # 选择运行模式 - parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm', - help='运行步骤: extract=仅提取句子, llm=仅vLLM处理, all=完整流程') - - # 文件路径参数 - parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files') - parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)') - parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)') - - # 处理参数 - parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)') - parser.add_argument('--no_llm', action='store_true', help='Disable vLLM processing (basic mode)') - - args = parser.parse_args() - - # 根据步骤验证参数 - if args.step in ['extract', 'all']: - if not os.path.exists(args.input_dir): - print(f"Error: Input directory {args.input_dir} does not exist!") - return - - if args.step in ['llm', 'all']: - if args.no_llm: - print("Error: Cannot run vLLM step with --no_llm flag!") - return - - # 创建处理器 - processor = EnhancedTRExProcessor( - input_dir=args.input_dir, - sentences_json=args.sentences_json, - output_file=args.output_file, - max_files=args.max_files, - enable_llm_processing=not args.no_llm - ) - - # 根据选择的步骤运行 - if args.step == 'extract': - print("=== 运行模式:仅句子提取 ===") - processor.extract_sentences() - - elif args.step == 'llm': - print("=== 运行模式:仅vLLM处理 ===") - asyncio.run(processor.process_with_vllm_api()) - - elif args.step == 'all': - print("=== 运行模式:完整流程 ===") - - # 步骤1:提取句子 - print("\n--- 开始步骤1:句子提取 ---") - sentences = processor.extract_sentences() - - if not sentences: - print("句子提取失败,退出") - return - - if args.no_llm: - print("vLLM处理已禁用,流程结束") - return - - # 步骤2:vLLM处理 - print("\n--- 开始步骤2:vLLM处理 ---") - asyncio.run(processor.process_with_vllm_api()) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5355e0a..d9002d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ dependencies = [ "regex==2024.11.6", "requests==2.32.3", "rich==13.7.1", + "rouge-score>=0.1.2", "rpds-py==0.24.0", "s3transfer==0.13.0", "safetensors==0.5.3", diff --git a/train_extra_accelerate.py b/train_extra_accelerate.py index d8f5fcd..1015281 100644 --- a/train_extra_accelerate.py +++ b/train_extra_accelerate.py @@ -397,6 +397,66 @@ def log_memory_status(step, accelerator, stage="", detailed=False): Logger(log_msg, accelerator) +# 验证函数 +def validate_model(model, val_loader, accelerator, ctx, args): + """ + 验证模型性能 + Args: + model: 模型 + val_loader: 验证集数据加载器 + accelerator: accelerator对象 + ctx: 上下文管理器 + args: 参数 + Returns: + dict: 包含平均损失和准确率的字典 + """ + model.eval() + + total_loss = 0.0 + correct_predictions = 0 + total_predictions = 0 + num_batches = 0 + + criterion = nn.CrossEntropyLoss() + + with torch.no_grad(): + for batch_data in val_loader: + try: + # 数据准备 + X = batch_data['input_ids'].to(accelerator.device) + Y = batch_data['labels'] + + # 前向传播 + with ctx: + res = model(X, step=0) # 验证时step设为0 + loss = criterion(res.predicate_class.cpu(), Y.cpu()) + + # 计算准确率 + predicted_classes = torch.argmax(res.predicate_class, dim=1) + predicted_classes = predicted_classes.to(Y.device) + correct_predictions += (predicted_classes == Y).sum().item() + total_predictions += Y.size(0) + + # 累计损失 + total_loss += loss.item() + num_batches += 1 + + except Exception as e: + Logger(f"验证时出错: {e}", accelerator) + continue + + # 计算平均值 + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 + accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 + + model.train() # 重新设置为训练模式 + + return { + 'avg_loss': avg_loss, + 'accuracy': accuracy, + 'total_samples': total_predictions + } + # 日志记录函数 def Logger(msg, accelerator=None): # 如果没有提供accelerator,则只在主进程打印 @@ -515,7 +575,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') return model, tokenizer -def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer): +def train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer): # 三元组提取训练模式:不需要传统的交叉熵损失函数 epoch_start_time = time.time() total_steps_in_epoch = len(train_loader) @@ -563,9 +623,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a X = batch_data['input_ids'] Y = batch_data['labels'] loss_mask = batch_data['loss_mask'] - target_input_ids = batch_data['target_input_ids'] - target_attention_mask = batch_data['target_attention_mask'] - target_sentences = batch_data['target_sentences'] # 用于调试输出 + # target_input_ids = batch_data['target_input_ids'] + # target_attention_mask = batch_data['target_attention_mask'] + # target_sentences = batch_data['target_sentences'] # 用于调试输出 # === 2. 学习率更新 === if scheduler is not None: @@ -590,36 +650,34 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # === 4. 损失计算 === # 三元组提取模式:只使用ROUGE Loss进行三元组损失计算 - Logger("三元组提取训练模式", accelerator) if step == 0 else None + # Logger("三元组提取训练模式", accelerator) if step == 0 else None - # 确保有三元组输出 - if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')): - raise ValueError("模型没有输出三元组logits,请检查模型配置") + # # 确保有三元组输出 + # if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')): + # raise ValueError("模型没有输出三元组logits,请检查模型配置") - # 确保有目标数据 - if target_input_ids is None: - raise ValueError("没有三元组目标数据,请检查数据格式") + # # 确保有目标数据 + # if target_input_ids is None: + # raise ValueError("没有三元组目标数据,请检查数据格式") - # 计算三元组损失 + # 计算分类损失 try: - Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None + Logger("使用分类交叉熵损失", accelerator) if step == 0 else None # 计时GPU损失计算 if args.profile and accelerator.is_main_process and loss_start is not None: loss_start.record() - # 计算优化后的嵌入余弦相似度损失 - loss = compute_triple_rouge_loss_optimized( - res.subject_logits, res.predicate_logits, res.object_logits, - target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature - ) + # 计算交叉熵损失 + criterion = nn.CrossEntropyLoss() + loss = criterion(res.predicate_class, Y) # 计时GPU损失计算结束 if args.profile and accelerator.is_main_process and loss_end is not None: loss_end.record() except Exception as e: - Logger(f"Error: ROUGE loss computation failed: {e}", accelerator) + Logger(f"Error: 分类损失计算失败: {e}", accelerator) import traceback Logger(f"Traceback: {traceback.format_exc()}", accelerator) loss = res.logits.sum() * 0.0 + 1.0 @@ -683,13 +741,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator) Logger("=" * 50, accelerator) - Logger("=== 三元组预测示例 ===", accelerator) - predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer) - # 显示前2个样本的目标句子 - for i, target_sentence in enumerate(target_sentences[:2]): - Logger(f"样本{i+1}目标: {target_sentence}", accelerator) - Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator) - Logger("==================", accelerator) + # Logger("=== 三元组预测示例 ===", accelerator) + # predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer) + # # 显示前2个样本的目标句子 + # for i, target_sentence in enumerate(target_sentences[:2]): + # Logger(f"样本{i+1}目标: {target_sentence}", accelerator) + # Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator) + Logger("=======val dataset=========", accelerator) # 重置GPU事件 forward_start = torch.cuda.Event(enable_timing=True) @@ -734,11 +792,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # SwanLab日志记录 if args.use_swanlab and accelerator.is_main_process and swanlab_run: + Logger("=======val dataset=========", accelerator) + + # 验证集评估 + val_results = validate_model(model, val_loader, accelerator, ctx, args) + Logger(f"验证集结果 - 平均损失: {val_results['avg_loss']:.6f}, 准确率: {val_results['accuracy']:.4f}, 样本数: {val_results['total_samples']}", accelerator) + log_dict = { "epoch": epoch + 1, "step": step + 1, "total_steps_in_epoch": total_steps_in_epoch, - "triple_embedding_cosine_loss": loss.item() * args.accumulation_steps, + "train_loss": loss.item() * args.accumulation_steps, + "val_loss": val_results['avg_loss'], + "val_accuracy": val_results['accuracy'], + "val_samples": val_results['total_samples'], "lr": current_lr, "tokens_per_sec": tokens_per_sec, "epoch_time_left_seconds": epoch_remaining_time, @@ -776,7 +843,7 @@ def main(): parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数") - parser.add_argument("--batch_size", type=int, default=192) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数 @@ -793,6 +860,7 @@ def main(): parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json") + parser.add_argument("--predicate_vocab_path", type=str, default="./dataset/predicate_stats.json", help="Path to predicate vocabulary/statistics file") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") @@ -932,7 +1000,8 @@ def main(): # 创建数据集和数据加载器(专用于三元组提取训练) ######################################################### Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator) - train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) + train_ds = TriplePretrainDataset(data_path=args.data_path, predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len) + val_ds = TriplePretrainDataset(data_path=args.data_path,samples=train_ds.get_val_samples(), predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len) # 创建自定义collate_fn来处理优化后的数据格式 def triple_collate_fn(batch): @@ -940,17 +1009,17 @@ def main(): input_ids = torch.stack([item['input_ids'] for item in batch]) labels = torch.stack([item['labels'] for item in batch]) loss_mask = torch.stack([item['loss_mask'] for item in batch]) - target_input_ids = torch.stack([item['target_input_ids'] for item in batch]) - target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch]) - target_sentences = [item['target_sentence'] for item in batch] # 用于调试 + # target_input_ids = torch.stack([item['target_input_ids'] for item in batch]) + # target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch]) + # target_sentences = [item['target_sentence'] for item in batch] # 用于调试 return { 'input_ids': input_ids, 'labels': labels, 'loss_mask': loss_mask, - 'target_input_ids': target_input_ids, - 'target_attention_mask': target_attention_mask, - 'target_sentences': target_sentences + # 'target_input_ids': target_input_ids, + # 'target_attention_mask': target_attention_mask, + # 'target_sentences': target_sentences } train_loader = DataLoader( @@ -963,6 +1032,15 @@ def main(): # persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用 collate_fn=triple_collate_fn ) + val_loader = DataLoader( + val_ds, + batch_size=args.batch_size, + pin_memory=False, + drop_last=True, + shuffle=False, + num_workers=0, + collate_fn=triple_collate_fn + ) ######################################################### # 创建优化器 @@ -993,7 +1071,7 @@ def main(): overall_start_time = time.time() # Record overall start time for epoch in range(args.epochs): Logger(f"开始第{epoch+1}轮训练", accelerator) - train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer + train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer # 每个epoch结束后进行内存清理 Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)