Minimind/test.py
2025-07-05 03:03:43 +00:00

243 lines
10 KiB
Python

import os
import json
import argparse
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from model.model_extra import MiniMindLM
from model.LMConfig import LMConfig
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
# 加载谓词词汇表
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
PREDICATE_LIST = json.load(f)
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
NUM_PREDICATES = len(PREDICATE_LIST)
def evaluate_model(model, tokenizer, test_data, device):
"""
评估模型性能 - 只关注谓词分类
"""
model.eval()
results = []
all_pred_predicates = []
all_gold_predicates = []
correct_predictions = 0
total_predictions = 0
print("开始评估...")
# 添加调试信息
print(f"测试数据样本数量: {len(test_data)}")
if test_data:
print(f"第一个样本格式: {type(test_data[0])}")
print(f"第一个样本内容: {test_data[0]}")
if isinstance(test_data[0], dict):
print(f"第一个样本的键: {list(test_data[0].keys())}")
for i, item in enumerate(tqdm(test_data, desc="评估进度")):
input_text = item["input"]
gold_triples = item.get("output", [])
# 调试信息(前几个样本)
if i < 3:
print(f"\n样本 {i+1} 调试信息:")
print(f" 输入文本: {input_text[:100]}...")
print(f" 真值三元组数量: {len(gold_triples)}")
if gold_triples:
print(f" 真值三元组: {gold_triples[0]}")
# 模型推理
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
output = model(input_ids=input_ids)
# 获取谓词分类结果
pred_predicate_id = output.predicate_cls_logits.argmax(-1).item()
pred_predicate = PREDICATE_LIST[pred_predicate_id] if pred_predicate_id < len(PREDICATE_LIST) else "<UNK>"
# 收集所有目标谓词
target_predicates = []
if gold_triples:
for triple in gold_triples:
if "predicate" in triple:
target_predicates.append(triple["predicate"])
# 检查预测是否正确(只要在目标谓词列表中就算正确)
is_correct = False
if target_predicates and pred_predicate in target_predicates:
is_correct = True
correct_predictions += 1
total_predictions += 1
# 调试信息(前几个样本)
if i < 3:
print(f" 预测谓词: {pred_predicate}")
print(f" 目标谓词: {target_predicates}")
print(f" 是否正确: {is_correct}")
# 收集谓词分类标签(用于详细分析)
if target_predicates:
# 取第一个目标谓词作为主要标签
main_target = target_predicates[0]
if main_target in PREDICATE2ID:
all_gold_predicates.append(PREDICATE2ID[main_target])
all_pred_predicates.append(pred_predicate_id)
results.append({
"input": input_text,
"predicted_predicate": pred_predicate,
"target_predicates": target_predicates,
"is_correct": is_correct
})
print(f"\n评估完成,总预测数: {total_predictions}, 正确数: {correct_predictions}")
return results, all_pred_predicates, all_gold_predicates, correct_predictions, total_predictions
def print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions):
"""
打印评估结果摘要 - 只关注谓词分类
"""
print("\n" + "="*60)
print("谓词分类评估结果摘要")
print("="*60)
# 谓词分类准确率
if total_predictions > 0:
predicate_accuracy = correct_predictions / total_predictions
print(f"谓词分类准确率: {predicate_accuracy:.4f} ({correct_predictions}/{total_predictions})")
else:
print("谓词分类准确率: 无法计算(没有有效预测)")
# 详细的分类报告(如果有足够的标签数据)
if pred_predicates and gold_predicates and len(pred_predicates) > 10:
try:
print(f"\n谓词分类详细报告:")
print(classification_report(gold_predicates, pred_predicates,
target_names=PREDICATE_LIST[:10] + ["..."] if len(PREDICATE_LIST) > 10 else PREDICATE_LIST,
zero_division=0))
except Exception as e:
print(f"\n谓词分类详细报告生成失败: {e}")
# 样本预测示例
print(f"\n预测示例 (前5个):")
for i, result in enumerate(results[:5]):
print(f"样本 {i+1}:")
print(f" 输入: {result['input'][:100]}...")
print(f" 预测谓词: {result['predicted_predicate']}")
print(f" 目标谓词: {result['target_predicates']}")
print(f" 是否正确: {'' if result['is_correct'] else ''}")
print()
def main():
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取模型评估脚本")
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth')
parser.add_argument('--tokenizer_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
parser.add_argument('--test_json', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
parser.add_argument('--output_dir', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
parser.add_argument('--device', type=str, default='cuda', help='推理设备')
# 模型配置参数
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能")
parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention")
parser.add_argument('--knowledge_num', type=int, default=960400, help="知识库的数据数目")
parser.add_argument('--knowledge_length', type=int, default=32, help="知识库的句子长度")
parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# 加载模型和分词器
print("加载模型和分词器...")
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embeddings_epoch
)
model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES)
# 加载模型权重
try:
state_dict = torch.load(args.model_path, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
print(f"成功加载模型权重: {args.model_path}")
except Exception as e:
print(f"加载模型权重失败: {e}")
print("使用随机初始化的模型进行测试")
model.to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
print(f"谓词词汇表大小: {len(PREDICATE_LIST)}")
# 加载测试数据
print(f"加载测试数据: {args.test_json}")
with open(args.test_json, 'r', encoding='utf-8') as f:
test_data = json.load(f)
# 支持多种数据格式
if isinstance(test_data[0], dict) and "text" in test_data[0]:
# 格式: [{"text": "...", "target": [...]}, ...]
test_data = [{"input": item["text"], "output": item.get("target", [])} for item in test_data]
elif isinstance(test_data[0], dict) and "input" in test_data[0]:
# 格式: [{"input": "...", "output": [...]}, ...]
pass
else:
# 格式: ["句子", ...] - 没有真值,只能做推理
test_data = [{"input": text, "output": []} for text in test_data]
print(f"测试样本数量: {len(test_data)}")
# 评估模型
results, pred_predicates, gold_predicates, correct_predictions, total_predictions = evaluate_model(model, tokenizer, test_data, args.device)
# 打印评估结果
print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions)
# 保存详细结果
output_path = os.path.join(args.output_dir, 'evaluation_results.json')
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n详细评估结果已保存到: {output_path}")
# 保存准确率统计
accuracy_stats = {
"total_predictions": total_predictions,
"correct_predictions": correct_predictions,
"accuracy": correct_predictions / total_predictions if total_predictions > 0 else 0.0,
"model_path": args.model_path,
"test_data_path": args.test_json,
"predicate_vocab_size": len(PREDICATE_LIST),
"evaluation_timestamp": str(np.datetime64('now'))
}
accuracy_path = os.path.join(args.output_dir, 'accuracy_stats.json')
with open(accuracy_path, 'w', encoding='utf-8') as f:
json.dump(accuracy_stats, f, indent=2, ensure_ascii=False)
print(f"准确率统计已保存到: {accuracy_path}")
# 保存预测结果
predictions = [{"input": r["input"], "predicted_predicate": r["predicted_predicate"], "gold_predicates": r["target_predicates"]} for r in results]
pred_output_path = os.path.join(args.output_dir, 'predictions.json')
with open(pred_output_path, 'w', encoding='utf-8') as f:
json.dump(predictions, f, indent=2, ensure_ascii=False)
print(f"预测结果已保存到: {pred_output_path}")
if __name__ == "__main__":
main()