import os import json import argparse import torch import numpy as np from tqdm import tqdm from transformers import AutoTokenizer from model.model_extra import MiniMindLM from model.LMConfig import LMConfig from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report # 加载谓词词汇表 PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json' with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f: PREDICATE_LIST = json.load(f) PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)} NUM_PREDICATES = len(PREDICATE_LIST) def evaluate_model(model, tokenizer, test_data, device): """ 评估模型性能 - 只关注谓词分类 """ model.eval() results = [] all_pred_predicates = [] all_gold_predicates = [] correct_predictions = 0 total_predictions = 0 print("开始评估...") # 添加调试信息 print(f"测试数据样本数量: {len(test_data)}") if test_data: print(f"第一个样本格式: {type(test_data[0])}") print(f"第一个样本内容: {test_data[0]}") if isinstance(test_data[0], dict): print(f"第一个样本的键: {list(test_data[0].keys())}") for i, item in enumerate(tqdm(test_data, desc="评估进度")): input_text = item["input"] gold_triples = item.get("output", []) # 调试信息(前几个样本) if i < 3: print(f"\n样本 {i+1} 调试信息:") print(f" 输入文本: {input_text[:100]}...") print(f" 真值三元组数量: {len(gold_triples)}") if gold_triples: print(f" 真值三元组: {gold_triples[0]}") # 模型推理 inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding='max_length') input_ids = inputs["input_ids"].to(device) with torch.no_grad(): output = model(input_ids=input_ids) # 获取谓词分类结果 pred_predicate_id = output.predicate_cls_logits.argmax(-1).item() pred_predicate = PREDICATE_LIST[pred_predicate_id] if pred_predicate_id < len(PREDICATE_LIST) else "" # 收集所有目标谓词 target_predicates = [] if gold_triples: for triple in gold_triples: if "predicate" in triple: target_predicates.append(triple["predicate"]) # 检查预测是否正确(只要在目标谓词列表中就算正确) is_correct = False if target_predicates and pred_predicate in target_predicates: is_correct = True correct_predictions += 1 total_predictions += 1 # 调试信息(前几个样本) if i < 3: print(f" 预测谓词: {pred_predicate}") print(f" 目标谓词: {target_predicates}") print(f" 是否正确: {is_correct}") # 收集谓词分类标签(用于详细分析) if target_predicates: # 取第一个目标谓词作为主要标签 # import pdb; pdb.set_trace() main_target = target_predicates[0] if main_target in PREDICATE2ID: all_gold_predicates.append(PREDICATE2ID[main_target]) all_pred_predicates.append(pred_predicate_id) results.append({ "input": input_text, "predicted_predicate": pred_predicate, "target_predicates": target_predicates, "is_correct": is_correct }) print(f"\n评估完成,总预测数: {total_predictions}, 正确数: {correct_predictions}") return results, all_pred_predicates, all_gold_predicates, correct_predictions, total_predictions def print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions): """ 打印评估结果摘要 - 只关注谓词分类 """ print("\n" + "="*60) print("谓词分类评估结果摘要") print("="*60) # 谓词分类准确率 if total_predictions > 0: predicate_accuracy = correct_predictions / total_predictions print(f"谓词分类准确率: {predicate_accuracy:.4f} ({correct_predictions}/{total_predictions})") else: print("谓词分类准确率: 无法计算(没有有效预测)") # 详细的分类报告(如果有足够的标签数据) if pred_predicates and gold_predicates and len(pred_predicates) > 10: try: print(f"\n谓词分类详细报告:") print(classification_report(gold_predicates, pred_predicates, target_names=PREDICATE_LIST[:10] + ["..."] if len(PREDICATE_LIST) > 10 else PREDICATE_LIST, zero_division=0)) except Exception as e: print(f"\n谓词分类详细报告生成失败: {e}") # 样本预测示例 print(f"\n预测示例 (前5个):") for i, result in enumerate(results[:5]): print(f"样本 {i+1}:") print(f" 输入: {result['input'][:100]}...") print(f" 预测谓词: {result['predicted_predicate']}") print(f" 目标谓词: {result['target_predicates']}") print(f" 是否正确: {'✓' if result['is_correct'] else '✗'}") print() def main(): parser = argparse.ArgumentParser(description="MiniMind 三元组抽取模型评估脚本") parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth') parser.add_argument('--tokenizer_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer') parser.add_argument('--test_json', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json') parser.add_argument('--output_dir', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录') parser.add_argument('--device', type=str, default='cuda', help='推理设备') # 模型配置参数 parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能") parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention") parser.add_argument('--knowledge_num', type=int, default=960400, help="知识库的数据数目") parser.add_argument('--knowledge_length', type=int, default=32, help="知识库的句子长度") parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # 加载模型和分词器 print("加载模型和分词器...") lm_config = LMConfig( dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe, disable_db=args.disable_db, flash_attn=args.flash_attn, knowledge_num=args.knowledge_num, knowledge_length=args.knowledge_length, embeddings_epoch=args.embeddings_epoch ) model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES) # 加载模型权重 try: state_dict = torch.load(args.model_path, map_location=args.device) model.load_state_dict(state_dict, strict=False) print(f"成功加载模型权重: {args.model_path}") except Exception as e: print(f"加载模型权重失败: {e}") print("使用随机初始化的模型进行测试") model.to(args.device) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) print(f"谓词词汇表大小: {len(PREDICATE_LIST)}") # 加载测试数据 print(f"加载测试数据: {args.test_json}") with open(args.test_json, 'r', encoding='utf-8') as f: test_data = json.load(f) # 支持多种数据格式 if isinstance(test_data[0], dict) and "text" in test_data[0]: # 格式: [{"text": "...", "target": [...]}, ...] test_data = [{"input": item["text"], "output": item.get("target", [])} for item in test_data] elif isinstance(test_data[0], dict) and "input" in test_data[0]: # 格式: [{"input": "...", "output": [...]}, ...] pass else: # 格式: ["句子", ...] - 没有真值,只能做推理 test_data = [{"input": text, "output": []} for text in test_data] print(f"测试样本数量: {len(test_data)}") # 评估模型 results, pred_predicates, gold_predicates, correct_predictions, total_predictions = evaluate_model(model, tokenizer, test_data, args.device) # 打印评估结果 print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions) # 保存详细结果 output_path = os.path.join(args.output_dir, 'evaluation_results.json') with open(output_path, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"\n详细评估结果已保存到: {output_path}") # 保存准确率统计 accuracy_stats = { "total_predictions": total_predictions, "correct_predictions": correct_predictions, "accuracy": correct_predictions / total_predictions if total_predictions > 0 else 0.0, "model_path": args.model_path, "test_data_path": args.test_json, "predicate_vocab_size": len(PREDICATE_LIST), "evaluation_timestamp": str(np.datetime64('now')) } accuracy_path = os.path.join(args.output_dir, 'accuracy_stats.json') with open(accuracy_path, 'w', encoding='utf-8') as f: json.dump(accuracy_stats, f, indent=2, ensure_ascii=False) print(f"准确率统计已保存到: {accuracy_path}") # 保存预测结果 predictions = [{"input": r["input"], "predicted_predicate": r["predicted_predicate"], "gold_predicates": r["target_predicates"]} for r in results] pred_output_path = os.path.join(args.output_dir, 'predictions.json') with open(pred_output_path, 'w', encoding='utf-8') as f: json.dump(predictions, f, indent=2, ensure_ascii=False) print(f"预测结果已保存到: {pred_output_path}") if __name__ == "__main__": main()