243 lines
10 KiB
Python
243 lines
10 KiB
Python
import os
|
|
import json
|
|
import argparse
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer
|
|
from model.model_extra import MiniMindLM
|
|
from model.LMConfig import LMConfig
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
|
|
|
|
# 加载谓词词汇表
|
|
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
|
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
|
PREDICATE_LIST = json.load(f)
|
|
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
|
|
NUM_PREDICATES = len(PREDICATE_LIST)
|
|
|
|
def evaluate_model(model, tokenizer, test_data, device):
|
|
"""
|
|
评估模型性能 - 只关注谓词分类
|
|
"""
|
|
model.eval()
|
|
results = []
|
|
all_pred_predicates = []
|
|
all_gold_predicates = []
|
|
correct_predictions = 0
|
|
total_predictions = 0
|
|
|
|
print("开始评估...")
|
|
|
|
# 添加调试信息
|
|
print(f"测试数据样本数量: {len(test_data)}")
|
|
if test_data:
|
|
print(f"第一个样本格式: {type(test_data[0])}")
|
|
print(f"第一个样本内容: {test_data[0]}")
|
|
if isinstance(test_data[0], dict):
|
|
print(f"第一个样本的键: {list(test_data[0].keys())}")
|
|
|
|
for i, item in enumerate(tqdm(test_data, desc="评估进度")):
|
|
input_text = item["input"]
|
|
gold_triples = item.get("output", [])
|
|
|
|
# 调试信息(前几个样本)
|
|
if i < 3:
|
|
print(f"\n样本 {i+1} 调试信息:")
|
|
print(f" 输入文本: {input_text[:100]}...")
|
|
print(f" 真值三元组数量: {len(gold_triples)}")
|
|
if gold_triples:
|
|
print(f" 真值三元组: {gold_triples[0]}")
|
|
|
|
# 模型推理
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
|
input_ids = inputs["input_ids"].to(device)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids=input_ids)
|
|
|
|
# 获取谓词分类结果
|
|
pred_predicate_id = output.predicate_cls_logits.argmax(-1).item()
|
|
pred_predicate = PREDICATE_LIST[pred_predicate_id] if pred_predicate_id < len(PREDICATE_LIST) else "<UNK>"
|
|
|
|
# 收集所有目标谓词
|
|
target_predicates = []
|
|
if gold_triples:
|
|
for triple in gold_triples:
|
|
if "predicate" in triple:
|
|
target_predicates.append(triple["predicate"])
|
|
|
|
# 检查预测是否正确(只要在目标谓词列表中就算正确)
|
|
is_correct = False
|
|
if target_predicates and pred_predicate in target_predicates:
|
|
is_correct = True
|
|
correct_predictions += 1
|
|
total_predictions += 1
|
|
|
|
# 调试信息(前几个样本)
|
|
if i < 3:
|
|
print(f" 预测谓词: {pred_predicate}")
|
|
print(f" 目标谓词: {target_predicates}")
|
|
print(f" 是否正确: {is_correct}")
|
|
|
|
# 收集谓词分类标签(用于详细分析)
|
|
if target_predicates:
|
|
# 取第一个目标谓词作为主要标签
|
|
main_target = target_predicates[0]
|
|
if main_target in PREDICATE2ID:
|
|
all_gold_predicates.append(PREDICATE2ID[main_target])
|
|
all_pred_predicates.append(pred_predicate_id)
|
|
|
|
results.append({
|
|
"input": input_text,
|
|
"predicted_predicate": pred_predicate,
|
|
"target_predicates": target_predicates,
|
|
"is_correct": is_correct
|
|
})
|
|
|
|
print(f"\n评估完成,总预测数: {total_predictions}, 正确数: {correct_predictions}")
|
|
return results, all_pred_predicates, all_gold_predicates, correct_predictions, total_predictions
|
|
|
|
def print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions):
|
|
"""
|
|
打印评估结果摘要 - 只关注谓词分类
|
|
"""
|
|
print("\n" + "="*60)
|
|
print("谓词分类评估结果摘要")
|
|
print("="*60)
|
|
|
|
# 谓词分类准确率
|
|
if total_predictions > 0:
|
|
predicate_accuracy = correct_predictions / total_predictions
|
|
print(f"谓词分类准确率: {predicate_accuracy:.4f} ({correct_predictions}/{total_predictions})")
|
|
else:
|
|
print("谓词分类准确率: 无法计算(没有有效预测)")
|
|
|
|
# 详细的分类报告(如果有足够的标签数据)
|
|
if pred_predicates and gold_predicates and len(pred_predicates) > 10:
|
|
try:
|
|
print(f"\n谓词分类详细报告:")
|
|
print(classification_report(gold_predicates, pred_predicates,
|
|
target_names=PREDICATE_LIST[:10] + ["..."] if len(PREDICATE_LIST) > 10 else PREDICATE_LIST,
|
|
zero_division=0))
|
|
except Exception as e:
|
|
print(f"\n谓词分类详细报告生成失败: {e}")
|
|
|
|
# 样本预测示例
|
|
print(f"\n预测示例 (前5个):")
|
|
for i, result in enumerate(results[:5]):
|
|
print(f"样本 {i+1}:")
|
|
print(f" 输入: {result['input'][:100]}...")
|
|
print(f" 预测谓词: {result['predicted_predicate']}")
|
|
print(f" 目标谓词: {result['target_predicates']}")
|
|
print(f" 是否正确: {'✓' if result['is_correct'] else '✗'}")
|
|
print()
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取模型评估脚本")
|
|
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth')
|
|
parser.add_argument('--tokenizer_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
|
|
parser.add_argument('--test_json', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
|
|
parser.add_argument('--output_dir', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
|
|
parser.add_argument('--device', type=str, default='cuda', help='推理设备')
|
|
|
|
# 模型配置参数
|
|
parser.add_argument('--dim', default=512, type=int)
|
|
parser.add_argument('--n_layers', default=8, type=int)
|
|
parser.add_argument('--max_seq_len', default=512, type=int)
|
|
parser.add_argument('--use_moe', default=False, type=bool)
|
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能")
|
|
parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention")
|
|
parser.add_argument('--knowledge_num', type=int, default=960400, help="知识库的数据数目")
|
|
parser.add_argument('--knowledge_length', type=int, default=32, help="知识库的句子长度")
|
|
parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数")
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# 加载模型和分词器
|
|
print("加载模型和分词器...")
|
|
lm_config = LMConfig(
|
|
dim=args.dim,
|
|
n_layers=args.n_layers,
|
|
max_seq_len=args.max_seq_len,
|
|
use_moe=args.use_moe,
|
|
disable_db=args.disable_db,
|
|
flash_attn=args.flash_attn,
|
|
knowledge_num=args.knowledge_num,
|
|
knowledge_length=args.knowledge_length,
|
|
embeddings_epoch=args.embeddings_epoch
|
|
)
|
|
|
|
model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES)
|
|
|
|
# 加载模型权重
|
|
try:
|
|
state_dict = torch.load(args.model_path, map_location=args.device)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
print(f"成功加载模型权重: {args.model_path}")
|
|
except Exception as e:
|
|
print(f"加载模型权重失败: {e}")
|
|
print("使用随机初始化的模型进行测试")
|
|
|
|
model.to(args.device)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
|
|
|
print(f"谓词词汇表大小: {len(PREDICATE_LIST)}")
|
|
|
|
# 加载测试数据
|
|
print(f"加载测试数据: {args.test_json}")
|
|
with open(args.test_json, 'r', encoding='utf-8') as f:
|
|
test_data = json.load(f)
|
|
|
|
# 支持多种数据格式
|
|
if isinstance(test_data[0], dict) and "text" in test_data[0]:
|
|
# 格式: [{"text": "...", "target": [...]}, ...]
|
|
test_data = [{"input": item["text"], "output": item.get("target", [])} for item in test_data]
|
|
elif isinstance(test_data[0], dict) and "input" in test_data[0]:
|
|
# 格式: [{"input": "...", "output": [...]}, ...]
|
|
pass
|
|
else:
|
|
# 格式: ["句子", ...] - 没有真值,只能做推理
|
|
test_data = [{"input": text, "output": []} for text in test_data]
|
|
|
|
print(f"测试样本数量: {len(test_data)}")
|
|
|
|
# 评估模型
|
|
results, pred_predicates, gold_predicates, correct_predictions, total_predictions = evaluate_model(model, tokenizer, test_data, args.device)
|
|
|
|
# 打印评估结果
|
|
print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions)
|
|
|
|
# 保存详细结果
|
|
output_path = os.path.join(args.output_dir, 'evaluation_results.json')
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
print(f"\n详细评估结果已保存到: {output_path}")
|
|
|
|
# 保存准确率统计
|
|
accuracy_stats = {
|
|
"total_predictions": total_predictions,
|
|
"correct_predictions": correct_predictions,
|
|
"accuracy": correct_predictions / total_predictions if total_predictions > 0 else 0.0,
|
|
"model_path": args.model_path,
|
|
"test_data_path": args.test_json,
|
|
"predicate_vocab_size": len(PREDICATE_LIST),
|
|
"evaluation_timestamp": str(np.datetime64('now'))
|
|
}
|
|
|
|
accuracy_path = os.path.join(args.output_dir, 'accuracy_stats.json')
|
|
with open(accuracy_path, 'w', encoding='utf-8') as f:
|
|
json.dump(accuracy_stats, f, indent=2, ensure_ascii=False)
|
|
print(f"准确率统计已保存到: {accuracy_path}")
|
|
|
|
# 保存预测结果
|
|
predictions = [{"input": r["input"], "predicted_predicate": r["predicted_predicate"], "gold_predicates": r["target_predicates"]} for r in results]
|
|
pred_output_path = os.path.join(args.output_dir, 'predictions.json')
|
|
with open(pred_output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(predictions, f, indent=2, ensure_ascii=False)
|
|
print(f"预测结果已保存到: {pred_output_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |