From 74e9293c9a48a0d1b401b48524369f8c65e9b91e Mon Sep 17 00:00:00 2001 From: Yu Chengzhang Date: Sun, 29 Jun 2025 16:01:36 +0800 Subject: [PATCH] DynamicKV-LLM Extra v1.0.0 --- model/dataset.py | 137 ++++ preprocessing/preprocess_trex.py | 442 ++++++++++++ preprocessing/preprocess_triple.py | 441 ++++++++++++ train_extra_accelerate.py | 1022 ++++++++++++++++++++++++++++ 4 files changed, 2042 insertions(+) create mode 100644 preprocessing/preprocess_trex.py create mode 100644 preprocessing/preprocess_triple.py create mode 100644 train_extra_accelerate.py diff --git a/model/dataset.py b/model/dataset.py index 14acc6c..6658eca 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -9,6 +9,7 @@ import torch from sklearn.model_selection import train_test_split import os import ast +from tqdm import tqdm os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -196,6 +197,142 @@ class DPODataset(Dataset): return loss_mask +class TriplePretrainDataset(Dataset): + """ + 优化的三元组预训练数据集 + - 每个样本只保留一个target三元组 + - 预先tokenize所有数据 + - 使用进度条显示处理进度 + """ + def __init__(self, data_path, tokenizer, max_length=512): + super().__init__() + self.tokenizer = tokenizer + self.max_length = max_length + print("🚀 开始加载和预处理三元组数据...") + self.samples = self.load_and_preprocess_data(data_path) + + def load_and_preprocess_data(self, path): + """加载并预处理三元组数据""" + # 1. 加载原始数据 + print("📂 加载原始数据...") + if path.endswith('.json'): + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f) + elif path.endswith('.jsonl'): + data = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + data.append(json.loads(line.strip())) + else: + raise ValueError(f"Unsupported file format: {path}") + + print(f"📊 原始数据量: {len(data)} 个样本") + + # 2. 数据验证和筛选(只保留一个target) + print("🔍 验证数据格式并选择单个target...") + valid_samples = [] + + for i, sample in enumerate(tqdm(data, desc="验证数据格式")): + if not isinstance(sample, dict) or 'text' not in sample: + continue + + targets = sample.get('target', []) + if not isinstance(targets, list) or len(targets) == 0: + # 如果没有有效的target,创建一个默认的 + selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} + else: + # 验证并选择第一个有效的target + selected_target = None + for triple in targets: + if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']): + selected_target = triple + break + + # 如果没有找到有效的target,使用默认值 + if selected_target is None: + selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} + + valid_samples.append({ + 'text': sample['text'], + 'target': selected_target # 只保留一个target + }) + + print(f"✅ 有效样本数: {len(valid_samples)}") + + # 3. 分批tokenize目标句子 + print("🔤 分批tokenize目标句子...") + + processed_samples = [] + batch_size = 1000 # 每批处理1000个句子,避免内存爆炸 + + for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"): + # 获取当前批次 + batch_samples = valid_samples[i:i + batch_size] + + # 提取当前批次的目标句子 + batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples] + + # 批量tokenize当前批次 + batch_encodings = self.tokenizer( + batch_target_sentences, + max_length=128, # 目标句子通常较短 + padding='max_length', + truncation=True, + return_tensors='pt' + ) + + # 构建当前批次的样本数据 + for j, sample in enumerate(batch_samples): + processed_samples.append({ + 'text': sample['text'], # 保持原始文本,不进行tokenize + 'target_input_ids': batch_encodings.input_ids[j], + 'target_attention_mask': batch_encodings.attention_mask[j], + 'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试 + }) + + print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本") + return processed_samples + + def __len__(self): + return len(self.samples) + + def _triple_to_sentence(self, triple): + """将三元组转换为句子格式""" + return f"{triple['subject']} {triple['predicate']} {triple['object']}" + + def __getitem__(self, index): + """返回数据,输入文本在运行时tokenize,目标已预tokenize""" + sample = self.samples[index] + + # 在运行时tokenize输入文本(用于语言建模) + input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}" + encoding = self.tokenizer( + input_text, + max_length=self.max_length, + padding='max_length', + truncation=True, + return_tensors='pt' + ) + input_ids = encoding.input_ids.squeeze() + loss_mask = (input_ids != self.tokenizer.pad_token_id) + + # 构建训练数据 + X = input_ids[:-1] + Y = input_ids[1:] + loss_mask = loss_mask[1:] + + return { + 'input_ids': X, + 'labels': Y, + 'loss_mask': loss_mask, + 'target_input_ids': sample['target_input_ids'], # 已经是tensor + 'target_attention_mask': sample['target_attention_mask'], # 已经是tensor + 'target_sentence': sample['target_sentence'], # 字符串,用于调试 + 'original_text': sample['text'] + } + + class RLAIFDataset(Dataset): def __init__(self, jsonl_path, tokenizer, max_length=1024): super().__init__() diff --git a/preprocessing/preprocess_trex.py b/preprocessing/preprocess_trex.py new file mode 100644 index 0000000..eb31c67 --- /dev/null +++ b/preprocessing/preprocess_trex.py @@ -0,0 +1,442 @@ +import json +import os +import argparse +from typing import List, Dict, Any, Optional +from collections import defaultdict +import pickle +from pathlib import Path + +class WikidataRelationManager: + """Wikidata关系管理器,支持动态获取和缓存""" + + def __init__(self, cache_file: str = "wikidata_relations_cache.pkl", + mapping_file: str = None): + self.cache_file = cache_file + self.mapping_file = mapping_file + self.relations = {} + # 删除了API相关属性 + + # 初始的基础关系映射 + self.base_relations = { + # # 基本关系 + # 'P31': 'instance of', + # 'P279': 'subclass of', + # 'P17': 'country', + # 'P159': 'headquarters location', + # 'P571': 'inception', + + # # 人物关系 + # 'P19': 'place of birth', + # 'P20': 'place of death', + # 'P27': 'country of citizenship', + # 'P106': 'occupation', + # 'P22': 'father', + # 'P25': 'mother', + # 'P26': 'spouse', + # 'P40': 'child', + # 'P69': 'educated at', + # 'P108': 'employer', + + # # 地理关系 + # 'P36': 'capital', + # 'P131': 'located in', + # 'P47': 'shares border with', + # 'P206': 'located on terrain feature', + # 'P1376': 'capital of', + + # # 组织关系 + # 'P112': 'founded by', + # 'P127': 'owned by', + # 'P169': 'chief executive officer', + # 'P488': 'chairperson', + # 'P749': 'parent organization', + + # # 作品关系 + # 'P50': 'author', + # 'P57': 'director', + # 'P58': 'screenwriter', + # 'P161': 'cast member', + # 'P175': 'performer', + # 'P577': 'publication date', + # 'P123': 'publisher', + # 'P136': 'genre', + + # # 时间关系 + # 'P155': 'follows', + # 'P156': 'followed by', + # 'P580': 'start time', + # 'P582': 'end time', + + # # 体育关系 + # 'P54': 'member of sports team', + # 'P413': 'position played on team', + # 'P118': 'league', + + # # 科学关系 + # 'P275': 'copyright license', + # 'P170': 'creator', + # 'P398': 'child astronomical body', + # 'P397': 'parent astronomical body', + + # # 其他常见关系 + # 'P37': 'official language', + # 'P1923': 'place of marriage', + # 'P737': 'influenced by', + # 'P463': 'member of', + # 'P39': 'position held', + # 'P276': 'location', + # 'P1441': 'present in work', + } + + self.load_cache() + + def load_cache(self): + """加载缓存的关系映射,优先使用JSON映射文件""" + try: + # 优先尝试加载JSON映射文件 + if self.mapping_file and os.path.exists(self.mapping_file): + with open(self.mapping_file, 'r', encoding='utf-8') as f: + self.relations = json.load(f) + print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射") + return + + # 尝试加载pickle缓存文件 + if os.path.exists(self.cache_file): + with open(self.cache_file, 'rb') as f: + self.relations = pickle.load(f) + print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射") + else: + self.relations = self.base_relations.copy() + print(f"初始化基础关系映射: {len(self.relations)} 个") + except Exception as e: + print(f"加载缓存失败: {e}") + self.relations = self.base_relations.copy() + + def save_cache(self): + """保存关系映射到缓存""" + try: + with open(self.cache_file, 'wb') as f: + pickle.dump(self.relations, f) + print(f"已保存 {len(self.relations)} 个关系映射到缓存") + except Exception as e: + print(f"保存缓存失败: {e}") + + # 删除了网络抓取功能,改为纯离线模式 + + def get_relation_name(self, property_id: str) -> Optional[str]: + """获取关系名称,仅使用本地映射""" + if property_id in self.relations: + return self.relations[property_id] + + # 如果本地映射中没有找到,返回None(表示跳过这个关系) + return None + + # 删除了网络请求相关的批量获取和预加载功能 + +class TRexProcessor: + """T-REx数据集处理器""" + + def __init__(self, relation_manager: WikidataRelationManager): + self.relation_manager = relation_manager + + def extract_predicate_id(self, uri: str) -> str: + """从URI中提取属性ID""" + if uri and 'prop/direct/' in uri: + return uri.split('/')[-1] + elif uri and uri.startswith('P') and uri[1:].isdigit(): + return uri + return uri if uri else 'unknown' + + def get_relation_name(self, predicate_uri: str) -> Optional[str]: + """获取关系的可读名称""" + predicate_id = self.extract_predicate_id(predicate_uri) + return self.relation_manager.get_relation_name(predicate_id) + + # 删除了谓词收集功能,因为不再需要预加载 + + def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float, + boundary_threshold: int) -> bool: + """检查三元组是否满足过滤条件""" + try: + # 检查triple是否为字典 + if not isinstance(triple, dict): + return False + + # 检查必要字段 + if not all(key in triple for key in ['subject', 'predicate', 'object']): + return False + + subject = triple['subject'] + predicate = triple['predicate'] + object_info = triple['object'] + + # 检查subject、predicate、object是否都为字典 + if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict): + return False + + # 检查主语和宾语是否有有效的URI和surfaceform + if not (subject.get('uri') and subject.get('surfaceform')): + return False + if not (object_info.get('uri') and object_info.get('surfaceform')): + return False + if not predicate.get('uri'): + return False + + # 检查置信度(如果存在) + confidence = triple.get('confidence') + if confidence is not None and confidence < confidence_threshold: + return False + + # 检查边界信息(如果设置了阈值) + if boundary_threshold > 0: + subject_boundaries = subject.get('boundaries') + object_boundaries = object_info.get('boundaries') + + if not subject_boundaries or not object_boundaries: + return False + + # 检查边界是否为列表且长度至少为2 + if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2): + return False + if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2): + return False + + try: + # 检查边界长度是否合理 + subject_length = subject_boundaries[1] - subject_boundaries[0] + object_length = object_boundaries[1] - object_boundaries[0] + + if subject_length < boundary_threshold or object_length < boundary_threshold: + return False + except (TypeError, IndexError): + return False + + # 检查文本内容是否合理 + subject_text = subject.get('surfaceform', '').strip() + object_text = object_info.get('surfaceform', '').strip() + + if not subject_text or not object_text: + return False + + # 过滤掉过长或过短的实体 + if len(subject_text) > 100 or len(object_text) > 100: + return False + if len(subject_text) < 2 or len(object_text) < 2: + return False + + return True + + except (KeyError, TypeError, AttributeError): + return False + + def process_single_file(self, file_path: str, confidence_threshold: float, + boundary_threshold: int) -> List[Dict[str, Any]]: + """处理单个JSON文件""" + print(f"Processing file: {file_path}") + + processed_data = [] + + try: + with open(file_path, 'r', encoding='utf-8') as f: + # 读取整个文件作为JSON数组 + print(f"正在加载JSON数组文件: {file_path}") + data_list = json.load(f) + print(f"文件包含 {len(data_list)} 个条目") + + for idx, data in enumerate(data_list): + try: + # 获取基本信息 + text = data.get('text', '').strip() + if not text: + continue + + # 处理三元组 + triples = data.get('triples', []) + if not triples: + continue + + valid_targets = [] + + for triple in triples: + if self.is_valid_triple(triple, confidence_threshold, boundary_threshold): + # 获取关系名称,如果无法解析则跳过这个三元组 + relation_name = self.get_relation_name(triple['predicate']['uri']) + if relation_name is None: + continue # 跳过无法解析的关系 + + target = { + 'subject': triple['subject']['surfaceform'].strip(), + 'predicate': relation_name, + 'object': triple['object']['surfaceform'].strip() + } + valid_targets.append(target) + + # 如果有有效的三元组,添加到结果中 + if valid_targets: + processed_data.append({ + 'text': text, + 'target': valid_targets + }) + + except Exception as e: + if idx <= 10: # 只打印前10个错误 + print(f"处理条目时出错 in {file_path} at index {idx}: {e}") + continue + + except FileNotFoundError: + print(f"文件未找到: {file_path}") + except json.JSONDecodeError as e: + print(f"JSON解析错误 in {file_path}: {e}") + except Exception as e: + print(f"处理文件时出错 {file_path}: {e}") + + print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本") + return processed_data + + def process_folder(self, folder_path: str, confidence_threshold: float, + boundary_threshold: int) -> List[Dict[str, Any]]: + """处理文件夹中的所有JSON文件""" + all_processed_data = [] + + if not os.path.exists(folder_path): + raise FileNotFoundError(f"文件夹不存在: {folder_path}") + + # 获取所有JSON文件 + json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')] + + if not json_files: + raise ValueError(f"在 {folder_path} 中没有找到JSON文件") + + print(f"找到 {len(json_files)} 个JSON文件") + + for filename in sorted(json_files): + file_path = os.path.join(folder_path, filename) + processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold) + all_processed_data.extend(processed_data) + + # 保存最终的关系缓存 + self.relation_manager.save_cache() + + return all_processed_data + + def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]: + """生成数据统计信息""" + total_samples = len(processed_data) + total_triples = sum(len(sample['target']) for sample in processed_data) + + # 统计关系类型 + relation_counts = defaultdict(int) + for sample in processed_data: + for target in sample['target']: + relation_counts[target['predicate']] += 1 + + # 统计文本长度 + text_lengths = [len(sample['text']) for sample in processed_data] + avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0 + + # 统计每个文本的三元组数量 + triples_per_text = [len(sample['target']) for sample in processed_data] + avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0 + + return { + 'total_samples': total_samples, + 'total_triples': total_triples, + 'avg_text_length': round(avg_text_length, 2), + 'avg_triples_per_text': round(avg_triples_per_text, 2), + 'relation_distribution': dict(sorted(relation_counts.items(), + key=lambda x: x[1], reverse=True)), + 'top_10_relations': dict(list(sorted(relation_counts.items(), + key=lambda x: x[1], reverse=True))[:10]), + 'total_unique_relations': len(relation_counts), + 'cached_relations': len(self.relation_manager.relations) + } + +def main(): + parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)') + parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径') + parser.add_argument('--confidence_threshold', type=float, default=0.5, + help='置信度阈值 (默认: 0.0)') + parser.add_argument('--boundary_threshold', type=int, default=0, + help='边界长度阈值 (默认: 0, 不过滤)') + parser.add_argument('--output', type=str, default='./processed_trex_data.json', + help='输出文件名 (默认: processed_trex_data.json)') + parser.add_argument('--stats', type=str, default='trex_statistics.json', + help='统计信息输出文件名 (默认: trex_statistics.json)') + parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl', + help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)') + parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json", + help='JSON映射文件路径 (必须提供,用于关系名称映射)') + + args = parser.parse_args() + + print("T-REx数据集处理器(支持动态关系获取)") + print("=" * 60) + print(f"输入文件夹: {args.folder_path}") + print(f"置信度阈值: {args.confidence_threshold}") + print(f"边界长度阈值: {args.boundary_threshold}") + print(f"输出文件: {args.output}") + print(f"关系缓存文件: {args.cache_file}") + print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}") + print("=" * 60) + + # 检查映射文件是否存在 + if not args.mapping_file or not os.path.exists(args.mapping_file): + print(f"错误: 映射文件不存在或未指定: {args.mapping_file}") + print("请确保提供有效的JSON映射文件。") + return 1 + + # 创建关系管理器 + relation_manager = WikidataRelationManager( + cache_file=args.cache_file, + mapping_file=args.mapping_file + ) + + # 创建处理器 + processor = TRexProcessor(relation_manager) + + try: + # 处理数据 + processed_data = processor.process_folder( + args.folder_path, + args.confidence_threshold, + args.boundary_threshold + ) + + print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本") + + # 生成统计信息 + stats = processor.generate_statistics(processed_data) + + # 保存处理后的数据 + with open(args.output, 'w', encoding='utf-8') as f: + json.dump(processed_data, f, ensure_ascii=False, indent=2) + + # 保存统计信息 + with open(args.stats, 'w', encoding='utf-8') as f: + json.dump(stats, f, ensure_ascii=False, indent=2) + + print(f"\n数据已保存到: {args.output}") + print(f"统计信息已保存到: {args.stats}") + print(f"关系缓存已保存到: {args.cache_file}") + + # 打印统计摘要 + print("\n数据统计摘要:") + print("=" * 30) + print(f"总样本数: {stats['total_samples']}") + print(f"总三元组数: {stats['total_triples']}") + print(f"唯一关系数: {stats['total_unique_relations']}") + print(f"缓存关系数: {stats['cached_relations']}") + print(f"平均文本长度: {stats['avg_text_length']}") + print(f"平均每文本三元组数: {stats['avg_triples_per_text']}") + print("\n前10个最常见关系:") + for relation, count in stats['top_10_relations'].items(): + print(f" {relation}: {count}") + + except Exception as e: + print(f"处理过程中出错: {e}") + return 1 + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/preprocessing/preprocess_triple.py b/preprocessing/preprocess_triple.py new file mode 100644 index 0000000..35a04f7 --- /dev/null +++ b/preprocessing/preprocess_triple.py @@ -0,0 +1,441 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import json +import re +import asyncio +import aiofiles +from concurrent.futures import ThreadPoolExecutor +from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent +from typing import Dict, List, Tuple +import gc +import time +import psutil +from tqdm.asyncio import tqdm as async_tqdm +from tqdm import tqdm + +json_path = "dataset/merged_pretrain_extra.jsonl" +output_path = "dataset/processed_triples.jsonl" + +# 优化后的配置参数 - 降低资源消耗 +BATCH_SIZE = 5000 # 减少批次大小:每批1万条数据 +MAX_CONCURRENT = 200 # 减少并发数:最多50条并发处理 +AGENT_POOL_SIZE = 20 # 大幅减少agent池大小:只创建5个agent实例 + +def get_memory_usage(): + """获取当前内存使用情况""" + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + memory_mb = memory_info.rss / 1024 / 1024 + return memory_mb + +def print_memory_info(stage=""): + """打印内存使用信息""" + memory_mb = get_memory_usage() + print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB") + +# 创建extractor_agent池,避免并发冲突 +def create_extractor_pool(pool_size: int = 5): + """创建extractor_agent池""" + print(f"正在创建 {pool_size} 个agent实例...") + agents = [] + for i in range(pool_size): + try: + agent = DepartmentAgent(model_type="deepseek") + agents.append(agent) + print(f" ✓ Agent {i+1}/{pool_size} 创建成功") + except Exception as e: + print(f" ✗ Agent {i+1} 创建失败: {e}") + print(f"Agent池创建完成,实际创建了 {len(agents)} 个实例") + return agents + +# 延迟初始化agent池 +AGENT_POOL = None +agent_pool_index = 0 + +def get_agent_pool(): + """获取agent池,延迟初始化""" + global AGENT_POOL + if AGENT_POOL is None: + print_memory_info("创建Agent池前") + AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE) + print_memory_info("创建Agent池后") + return AGENT_POOL + +def get_next_agent(): + """轮询获取下一个可用的agent""" + global agent_pool_index + pool = get_agent_pool() + agent = pool[agent_pool_index % len(pool)] + agent_pool_index += 1 + return agent + +def clean_and_split_text(text): + """ + 去除文本开头结尾的标记,并按句子分割 + """ + # 去除开头的<|im_start|>和结尾的<|im_end|> + text = text.strip() + if text.startswith('<|im_start|>'): + text = text[len('<|im_start|>'):] + if text.endswith('<|im_end|>'): + text = text[:-len('<|im_end|>')] + + # 清理文本,去除多余的空白字符 + text = text.strip() + + # 按句子分割(根据句号、问号、感叹号等标点符号) + # 使用正则表达式匹配句子结束标志 + sentence_endings = r'[.!?。!?]' + sentences = re.split(sentence_endings, text) + + # 清理每个句子,去除空白和空句子 + cleaned_sentences = [] + for sentence in sentences: + sentence = sentence.strip() + if sentence and len(sentence) > 5: # 只保留非空且有意义的句子 + cleaned_sentences.append(sentence) + + return cleaned_sentences + +async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict: + """ + 异步使用extractor_agent从句子中提取三元组 + """ + try: + # 获取一个agent实例 + agent = get_next_agent() + result = await agent.async_run(sentence=sentence, context=context) + return { + "sentence": sentence, + "triple": { + "subject": result.triple.subject, + "predicate": result.triple.predicate, + "object": result.triple.object + }, + "confidence": result.confidence + } + except Exception as e: + return { + "sentence": sentence, + "triple": { + "subject": "", + "predicate": "", + "object": "" + }, + "confidence": 0.0, + "error": str(e) + } + +async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict: + """ + 异步处理单个段落 + """ + async with semaphore: # 控制并发数量 + try: + # 清理并分割文本 + sentences = clean_and_split_text(original_text) + + if not sentences: + return None + + # 构建当前段落的结果 + paragraph_result = { + "source_line": line_num, + "original_paragraph": original_text, + "sentences": [], + "triples": [] + } + + # 异步处理所有句子 + tasks = [] + for sentence in sentences: + task = extract_triple_from_sentence_async(sentence, context=original_text) + tasks.append(task) + + # 等待所有句子处理完成 + triple_results = await asyncio.gather(*tasks) + + # 整理结果 + for i, sentence in enumerate(sentences): + paragraph_result["sentences"].append(sentence) + paragraph_result["triples"].append(triple_results[i]) + + return paragraph_result + + except Exception as e: + print(f"处理第 {line_num} 行时出错: {e}") + return None + +async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]: + """ + 异步处理一个批次的数据,带进度条和内存监控 + """ + print(f"\n=== 异步处理批次 {batch_num} ===") + print(f"批次大小: {len(batch_data)} 条记录") + print_memory_info(f"批次 {batch_num} 开始前") + + start_time = time.time() + + # 创建信号量控制并发数量 + semaphore = asyncio.Semaphore(MAX_CONCURRENT) + + # 分块处理任务,避免一次性创建太多任务 + chunk_size = 1000 # 每次处理1000个任务 + all_results = [] + + for chunk_start in range(0, len(batch_data), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(batch_data)) + chunk_data = batch_data[chunk_start:chunk_end] + + print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)") + + # 创建当前块的异步任务 + tasks = [] + for line_num, original_text in chunk_data: + task = process_paragraph_async(line_num, original_text, semaphore) + tasks.append(task) + + # 使用进度条处理当前块 + progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100) + + chunk_results = [] + completed_tasks = 0 + + # 使用as_completed来获取完成的任务,并更新进度条 + for coro in asyncio.as_completed(tasks): + try: + result = await coro + chunk_results.append(result) + completed_tasks += 1 + + # 更新进度条 + progress_bar.update(1) + + # 每完成50个任务更新一次描述 + if completed_tasks % 50 == 0: + valid_results = [r for r in chunk_results if r is not None] + progress_bar.set_postfix({ + '有效': len(valid_results), + '完成': completed_tasks, + '成功率': f"{len(valid_results)/completed_tasks*100:.1f}%" + }) + except Exception as e: + print(f"任务执行失败: {e}") + completed_tasks += 1 + progress_bar.update(1) + + progress_bar.close() + all_results.extend(chunk_results) + + # 每个块完成后清理内存 + del tasks, chunk_results + gc.collect() + + print_memory_info(f"批次 {batch_num} 块 {chunk_start//chunk_size + 1} 完成后") + + # 过滤None结果 + valid_results = [result for result in all_results if result is not None] + + # 统计信息 + batch_sentences = sum(len(result["sentences"]) for result in valid_results) + batch_triples = sum( + sum(1 for triple in result["triples"] if triple["confidence"] > 0.0) + for result in valid_results + ) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"批次 {batch_num} 异步处理完成:") + print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)") + print(f" - 总句子数: {batch_sentences}") + print(f" - 成功三元组: {batch_triples}") + print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子") + print(f" - 处理时间: {processing_time:.2f}秒") + print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒") + + print_memory_info(f"批次 {batch_num} 完成后") + + return valid_results + +async def write_results_batch(results: List[Dict], output_path: str): + """ + 异步批量写入结果,带进度提示 + """ + try: + print(f"开始批量写入 {len(results)} 条结果...") + + # 准备写入内容 + content_lines = [] + for result in results: + content_lines.append(json.dumps(result, ensure_ascii=False)) + + # 异步批量写入 + async with aiofiles.open(output_path, "a", encoding="utf-8") as f: + await f.write("\n".join(content_lines) + "\n") + + print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}") + + except Exception as e: + print(f"✗ 批量写入失败: {e}") + print("尝试逐条写入...") + + # 如果批量写入失败,回退到逐条写入(带进度条) + async with aiofiles.open(output_path, "a", encoding="utf-8") as f: + for result in tqdm(results, desc="逐条写入", unit="条"): + await f.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"✓ 逐条写入完成") + +# 主处理流程 +async def main_async(): + total_processed = 0 + total_sentences = 0 + total_triples = 0 + batch_num = 0 + + print("=== 开始异步批次处理JSONL文件 ===") + print(f"优化后的配置信息:") + print(f" - 批次大小: {BATCH_SIZE:,} 条记录") + print(f" - 最大并发数: {MAX_CONCURRENT}") + print(f" - Agent池大小: {AGENT_POOL_SIZE}") + print(f" - 输入文件: {json_path}") + print(f" - 输出文件: {output_path}") + print() + + print_memory_info("程序开始") + + # 清空输出文件 + async with aiofiles.open(output_path, "w", encoding="utf-8") as f: + pass + + # 读取并处理数据 + with open(json_path, "r", encoding="utf-8") as f_in: + batch_data = [] + + for line_num, line in enumerate(f_in): + if line.strip(): # 跳过空行 + try: + item = json.loads(line) + original_text = item.get("text", "") + + if original_text: + batch_data.append((line_num + 1, original_text)) + + # 当批次达到指定大小时,异步处理这个批次 + if len(batch_data) >= BATCH_SIZE: + batch_num += 1 + + # 异步处理批次 + batch_results = await process_batch_async(batch_data, batch_num) + + # 批量写入结果 + if batch_results: + await write_results_batch(batch_results, output_path) + + # 统计信息 + batch_sentences = sum(len(result["sentences"]) for result in batch_results) + batch_triples = sum( + sum(1 for triple in result["triples"] if triple["confidence"] > 0.0) + for result in batch_results + ) + + total_processed += len(batch_data) + total_sentences += batch_sentences + total_triples += batch_triples + + print(f"\n📊 批次 {batch_num} 累计统计:") + print(f" - 累计处理段落: {total_processed:,}") + print(f" - 累计句子数: {total_sentences:,}") + print(f" - 累计三元组: {total_triples:,}") + print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%") + print("-" * 80) + + # 清理批次数据,释放内存 + batch_data.clear() + batch_results.clear() + gc.collect() # 强制垃圾回收 + + print_memory_info(f"批次 {batch_num} 清理后") + + except json.JSONDecodeError as e: + print(f"第 {line_num + 1} 行JSON解析错误: {e}") + except Exception as e: + print(f"处理第 {line_num + 1} 行时出错: {e}") + + # 处理最后一个不完整的批次 + if batch_data: + batch_num += 1 + batch_results = await process_batch_async(batch_data, batch_num) + + if batch_results: + await write_results_batch(batch_results, output_path) + + batch_sentences = sum(len(result["sentences"]) for result in batch_results) + batch_triples = sum( + sum(1 for triple in result["triples"] if triple["confidence"] > 0.0) + for result in batch_results + ) + + total_processed += len(batch_data) + total_sentences += batch_sentences + total_triples += batch_triples + + # 最终统计 + print(f"\n{'='*80}") + print(f"🎉 所有批次异步处理完成!") + print(f"{'='*80}") + print(f"最终统计:") + print(f" - 总批次数: {batch_num}") + print(f" - 总段落数: {total_processed:,}") + print(f" - 总句子数: {total_sentences:,}") + print(f" - 总三元组: {total_triples:,}") + print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子") + print(f" - 输出文件: {output_path}") + print(f"{'='*80}") + + print_memory_info("程序结束前") + + # 显示示例结果 + await show_sample_results() + +async def show_sample_results(): + """显示前几个处理结果作为示例""" + print("\n📋 前3个处理结果示例:") + try: + async with aiofiles.open(output_path, "r", encoding="utf-8") as f: + i = 0 + async for line in f: + if i >= 3: + break + item = json.loads(line) + print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---") + print(f"原始段落: {item['original_paragraph'][:100]}...") + print(f"句子数量: {len(item['sentences'])}") + if item['triples']: + for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组 + print(f" 句子 {j+1}: {triple['sentence'][:50]}...") + if triple['confidence'] > 0: + print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}") + print(f" 置信度: {triple['confidence']:.2f}") + else: + print(f" 提取失败: {triple.get('error', '未知错误')}") + i += 1 + except Exception as e: + print(f"读取示例结果时出错: {e}") + +def main(): + """主入口函数""" + try: + # 运行异步主函数 + asyncio.run(main_async()) + except KeyboardInterrupt: + print("\n⚠️ 用户中断处理") + except Exception as e: + print(f"❌ 处理过程中出现错误: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_extra_accelerate.py b/train_extra_accelerate.py new file mode 100644 index 0000000..d8f5fcd --- /dev/null +++ b/train_extra_accelerate.py @@ -0,0 +1,1022 @@ +import os +# 设置环境变量 - 将wandb替换为SwanLab +# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式 +import platform +import argparse +from tqdm import tqdm +import time +import math +import warnings +import pandas as pd +import torch +from torch import optim, nn +from torch.utils.data import DataLoader +from contextlib import nullcontext +from typing import Optional +import datetime # Add datetime for time formatting +from accelerate import Accelerator +from accelerate.utils import set_seed +from accelerate.utils import DeepSpeedPlugin +from accelerate.utils import DistributedDataParallelKwargs +from transformers import AutoTokenizer, get_cosine_schedule_with_warmup +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +import swanlab # 替换wandb导入 +import gc # 添加垃圾回收模块 +import psutil # 添加系统资源监控模块 + +from model.model_extra import MiniMindLM, RMSNorm # 使用model_extra +from model.LMConfig import LMConfig +from model.dataset import TriplePretrainDataset # 只需要三元组数据集 + +warnings.filterwarnings('ignore') + +# 基于嵌入的余弦相似度损失计算函数 +def compute_embedding_cosine_loss(subject_logits, predicate_logits, object_logits, + target_triples, tokenizer, tok_embeddings, + pooling_method='mean', max_targets=5, temperature=1.0): + """ + 基于嵌入的余弦相似度损失计算 + Args: + subject_logits: [batch_size, max_subject_len, vocab_size] + predicate_logits: [batch_size, max_predicate_len, vocab_size] + object_logits: [batch_size, max_object_len, vocab_size] + target_triples: List[List[str]] - 每个样本的多个目标句子 + tokenizer: 分词器 + tok_embeddings: 模型的token嵌入层 + pooling_method: 句子嵌入的池化方法 ('mean', 'max', 'cls') + max_targets: int - 每个样本最大目标句子数量 + temperature: float - Softmax温度参数,控制预测的平滑度 + Returns: + torch.Tensor: 余弦相似度损失 + """ + if not target_triples or len(target_triples) == 0: + # 创建一个与输入张量相关的损失,保持在计算图中 + dummy_loss = subject_logits.sum() * 0.0 + 1.0 # 这样创建的张量会保持梯度 + return dummy_loss + + batch_size = subject_logits.shape[0] + + # 1. 获取预测的嵌入表示 + pred_embeddings = get_prediction_embeddings( + subject_logits, predicate_logits, object_logits, + tok_embeddings, pooling_method, temperature + ) # [batch_size, embed_dim] + + # 2. 获取目标的嵌入表示 + target_embeddings = get_target_embeddings( + target_triples, tokenizer, tok_embeddings, pooling_method, max_targets + ) # [batch_size, max_targets, embed_dim] + + # 3. 计算余弦相似度 + similarities = compute_cosine_similarity_batch(pred_embeddings, target_embeddings) + # [batch_size, max_targets] + + # 4. 选择最高相似度(最小损失) + best_similarities = torch.max(similarities, dim=-1)[0] # [batch_size] + + # 5. 转换为损失 (1 - cosine_similarity) + loss = 1.0 - best_similarities.mean() + + # 确保损失值在合理范围内(保持计算图连接) + loss = torch.clamp(loss, min=0.0, max=2.0) + + return loss + +def get_prediction_embeddings(subject_logits, predicate_logits, object_logits, + tok_embeddings, pooling_method='mean', temperature=1.0): + """ + 从预测logits获取句子嵌入(使用soft embedding保持梯度) + """ + batch_size = subject_logits.shape[0] + + # 使用softmax获取概率分布,而不是argmax + subject_probs = torch.softmax(subject_logits / temperature, dim=-1) # [batch_size, max_subject_len, vocab_size] + predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1) # [batch_size, max_predicate_len, vocab_size] + object_probs = torch.softmax(object_logits / temperature, dim=-1) # [batch_size, max_object_len, vocab_size] + + # 使用概率分布与嵌入矩阵进行加权求和,得到soft embeddings + # tok_embeddings.weight: [vocab_size, embed_dim] + subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight) # [batch_size, max_subject_len, embed_dim] + predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight) # [batch_size, max_predicate_len, embed_dim] + object_embeddings = torch.matmul(object_probs, tok_embeddings.weight) # [batch_size, max_object_len, embed_dim] + + # 拼接所有部分的嵌入 + combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1) + # [batch_size, total_len, embed_dim] + + # 池化得到句子嵌入 + if pooling_method == 'mean': + # 简单平均池化 + sentence_embeddings = combined_embeddings.mean(dim=1) + elif pooling_method == 'max': + sentence_embeddings = combined_embeddings.max(dim=1)[0] + elif pooling_method == 'cls': + # 使用第一个token作为句子表示 + sentence_embeddings = combined_embeddings[:, 0, :] + else: + sentence_embeddings = combined_embeddings.mean(dim=1) + + return sentence_embeddings # [batch_size, embed_dim] + +def get_target_embeddings(target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=5): + """ + 批量获取目标句子的嵌入表示 + Args: + target_triples: List[List[str]] - 每个样本的目标句子列表 + max_targets: int - 每个样本最大目标句子数量,不足补空字符串,超过则截取前max_targets个 + """ + batch_size = len(target_triples) + + if not target_triples: + # 如果没有目标句子,返回与嵌入层相关的零嵌入(保持计算图) + embed_dim = tok_embeddings.embedding_dim + # 使用嵌入层的权重创建零张量,保持计算图连接 + zero_embeddings = tok_embeddings.weight[:1, :].expand(batch_size, max_targets, embed_dim) * 0.0 + return zero_embeddings + + # 标准化每个样本的目标数量为max_targets + normalized_targets = [] + for targets in target_triples: + if len(targets) >= max_targets: + # 超过max_targets,取前max_targets个 + normalized_targets.extend(targets[:max_targets]) + else: + # 不足max_targets,补空字符串 + normalized_targets.extend(targets) + normalized_targets.extend([''] * (max_targets - len(targets))) + + # 现在 normalized_targets 的长度是 batch_size * max_targets + assert len(normalized_targets) == batch_size * max_targets + + # 批量tokenize所有目标句子 + tokenized = tokenizer( + normalized_targets, + padding=True, + truncation=True, + return_tensors='pt', + max_length=128 # 可以调整 + ) + + # 移到正确的设备 + input_ids = tokenized['input_ids'].to(tok_embeddings.weight.device) + attention_mask = tokenized['attention_mask'].to(tok_embeddings.weight.device) + + # 获取token嵌入 + token_embeddings = tok_embeddings(input_ids) # [batch_size * max_targets, seq_len, embed_dim] + + # 应用attention mask并池化 + if pooling_method == 'mean': + # 使用attention mask进行加权平均 + masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1) + sentence_embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8) + elif pooling_method == 'max': + # 在有效token上取最大值 + masked_embeddings = token_embeddings.masked_fill( + ~attention_mask.unsqueeze(-1).bool(), float('-inf') + ) + sentence_embeddings = masked_embeddings.max(dim=1)[0] + else: + sentence_embeddings = token_embeddings.mean(dim=1) + + # 重新整形为 [batch_size, max_targets, embed_dim] + embed_dim = sentence_embeddings.shape[-1] + target_embeddings = sentence_embeddings.view(batch_size, max_targets, embed_dim) + + return target_embeddings + +def compute_cosine_similarity_batch(pred_embeddings, target_embeddings): + """ + 批量计算余弦相似度 + Args: + pred_embeddings: [batch_size, embed_dim] + target_embeddings: [batch_size, max_targets, embed_dim] + Returns: + similarities: [batch_size, max_targets] + """ + # 标准化 + pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1) # [batch_size, embed_dim] + target_norm = torch.nn.functional.normalize(target_embeddings, p=2, dim=-1) # [batch_size, max_targets, embed_dim] + + # 计算余弦相似度 + # pred_norm: [batch_size, 1, embed_dim] + # target_norm: [batch_size, max_targets, embed_dim] + similarities = torch.sum(pred_norm.unsqueeze(1) * target_norm, dim=-1) + # [batch_size, max_targets] + + return similarities + +def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenizer): + """ + 将三元组logits转换为句子 + Args: + subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] + predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] + object_logits: [batch_size, seq_len, max_object_len, vocab_size] + tokenizer: 分词器 + Returns: + List[List[str]]: 每个样本每个位置的三元组句子 + """ + batch_size = subject_logits.shape[0] + predicate_seq_len = predicate_logits.shape[1] + subject_seq_len = subject_logits.shape[1] + object_seq_len = object_logits.shape[1] + + predicate_logits = predicate_logits.reshape(batch_size*predicate_seq_len, -1) + subject_logits = subject_logits.reshape(batch_size*subject_seq_len, -1) + object_logits = object_logits.reshape(batch_size*object_seq_len, -1) + + predicate_logits = torch.argmax(predicate_logits, dim=-1) + subject_logits = torch.argmax(subject_logits, dim=-1) + object_logits = torch.argmax(object_logits, dim=-1) + + predicate_logits = predicate_logits.reshape(batch_size, predicate_seq_len) + subject_logits = subject_logits.reshape(batch_size, subject_seq_len) + object_logits = object_logits.reshape(batch_size, object_seq_len) + + combined_logits = torch.cat([subject_logits, predicate_logits, object_logits], dim=1) + + sentences = tokenizer.batch_decode(combined_logits, skip_special_tokens=True) + + # sentences = [] + + # for batch_idx in range(batch_size): + # batch_sentences = [] + # for seq_idx in range(seq_len): + # # 获取预测的token ids + # subject_ids = torch.argmax(subject_logits[batch_idx, seq_idx], dim=-1) + # predicate_ids = torch.argmax(predicate_logits[batch_idx, seq_idx], dim=-1) + # object_ids = torch.argmax(object_logits[batch_idx, seq_idx], dim=-1) + + # # 转换为文本 + # subject_text = tokenizer.decode(subject_ids, skip_special_tokens=True).strip() + # predicate_text = tokenizer.decode(predicate_ids, skip_special_tokens=True).strip() + # object_text = tokenizer.decode(object_ids, skip_special_tokens=True).strip() + + # # 拼接为句子 (主语 + 谓语 + 宾语) + # if subject_text and predicate_text and object_text: + # sentence = f"{subject_text} {predicate_text} {object_text}" + # else: + # sentence = "" + + # batch_sentences.append(sentence) + # sentences.append(batch_sentences) + + return sentences + +def compute_triple_rouge_loss_optimized(subject_logits, predicate_logits, object_logits, + target_input_ids, target_attention_mask, tok_embeddings, temperature=1.0): + """ + 优化的三元组嵌入余弦相似度损失计算(单个target版本) + Args: + subject_logits: [batch_size, max_subject_len, vocab_size] + predicate_logits: [batch_size, max_predicate_len, vocab_size] + object_logits: [batch_size, max_object_len, vocab_size] + target_input_ids: [batch_size, target_seq_len] - 预tokenized的目标句子 + target_attention_mask: [batch_size, target_seq_len] - 目标句子的attention mask + tok_embeddings: 模型的token嵌入层 + temperature: float - Softmax温度参数,控制预测的平滑度 + Returns: + torch.Tensor: 嵌入余弦相似度损失 (标量) + """ + batch_size = subject_logits.shape[0] + + # ✅ 修复:确保target数据在正确的设备上 + device = tok_embeddings.weight.device + target_input_ids = target_input_ids.to(device) + target_attention_mask = target_attention_mask.to(device) + + # 1. 获取预测的嵌入表示(使用soft embedding保持梯度) + subject_probs = torch.softmax(subject_logits / temperature, dim=-1) + predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1) + object_probs = torch.softmax(object_logits / temperature, dim=-1) + + # 使用概率分布与嵌入矩阵进行加权求和 + subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight) + predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight) + object_embeddings = torch.matmul(object_probs, tok_embeddings.weight) + + # 拼接所有部分的嵌入并平均池化 + combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1) + pred_embeddings = combined_embeddings.mean(dim=1) # [batch_size, embed_dim] + + # 2. 获取目标的嵌入表示(直接使用预tokenized的数据) + target_embeddings = tok_embeddings(target_input_ids) # [batch_size, target_seq_len, embed_dim] + + # 使用attention mask进行加权平均池化 + masked_embeddings = target_embeddings * target_attention_mask.unsqueeze(-1) + target_pooled = masked_embeddings.sum(dim=1) / target_attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8) + # [batch_size, embed_dim] + + # 3. 计算余弦相似度 + pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1) + target_norm = torch.nn.functional.normalize(target_pooled, p=2, dim=-1) + + # 计算余弦相似度 + similarities = torch.sum(pred_norm * target_norm, dim=-1) # [batch_size] + + # 4. 转换为损失 (1 - cosine_similarity) + loss = 1.0 - similarities.mean() + + # 确保损失值在合理范围内 + loss = torch.clamp(loss, min=0.0, max=2.0) + + return loss + +def compute_triple_rouge_loss(subject_logits, predicate_logits, object_logits, target_triples, tokenizer, tok_embeddings, max_targets=5, temperature=1.0): + """ + 原始版本的三元组损失计算(保留用于兼容性) + Args: + subject_logits: [batch_size, max_subject_len, vocab_size] + predicate_logits: [batch_size, max_predicate_len, vocab_size] + object_logits: [batch_size, max_object_len, vocab_size] + target_triples: List[List[str]] - 每个样本的多个真值三元组句子 + tokenizer: 分词器 + tok_embeddings: 模型的token嵌入层 + max_targets: int - 每个样本最大目标句子数量 + temperature: float - Softmax温度参数,控制预测的平滑度 + Returns: + torch.Tensor: 嵌入余弦相似度损失 (标量) + """ + return compute_embedding_cosine_loss( + subject_logits, predicate_logits, object_logits, + target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=max_targets, temperature=temperature + ) + +# 内存监控辅助函数 +def get_memory_usage(): + """获取当前内存使用情况""" + process = psutil.Process() + memory_info = process.memory_info() + return { + 'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量(MB) + 'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量(MB) + } + +def get_cuda_memory_usage(): + """获取CUDA内存使用情况""" + if torch.cuda.is_available(): + return { + 'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024, + 'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024, + 'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024, + } + return {} + +def get_tensor_memory_size(tensor_list): + """计算tensor列表的总内存占用(MB)""" + total_size = 0 + for batch in tensor_list: + if isinstance(batch, (list, tuple)): + for tensor in batch: + if isinstance(tensor, torch.Tensor): + total_size += tensor.numel() * tensor.element_size() + elif isinstance(batch, torch.Tensor): + total_size += batch.numel() * batch.element_size() + return total_size / 1024 / 1024 # 转换为MB + +def log_memory_status(step, accelerator, stage="", detailed=False): + """记录内存状态""" + if not accelerator.is_main_process: + return + + memory_info = get_memory_usage() + cuda_info = get_cuda_memory_usage() + + log_msg = f"[Memory Monitor] Step {step} {stage} - " + log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB" + + if cuda_info: + log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB" + log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB" + + if detailed: + log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB" + if cuda_info: + log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB" + + Logger(log_msg, accelerator) + +# 日志记录函数 +def Logger(msg, accelerator=None): + # 如果没有提供accelerator,则只在主进程打印 + if accelerator is None or accelerator.is_main_process: + print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}") + +# Helper function to format seconds into HH:MM:SS +def format_time(seconds): + return str(datetime.timedelta(seconds=int(seconds))) + +# 获取学习率函数 +def get_lr(it, num_iters, learning_rate): + # 余弦学习率衰减 + return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters)) + +# 初始化模型函数 +def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None): + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + model = MiniMindLM(lm_config, mode="triple") # 设置为三元组模式 + + # 加载预训练权重 + pretrained_path = "./out/Experiment_1_2_2_pretrain_512.pth" + Logger(f"Loading pretrained weights from {pretrained_path}") + + try: + # 加载预训练的state_dict + pretrained_state_dict = torch.load(pretrained_path, map_location='cpu') + Logger(f"Successfully loaded pretrained state_dict with {len(pretrained_state_dict)} parameters") + + # 获取当前模型的state_dict + model_state_dict = model.state_dict() + + # 统计加载情况 + loaded_params = [] + skipped_params = [] + + # 逐个加载兼容的权重 + for name, param in pretrained_state_dict.items(): + if name in model_state_dict: + if model_state_dict[name].shape == param.shape: + model_state_dict[name].copy_(param) + loaded_params.append(name) + else: + Logger(f"Warning: Shape mismatch for {name}, expected {model_state_dict[name].shape}, got {param.shape}") + skipped_params.append(f"{name} (shape mismatch)") + else: + skipped_params.append(f"{name} (not found in model2)") + + Logger(f"Loaded {len(loaded_params)} parameters from pretrained weights") + Logger(f"Skipped {len(skipped_params)} parameters") + + # 显示一些关键加载的参数 + key_loaded = [name for name in loaded_params if any(key in name for key in ['tok_embeddings', 'layers.0', 'knowledge_dataset', 'output', 'norm'])] + if key_loaded: + Logger("Key loaded parameters:") + for name in key_loaded[:5]: # 只显示前5个 + Logger(f" ✅ {name}") + if len(key_loaded) > 5: + Logger(f" ... and {len(key_loaded) - 5} more") + + # 显示跳过的参数(应该主要是triple_extraction_head相关的) + triple_skipped = [name for name in skipped_params if 'triple_extraction_head' in name] + if triple_skipped: + Logger("Triple extraction head parameters (newly initialized):") + for name in triple_skipped[:3]: # 只显示前3个 + Logger(f" 🆕 {name}") + if len(triple_skipped) > 3: + Logger(f" ... and {len(triple_skipped) - 3} more") + + except Exception as e: + Logger(f"Error loading pretrained weights: {e}") + Logger("Falling back to default initialization...") + + # 默认模型初始化(备用方案) + Logger("Performing default model initialization...") + + # 初始化嵌入层权重 + nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02) + + # 初始化输出层权重(如果不共享权重的话) + if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight: + nn.init.normal_(model.output.weight, mean=0.0, std=0.02) + + # 初始化所有线性层 + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + # 使用Xavier/Glorot初始化 + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + # 嵌入层使用正态分布初始化 + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, RMSNorm): + # RMSNorm的权重初始化为1 + if hasattr(module, 'weight'): + nn.init.ones_(module.weight) + + # 初始化位置编码相关参数 + if hasattr(model.knowledge_dataset, 'keys'): + nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02) + + Logger("Default model initialization completed") + + # 如果提供了预训练的嵌入权重,加载它们 + if pretrained_embedding_path: + Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}") + pretrained_embeddings = torch.load(pretrained_embedding_path) + model.tok_embeddings.weight.data.copy_(pretrained_embeddings) + model.output.weight.data.copy_(pretrained_embeddings) # 共享权重 + + + + Logger(f"Database embeddings and sentences stored in model") + + Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + return model, tokenizer + +def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer): + # 三元组提取训练模式:不需要传统的交叉熵损失函数 + epoch_start_time = time.time() + total_steps_in_epoch = len(train_loader) + total_training_steps = args.epochs * total_steps_in_epoch + moe_path = '_moe' if args.use_moe else '' + best_loss = float('10000') + + # 初始化CUDA事件变量 - 只保留GPU计算时间追踪 + forward_start = forward_end = loss_start = loss_end = backward_start = backward_end = optimizer_start = optimizer_end = None + + # 添加CUDA事件来分析GPU性能 (只在主进程进行) + if args.profile and accelerator.is_main_process: + forward_start = torch.cuda.Event(enable_timing=True) + forward_end = torch.cuda.Event(enable_timing=True) + loss_start = torch.cuda.Event(enable_timing=True) + loss_end = torch.cuda.Event(enable_timing=True) + backward_start = torch.cuda.Event(enable_timing=True) + backward_end = torch.cuda.Event(enable_timing=True) + optimizer_start = torch.cuda.Event(enable_timing=True) + optimizer_end = torch.cuda.Event(enable_timing=True) + + # 移除自定义预取机制,使用DataLoader内置预取 + # 记录初始内存状态 + if args.memory_monitor: + memory_info = get_memory_usage() + cuda_info = get_cuda_memory_usage() + log_msg = f"[Memory Monitor] Training start - System RSS: {memory_info['rss_mb']:.2f}MB" + if cuda_info: + log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB" + Logger(log_msg, accelerator) + + # 在开始循环前初始化日志记录所需变量 + last_log_time = epoch_start_time + + # 使用DataLoader内置的iterator,移除自定义预取 + for step, batch_data in enumerate(train_loader): + # === 每个step开始 === + + try: + # === 1. 数据准备 === + # 直接使用DataLoader提供的数据 + if not isinstance(batch_data, dict): + raise ValueError("期望字典格式的批次数据,请确保使用 TriplePretrainDataset") + + X = batch_data['input_ids'] + Y = batch_data['labels'] + loss_mask = batch_data['loss_mask'] + target_input_ids = batch_data['target_input_ids'] + target_attention_mask = batch_data['target_attention_mask'] + target_sentences = batch_data['target_sentences'] # 用于调试输出 + + # === 2. 学习率更新 === + if scheduler is not None: + scheduler.step() + + # === 3. 前向传播 === + # 计时GPU前向传播 + if args.profile and accelerator.is_main_process and forward_start is not None: + forward_start.record() + + # 前向传播 + with ctx: + if step == 0 and args.embedding_epoch == epoch: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.freeze_embedding = True + Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator) + res = model(X, step=step) + + # 计时GPU前向传播结束 + if args.profile and accelerator.is_main_process and forward_end is not None: + forward_end.record() + + # === 4. 损失计算 === + # 三元组提取模式:只使用ROUGE Loss进行三元组损失计算 + Logger("三元组提取训练模式", accelerator) if step == 0 else None + + # 确保有三元组输出 + if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')): + raise ValueError("模型没有输出三元组logits,请检查模型配置") + + # 确保有目标数据 + if target_input_ids is None: + raise ValueError("没有三元组目标数据,请检查数据格式") + + # 计算三元组损失 + try: + Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None + + # 计时GPU损失计算 + if args.profile and accelerator.is_main_process and loss_start is not None: + loss_start.record() + + # 计算优化后的嵌入余弦相似度损失 + loss = compute_triple_rouge_loss_optimized( + res.subject_logits, res.predicate_logits, res.object_logits, + target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature + ) + + # 计时GPU损失计算结束 + if args.profile and accelerator.is_main_process and loss_end is not None: + loss_end.record() + + except Exception as e: + Logger(f"Error: ROUGE loss computation failed: {e}", accelerator) + import traceback + Logger(f"Traceback: {traceback.format_exc()}", accelerator) + loss = res.logits.sum() * 0.0 + 1.0 + + loss = loss / args.accumulation_steps + + # === 5. 反向传播 === + # 计时GPU反向传播 + if args.profile and accelerator.is_main_process and backward_start is not None: + backward_start.record() + + # 反向传播 + accelerator.backward(loss) + + # 计时GPU反向传播结束 + if args.profile and accelerator.is_main_process and backward_end is not None: + backward_end.record() + + # === 6. 优化器步骤 === + # 计时GPU优化器步骤 + if args.profile and accelerator.is_main_process and optimizer_start is not None: + optimizer_start.record() + + # 优化器步骤 + optimizer.step() + optimizer.zero_grad() + + # 计时GPU优化器步骤结束 + if args.profile and accelerator.is_main_process and optimizer_end is not None: + optimizer_end.record() + + # === 7. 日志记录 === + # 打印训练信息 (只在主进程进行) + if (step + 1) % args.log_interval == 0 and accelerator.is_main_process: + current_time = time.time() + + # 计算GPU性能指标 + if args.profile and accelerator.is_main_process: + torch.cuda.synchronize() + + # 获取GPU时间 + try: + forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0 + loss_time = loss_start.elapsed_time(loss_end) if loss_start is not None and loss_end is not None else 0 + backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0 + optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0 + iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log + + # 打印GPU性能分析 + if (step + 1) % (args.log_interval * args.profile_interval) == 0: + # 计算GPU时间 + gpu_time_total = (forward_time + loss_time + backward_time + optimizer_time) / args.log_interval + + Logger(f"=== GPU性能分析 (平均每步) ===", accelerator) + Logger(f"前向传播: {forward_time/args.log_interval:.2f}ms, " + f"损失计算: {loss_time/args.log_interval:.2f}ms, " + f"反向传播: {backward_time/args.log_interval:.2f}ms, " + f"优化器: {optimizer_time/args.log_interval:.2f}ms", accelerator) + Logger(f"GPU总时间: {gpu_time_total:.2f}ms, " + f"实际迭代时间: {iter_time:.2f}ms, " + f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator) + Logger("=" * 50, accelerator) + + Logger("=== 三元组预测示例 ===", accelerator) + predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer) + # 显示前2个样本的目标句子 + for i, target_sentence in enumerate(target_sentences[:2]): + Logger(f"样本{i+1}目标: {target_sentence}", accelerator) + Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator) + Logger("==================", accelerator) + + # 重置GPU事件 + forward_start = torch.cuda.Event(enable_timing=True) + forward_end = torch.cuda.Event(enable_timing=True) + loss_start = torch.cuda.Event(enable_timing=True) + loss_end = torch.cuda.Event(enable_timing=True) + backward_start = torch.cuda.Event(enable_timing=True) + backward_end = torch.cuda.Event(enable_timing=True) + optimizer_start = torch.cuda.Event(enable_timing=True) + optimizer_end = torch.cuda.Event(enable_timing=True) + except RuntimeError as e: + if "Both events must be recorded" in str(e): + Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator) + else: + raise e + + # 计算基本指标 + current_lr = optimizer.param_groups[0]['lr'] + epoch_elapsed_time = current_time - epoch_start_time + epoch_steps_done = step + 1 + epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done + epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done) + + total_elapsed_time = current_time - overall_start_time + total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done + total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0 + total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0 + + # 计算训练速度 + interval_elapsed_time = current_time - last_log_time + tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len + tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0 + last_log_time = current_time + + # 基本训练信息 + Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " + f"Loss: {loss.item() * args.accumulation_steps:.6f}, " + f"LR: {current_lr:.6f}, " + f"Speed: {tokens_per_sec:.2f} tokens/sec | " + f"Epoch Time Left: {format_time(epoch_remaining_time)} | " + f"Total Time Left: {format_time(total_remaining_time)}", accelerator) + + # SwanLab日志记录 + if args.use_swanlab and accelerator.is_main_process and swanlab_run: + log_dict = { + "epoch": epoch + 1, + "step": step + 1, + "total_steps_in_epoch": total_steps_in_epoch, + "triple_embedding_cosine_loss": loss.item() * args.accumulation_steps, + "lr": current_lr, + "tokens_per_sec": tokens_per_sec, + "epoch_time_left_seconds": epoch_remaining_time, + "total_time_left_seconds": total_remaining_time + } + swanlab_run.log(log_dict) + + # === 8. 模型保存 === + # 保存模型 (只在主进程进行) + loss_total = loss.item() * args.accumulation_steps + if epoch > 1 and best_loss > loss_total and accelerator.is_main_process: + best_loss = loss_total + ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' + unwrapped_model = accelerator.unwrap_model(model) + accelerator.save(unwrapped_model.state_dict(), ckp) + Logger(f"Model saved to {ckp}", accelerator) + + except Exception as e: + Logger(f"Error in training step: {e}", accelerator) + import traceback + Logger(traceback.format_exc(), accelerator) + + # 清理内存,防止内存泄漏 + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # 训练epoch结束时清理内存 + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +def main(): + parser = argparse.ArgumentParser(description="MiniMind Triple Extraction Training with Accelerate") + parser.add_argument("--out_dir", type=str, default="out") + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数") + parser.add_argument("--batch_size", type=int, default=192) + parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数 + parser.add_argument("--swanlab_project", type=str, default="MiniMind-TripleExtraction") # 替换wandb参数 + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--accumulation_steps", type=int, default=32) + parser.add_argument("--grad_clip", type=float, default=1.0) + parser.add_argument("--warmup_iters", type=int, default=0) + parser.add_argument("--log_interval", type=int, default=100) + parser.add_argument("--save_interval", type=int, default=10000) + parser.add_argument('--dim', default=512, type=int) + parser.add_argument('--n_layers', default=8, type=int) + parser.add_argument('--max_seq_len', default=512, type=int) + parser.add_argument('--use_moe', default=False, type=bool) + parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") + parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json") + parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") + parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") + parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") + parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") + parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目") + parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") + parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径") + parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") + parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径") + parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") + parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控") + parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)") + parser.add_argument("--max_targets", type=int, default=5, help="每个样本最大目标句子数量,用于批处理优化") + parser.add_argument("--temperature", type=float, default=1.0, help="Softmax温度参数,用于控制预测的平滑度") + parser.add_argument("--detailed_timing", action="store_true", default=True, help="启用详细的时间追踪分析") + # 移除dataset_type参数,此训练脚本专用于三元组提取训练 + # parser.add_argument("--dataset_type", type=str, default="pretrain", choices=["pretrain", "triple"], help="数据集类型:pretrain(标准预训练)或triple(三元组)") + args = parser.parse_args() + + ######################################################### + # 初始化accelerator和deepspeed + ######################################################### + # 设置ddp_kwargs以处理未使用的参数 + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # 创建DeepSpeedPlugin对象 + ds_plugin = DeepSpeedPlugin( + gradient_accumulation_steps=args.accumulation_steps, + gradient_clipping=args.grad_clip, + zero_stage=2, # 使用ZeRO-2优化 + offload_optimizer_device="none", # 将优化器状态卸载到CPU + offload_param_device="none", # 不将参数卸载到CPU + ) + accelerator = Accelerator( + kwargs_handlers=[ddp_kwargs], + deepspeed_plugin=ds_plugin, + mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no" + ) + + ######################################################### + # 设置随机种子 + ######################################################### + set_seed(1337 + accelerator.process_index) + + ######################################################### + # 配置模型 + ######################################################### + lm_config = LMConfig( + dim=args.dim, + n_layers=args.n_layers, + max_seq_len=args.max_seq_len, + use_moe=args.use_moe, + disable_db=args.disable_db, + flash_attn=args.use_flash_attn, + knowledge_num=args.knowledge_num, + knowledge_length=args.knowledge_length, + embeddings_epoch=args.embedding_epoch + ) + + ######################################################### + # 创建保存目录 + ######################################################### + args.save_dir = os.path.join(args.out_dir) + if accelerator.is_main_process: + os.makedirs(args.save_dir, exist_ok=True) + os.makedirs(args.out_dir, exist_ok=True) + + ######################################################### + # 设置数据类型 + ######################################################### + pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] + + + ######################################################### + # 配置SwanLab + ######################################################### + # 设置SwanLab运行名称 + args.swanlab_run_name = f"MiniMind-TripleExtraction-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + + # 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息) + config_dict = vars(args).copy() + config_dict.update(vars(lm_config)) + + # 初始化SwanLab实验实例 + swanlab_run = None + if args.use_swanlab and accelerator.is_main_process: + # 初始化SwanLab + swanlab_run = swanlab.init( + project=args.swanlab_project, + experiment_name=args.swanlab_run_name, + description="MiniMind三元组提取训练实验,使用ROUGE损失优化三元组抽取性能", + config=config_dict + # 设置SwanLab服务器地址和API Key + # host="http://100.123.118.114:11071", + # api_key="LesBT7HRq23HNBrOPKP8S" + ) + else: + swanlab_run = None + + ######################################################### + # 打印信息 + ######################################################### + # 计算每次迭代的token数量 + tokens_per_iter = args.batch_size * lm_config.max_seq_len + if accelerator.is_main_process: + Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator) + Logger("Configuration:", accelerator) + for key, value in config_dict.items(): + Logger(f" {key}: {value}", accelerator) + + + ######################################################### + # 设置自动混合精度上下文 + ######################################################### + ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype) + + ######################################################### + # 初始化模型和tokenizer + ######################################################### + model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args) + # 将accelerator传递给init_model函数中的Logger调用 + Logger(f'模型初始化完成', accelerator) + + ######################################################### + # 处理位置编码张量问题 + ######################################################### + if hasattr(model, "pos_cis_real"): + Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator) + # 设置模型的_ddp_params_and_buffers_to_ignore属性 + # model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"} + # 兼容旧版本,检查是否仍有pos_cis + elif hasattr(model, "pos_cis"): + Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator) + # 设置模型的_ddp_params_and_buffers_to_ignore属性 + model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + + ######################################################### + # 创建数据集和数据加载器(专用于三元组提取训练) + ######################################################### + Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator) + train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) + + # 创建自定义collate_fn来处理优化后的数据格式 + def triple_collate_fn(batch): + # batch是一个包含字典的列表 + input_ids = torch.stack([item['input_ids'] for item in batch]) + labels = torch.stack([item['labels'] for item in batch]) + loss_mask = torch.stack([item['loss_mask'] for item in batch]) + target_input_ids = torch.stack([item['target_input_ids'] for item in batch]) + target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch]) + target_sentences = [item['target_sentence'] for item in batch] # 用于调试 + + return { + 'input_ids': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'target_input_ids': target_input_ids, + 'target_attention_mask': target_attention_mask, + 'target_sentences': target_sentences + } + + train_loader = DataLoader( + train_ds, + batch_size=args.batch_size, + pin_memory=False, # ✅ 实验:禁用pin_memory,避免内存固定问题 + drop_last=True, # 修复:避免边界条件导致的死锁 + shuffle=True, + num_workers=0, # ✅ 实验:禁用多进程,避免worker死锁 + # persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用 + collate_fn=triple_collate_fn + ) + + ######################################################### + # 创建优化器 + ######################################################### + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + + ######################################################### + # 创建学习率调度器 + ######################################################### + total_steps = len(train_loader) * args.epochs + warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps) + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps + ) + + ######################################################### + # 准备训练 + ######################################################### + model, optimizer, train_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, scheduler + ) + + ######################################################### + # 训练循环 + ######################################################### + overall_start_time = time.time() # Record overall start time + for epoch in range(args.epochs): + Logger(f"开始第{epoch+1}轮训练", accelerator) + train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer + + # 每个epoch结束后进行内存清理 + Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # 记录epoch结束时的内存状态 + if accelerator.is_main_process: + memory_info = get_memory_usage() + cuda_info = get_cuda_memory_usage() + log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - " + log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB" + if cuda_info: + log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB" + log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB" + Logger(log_msg, accelerator) + + ######################################################### + # 关闭SwanLab + ######################################################### + if args.use_swanlab and accelerator.is_main_process and swanlab_run: + swanlab_run.finish() + +if __name__ == "__main__": + main()