update
This commit is contained in:
parent
ee7aaba91d
commit
77d298c3c6
20
inference.py
20
inference.py
@ -7,12 +7,14 @@ from transformers import AutoTokenizer
|
||||
from model.model_extra import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
|
||||
def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer):
|
||||
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||||
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
||||
PREDICATE_LIST = json.load(f)
|
||||
print(len(PREDICATE_LIST))
|
||||
def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer, predicate_cls_logits=None):
|
||||
# logits: [1, max_len, vocab_size]
|
||||
subject_ids = subject_logits.argmax(-1).squeeze(0).tolist()
|
||||
predicate_ids = predicate_logits.argmax(-1).squeeze(0).tolist()
|
||||
object_ids = object_logits.argmax(-1).squeeze(0).tolist()
|
||||
# 去除pad和eos
|
||||
def clean(ids):
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
@ -22,8 +24,14 @@ def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer):
|
||||
ids = [i for i in ids if i != tokenizer.pad_token_id]
|
||||
return ids
|
||||
subject = tokenizer.decode(clean(subject_ids), skip_special_tokens=True).strip()
|
||||
predicate = tokenizer.decode(clean(predicate_ids), skip_special_tokens=True).strip()
|
||||
object_ = tokenizer.decode(clean(object_ids), skip_special_tokens=True).strip()
|
||||
# 谓词用分类结果
|
||||
if predicate_cls_logits is not None:
|
||||
pred_id = predicate_cls_logits.argmax(-1).item()
|
||||
predicate = PREDICATE_LIST[pred_id] if pred_id < len(PREDICATE_LIST) else "<UNK>"
|
||||
else:
|
||||
predicate_ids = predicate_logits.argmax(-1).squeeze(0).tolist()
|
||||
predicate = tokenizer.decode(clean(predicate_ids), skip_special_tokens=True).strip()
|
||||
return {"subject": subject, "predicate": predicate, "object": object_}
|
||||
|
||||
def infer_triples(model, tokenizer, sentences, device):
|
||||
@ -35,13 +43,13 @@ def infer_triples(model, tokenizer, sentences, device):
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids=input_ids)
|
||||
triple = decode_triple(output.subject_logits, output.predicate_logits, output.object_logits, tokenizer)
|
||||
triple = decode_triple(output.subject_logits, output.predicate_logits, output.object_logits, tokenizer, output.predicate_cls_logits)
|
||||
results.append({"input": sent, "output": [triple]})
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取推理脚本")
|
||||
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_512.pth')
|
||||
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth')
|
||||
parser.add_argument('--tokenizer_path', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
|
||||
parser.add_argument('--input_json', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
|
||||
parser.add_argument('--output_dir', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
|
||||
|
@ -13,6 +13,11 @@ from tqdm import tqdm
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
# 加载谓词类别(与train_extra_accelerate.py保持一致)
|
||||
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||||
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
||||
PREDICATE_LIST = json.load(f)
|
||||
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
@ -302,9 +307,8 @@ class TriplePretrainDataset(Dataset):
|
||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize"""
|
||||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize,增加predicate_label字段"""
|
||||
sample = self.samples[index]
|
||||
|
||||
# 在运行时tokenize输入文本(用于语言建模)
|
||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
@ -316,12 +320,22 @@ class TriplePretrainDataset(Dataset):
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
# 构建训练数据
|
||||
X = input_ids[:-1]
|
||||
Y = input_ids[1:]
|
||||
loss_mask = loss_mask[1:]
|
||||
|
||||
# 提取谓词label
|
||||
# 先尝试从target_sentence中间取出谓词
|
||||
predicate_label = 0 # 默认0
|
||||
try:
|
||||
# target_sentence格式:主语 谓语 宾语
|
||||
triple_str = sample['target_sentence']
|
||||
triple_parts = triple_str.strip().split()
|
||||
if len(triple_parts) >= 3:
|
||||
predicate = triple_parts[1]
|
||||
predicate_label = PREDICATE2ID.get(predicate, 0)
|
||||
except Exception:
|
||||
predicate_label = 0
|
||||
return {
|
||||
'input_ids': X,
|
||||
'labels': Y,
|
||||
@ -329,7 +343,8 @@ class TriplePretrainDataset(Dataset):
|
||||
'target_input_ids': sample['target_input_ids'], # 已经是tensor
|
||||
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
|
||||
'target_sentence': sample['target_sentence'], # 字符串,用于调试
|
||||
'original_text': sample['text']
|
||||
'original_text': sample['text'],
|
||||
'predicate_label': torch.tensor(predicate_label, dtype=torch.long)
|
||||
}
|
||||
|
||||
|
||||
|
@ -475,7 +475,7 @@ class MOEFeedForward(nn.Module):
|
||||
|
||||
class TripleExtractionHead(nn.Module):
|
||||
"""三元组提取任务头"""
|
||||
def __init__(self, config: LMConfig):
|
||||
def __init__(self, config: LMConfig, num_predicates=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -506,6 +506,10 @@ class TripleExtractionHead(nn.Module):
|
||||
self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||
self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||
|
||||
# 分类头
|
||||
self.num_predicates = num_predicates if num_predicates is not None else 617
|
||||
self.predicate_cls = nn.Linear(config.dim, self.num_predicates)
|
||||
|
||||
print(f"三元组提取任务头配置:")
|
||||
print(f"- 主语最大长度: {self.max_subject_len}")
|
||||
print(f"- 谓语最大长度: {self.max_predicate_len}")
|
||||
@ -520,6 +524,7 @@ class TripleExtractionHead(nn.Module):
|
||||
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测
|
||||
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测
|
||||
object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测
|
||||
predicate_cls_logits: [batch_size, num_predicates] - 谓词分类logits
|
||||
"""
|
||||
batch_size, seq_len, dim = h.shape
|
||||
|
||||
@ -532,6 +537,8 @@ class TripleExtractionHead(nn.Module):
|
||||
predicate_features = predicate_features.mean(dim=1)
|
||||
predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||
predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1)
|
||||
# 分类logits
|
||||
predicate_cls_logits = self.predicate_cls(predicate_features) # [batch_size, num_predicates]
|
||||
|
||||
# 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||
h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||
@ -553,7 +560,7 @@ class TripleExtractionHead(nn.Module):
|
||||
object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||
object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||
|
||||
return predicate_logits, subject_logits, object_logits
|
||||
return predicate_logits, subject_logits, object_logits, predicate_cls_logits
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
@ -586,7 +593,7 @@ class MiniMindBlock(nn.Module):
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None,mode="triple"):
|
||||
def __init__(self, params: LMConfig = None, mode="triple", num_predicates=None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
@ -599,7 +606,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# 添加三元组提取任务头(可训练)
|
||||
self.triple_extraction_head = TripleExtractionHead(params)
|
||||
self.triple_extraction_head = TripleExtractionHead(params, num_predicates=num_predicates)
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
@ -656,7 +663,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 应用三元组提取任务头
|
||||
predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis)
|
||||
predicate_logits, subject_logits, object_logits, predicate_cls_logits = self.triple_extraction_head(h, pos_cis)
|
||||
predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1)
|
||||
subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1)
|
||||
object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1)
|
||||
@ -685,6 +692,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
output.predicate_logits = predicate_logits
|
||||
output.subject_logits = subject_logits
|
||||
output.object_logits = object_logits
|
||||
output.predicate_cls_logits = predicate_cls_logits
|
||||
|
||||
return output
|
||||
|
||||
|
27
stat_predicate_vocab.py
Normal file
27
stat_predicate_vocab.py
Normal file
@ -0,0 +1,27 @@
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
input_path = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/processed_trex_data.json'
|
||||
output_path = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||||
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
predicate_set = set()
|
||||
|
||||
for item in data:
|
||||
if 'target' in item and isinstance(item['target'], list):
|
||||
# 用集合去重本条数据的谓词
|
||||
predicates_in_item = set()
|
||||
for triple in item['target']:
|
||||
if isinstance(triple, dict) and 'predicate' in triple:
|
||||
predicates_in_item.add(triple['predicate'])
|
||||
predicate_set.update(predicates_in_item)
|
||||
|
||||
predicate_list = list(predicate_set)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(predicate_list, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f'已统计{len(predicate_list)}个谓词,保存到 {output_path}')
|
||||
|
243
test.py
Normal file
243
test.py
Normal file
@ -0,0 +1,243 @@
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_extra import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
|
||||
|
||||
# 加载谓词词汇表
|
||||
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||||
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
||||
PREDICATE_LIST = json.load(f)
|
||||
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
|
||||
NUM_PREDICATES = len(PREDICATE_LIST)
|
||||
|
||||
def evaluate_model(model, tokenizer, test_data, device):
|
||||
"""
|
||||
评估模型性能 - 只关注谓词分类
|
||||
"""
|
||||
model.eval()
|
||||
results = []
|
||||
all_pred_predicates = []
|
||||
all_gold_predicates = []
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
print("开始评估...")
|
||||
|
||||
# 添加调试信息
|
||||
print(f"测试数据样本数量: {len(test_data)}")
|
||||
if test_data:
|
||||
print(f"第一个样本格式: {type(test_data[0])}")
|
||||
print(f"第一个样本内容: {test_data[0]}")
|
||||
if isinstance(test_data[0], dict):
|
||||
print(f"第一个样本的键: {list(test_data[0].keys())}")
|
||||
|
||||
for i, item in enumerate(tqdm(test_data, desc="评估进度")):
|
||||
input_text = item["input"]
|
||||
gold_triples = item.get("output", [])
|
||||
|
||||
# 调试信息(前几个样本)
|
||||
if i < 3:
|
||||
print(f"\n样本 {i+1} 调试信息:")
|
||||
print(f" 输入文本: {input_text[:100]}...")
|
||||
print(f" 真值三元组数量: {len(gold_triples)}")
|
||||
if gold_triples:
|
||||
print(f" 真值三元组: {gold_triples[0]}")
|
||||
|
||||
# 模型推理
|
||||
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids=input_ids)
|
||||
|
||||
# 获取谓词分类结果
|
||||
pred_predicate_id = output.predicate_cls_logits.argmax(-1).item()
|
||||
pred_predicate = PREDICATE_LIST[pred_predicate_id] if pred_predicate_id < len(PREDICATE_LIST) else "<UNK>"
|
||||
|
||||
# 收集所有目标谓词
|
||||
target_predicates = []
|
||||
if gold_triples:
|
||||
for triple in gold_triples:
|
||||
if "predicate" in triple:
|
||||
target_predicates.append(triple["predicate"])
|
||||
|
||||
# 检查预测是否正确(只要在目标谓词列表中就算正确)
|
||||
is_correct = False
|
||||
if target_predicates and pred_predicate in target_predicates:
|
||||
is_correct = True
|
||||
correct_predictions += 1
|
||||
total_predictions += 1
|
||||
|
||||
# 调试信息(前几个样本)
|
||||
if i < 3:
|
||||
print(f" 预测谓词: {pred_predicate}")
|
||||
print(f" 目标谓词: {target_predicates}")
|
||||
print(f" 是否正确: {is_correct}")
|
||||
|
||||
# 收集谓词分类标签(用于详细分析)
|
||||
if target_predicates:
|
||||
# 取第一个目标谓词作为主要标签
|
||||
main_target = target_predicates[0]
|
||||
if main_target in PREDICATE2ID:
|
||||
all_gold_predicates.append(PREDICATE2ID[main_target])
|
||||
all_pred_predicates.append(pred_predicate_id)
|
||||
|
||||
results.append({
|
||||
"input": input_text,
|
||||
"predicted_predicate": pred_predicate,
|
||||
"target_predicates": target_predicates,
|
||||
"is_correct": is_correct
|
||||
})
|
||||
|
||||
print(f"\n评估完成,总预测数: {total_predictions}, 正确数: {correct_predictions}")
|
||||
return results, all_pred_predicates, all_gold_predicates, correct_predictions, total_predictions
|
||||
|
||||
def print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions):
|
||||
"""
|
||||
打印评估结果摘要 - 只关注谓词分类
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("谓词分类评估结果摘要")
|
||||
print("="*60)
|
||||
|
||||
# 谓词分类准确率
|
||||
if total_predictions > 0:
|
||||
predicate_accuracy = correct_predictions / total_predictions
|
||||
print(f"谓词分类准确率: {predicate_accuracy:.4f} ({correct_predictions}/{total_predictions})")
|
||||
else:
|
||||
print("谓词分类准确率: 无法计算(没有有效预测)")
|
||||
|
||||
# 详细的分类报告(如果有足够的标签数据)
|
||||
if pred_predicates and gold_predicates and len(pred_predicates) > 10:
|
||||
try:
|
||||
print(f"\n谓词分类详细报告:")
|
||||
print(classification_report(gold_predicates, pred_predicates,
|
||||
target_names=PREDICATE_LIST[:10] + ["..."] if len(PREDICATE_LIST) > 10 else PREDICATE_LIST,
|
||||
zero_division=0))
|
||||
except Exception as e:
|
||||
print(f"\n谓词分类详细报告生成失败: {e}")
|
||||
|
||||
# 样本预测示例
|
||||
print(f"\n预测示例 (前5个):")
|
||||
for i, result in enumerate(results[:5]):
|
||||
print(f"样本 {i+1}:")
|
||||
print(f" 输入: {result['input'][:100]}...")
|
||||
print(f" 预测谓词: {result['predicted_predicate']}")
|
||||
print(f" 目标谓词: {result['target_predicates']}")
|
||||
print(f" 是否正确: {'✓' if result['is_correct'] else '✗'}")
|
||||
print()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取模型评估脚本")
|
||||
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth')
|
||||
parser.add_argument('--tokenizer_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
|
||||
parser.add_argument('--test_json', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
|
||||
parser.add_argument('--output_dir', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='推理设备')
|
||||
|
||||
# 模型配置参数
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能")
|
||||
parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention")
|
||||
parser.add_argument('--knowledge_num', type=int, default=960400, help="知识库的数据数目")
|
||||
parser.add_argument('--knowledge_length', type=int, default=32, help="知识库的句子长度")
|
||||
parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 加载模型和分词器
|
||||
print("加载模型和分词器...")
|
||||
lm_config = LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.flash_attn,
|
||||
knowledge_num=args.knowledge_num,
|
||||
knowledge_length=args.knowledge_length,
|
||||
embeddings_epoch=args.embeddings_epoch
|
||||
)
|
||||
|
||||
model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES)
|
||||
|
||||
# 加载模型权重
|
||||
try:
|
||||
state_dict = torch.load(args.model_path, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print(f"成功加载模型权重: {args.model_path}")
|
||||
except Exception as e:
|
||||
print(f"加载模型权重失败: {e}")
|
||||
print("使用随机初始化的模型进行测试")
|
||||
|
||||
model.to(args.device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||
|
||||
print(f"谓词词汇表大小: {len(PREDICATE_LIST)}")
|
||||
|
||||
# 加载测试数据
|
||||
print(f"加载测试数据: {args.test_json}")
|
||||
with open(args.test_json, 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
|
||||
# 支持多种数据格式
|
||||
if isinstance(test_data[0], dict) and "text" in test_data[0]:
|
||||
# 格式: [{"text": "...", "target": [...]}, ...]
|
||||
test_data = [{"input": item["text"], "output": item.get("target", [])} for item in test_data]
|
||||
elif isinstance(test_data[0], dict) and "input" in test_data[0]:
|
||||
# 格式: [{"input": "...", "output": [...]}, ...]
|
||||
pass
|
||||
else:
|
||||
# 格式: ["句子", ...] - 没有真值,只能做推理
|
||||
test_data = [{"input": text, "output": []} for text in test_data]
|
||||
|
||||
print(f"测试样本数量: {len(test_data)}")
|
||||
|
||||
# 评估模型
|
||||
results, pred_predicates, gold_predicates, correct_predictions, total_predictions = evaluate_model(model, tokenizer, test_data, args.device)
|
||||
|
||||
# 打印评估结果
|
||||
print_evaluation_summary(results, pred_predicates, gold_predicates, correct_predictions, total_predictions)
|
||||
|
||||
# 保存详细结果
|
||||
output_path = os.path.join(args.output_dir, 'evaluation_results.json')
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"\n详细评估结果已保存到: {output_path}")
|
||||
|
||||
# 保存准确率统计
|
||||
accuracy_stats = {
|
||||
"total_predictions": total_predictions,
|
||||
"correct_predictions": correct_predictions,
|
||||
"accuracy": correct_predictions / total_predictions if total_predictions > 0 else 0.0,
|
||||
"model_path": args.model_path,
|
||||
"test_data_path": args.test_json,
|
||||
"predicate_vocab_size": len(PREDICATE_LIST),
|
||||
"evaluation_timestamp": str(np.datetime64('now'))
|
||||
}
|
||||
|
||||
accuracy_path = os.path.join(args.output_dir, 'accuracy_stats.json')
|
||||
with open(accuracy_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(accuracy_stats, f, indent=2, ensure_ascii=False)
|
||||
print(f"准确率统计已保存到: {accuracy_path}")
|
||||
|
||||
# 保存预测结果
|
||||
predictions = [{"input": r["input"], "predicted_predicate": r["predicted_predicate"], "gold_predicates": r["target_predicates"]} for r in results]
|
||||
pred_output_path = os.path.join(args.output_dir, 'predictions.json')
|
||||
with open(pred_output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(predictions, f, indent=2, ensure_ascii=False)
|
||||
print(f"预测结果已保存到: {pred_output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -25,6 +25,7 @@ import swanlab # 替换wandb导入
|
||||
import gc # 添加垃圾回收模块
|
||||
import psutil # 添加系统资源监控模块
|
||||
import os
|
||||
import json
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES']='2'
|
||||
from model.model_extra import MiniMindLM, RMSNorm # 使用model_extra
|
||||
@ -208,7 +209,7 @@ def compute_cosine_similarity_batch(pred_embeddings, target_embeddings):
|
||||
|
||||
return similarities
|
||||
|
||||
def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenizer):
|
||||
def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenizer, predicate_cls_logits=None):
|
||||
"""
|
||||
将三元组logits转换为句子
|
||||
Args:
|
||||
@ -216,54 +217,54 @@ def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenize
|
||||
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size]
|
||||
object_logits: [batch_size, seq_len, max_object_len, vocab_size]
|
||||
tokenizer: 分词器
|
||||
predicate_cls_logits: [batch_size, num_predicates],如果提供则用分类结果输出谓词
|
||||
Returns:
|
||||
List[List[str]]: 每个样本每个位置的三元组句子
|
||||
List[str]: 每个样本的三元组句子
|
||||
"""
|
||||
batch_size = subject_logits.shape[0]
|
||||
predicate_seq_len = predicate_logits.shape[1]
|
||||
# 主语
|
||||
subject_seq_len = subject_logits.shape[1]
|
||||
subject_logits_ = subject_logits.reshape(batch_size * subject_seq_len, -1)
|
||||
subject_ids = torch.argmax(subject_logits_, dim=-1)
|
||||
subject_ids = subject_ids.reshape(batch_size, subject_seq_len)
|
||||
# 宾语
|
||||
object_seq_len = object_logits.shape[1]
|
||||
object_logits_ = object_logits.reshape(batch_size * object_seq_len, -1)
|
||||
object_ids = torch.argmax(object_logits_, dim=-1)
|
||||
object_ids = object_ids.reshape(batch_size, object_seq_len)
|
||||
|
||||
predicate_logits = predicate_logits.reshape(batch_size*predicate_seq_len, -1)
|
||||
subject_logits = subject_logits.reshape(batch_size*subject_seq_len, -1)
|
||||
object_logits = object_logits.reshape(batch_size*object_seq_len, -1)
|
||||
# 谓词
|
||||
predicate_texts = []
|
||||
if predicate_cls_logits is not None:
|
||||
# 用分类结果输出谓词
|
||||
pred_ids = torch.argmax(predicate_cls_logits, dim=-1) # [batch_size]
|
||||
for i in range(batch_size):
|
||||
pred_id = pred_ids[i].item()
|
||||
pred_text = PREDICATE_LIST[pred_id] if pred_id < len(PREDICATE_LIST) else "<UNK>"
|
||||
predicate_texts.append(pred_text)
|
||||
else:
|
||||
# 兼容原有行为:用序列生成的谓词
|
||||
predicate_seq_len = predicate_logits.shape[1]
|
||||
predicate_logits_ = predicate_logits.reshape(batch_size * predicate_seq_len, -1)
|
||||
predicate_ids = torch.argmax(predicate_logits_, dim=-1)
|
||||
predicate_ids = predicate_ids.reshape(batch_size, predicate_seq_len)
|
||||
predicate_texts = tokenizer.batch_decode(predicate_ids, skip_special_tokens=True)
|
||||
|
||||
predicate_logits = torch.argmax(predicate_logits, dim=-1)
|
||||
subject_logits = torch.argmax(subject_logits, dim=-1)
|
||||
object_logits = torch.argmax(object_logits, dim=-1)
|
||||
|
||||
predicate_logits = predicate_logits.reshape(batch_size, predicate_seq_len)
|
||||
subject_logits = subject_logits.reshape(batch_size, subject_seq_len)
|
||||
object_logits = object_logits.reshape(batch_size, object_seq_len)
|
||||
|
||||
combined_logits = torch.cat([subject_logits, predicate_logits, object_logits], dim=1)
|
||||
|
||||
sentences = tokenizer.batch_decode(combined_logits, skip_special_tokens=True)
|
||||
|
||||
# sentences = []
|
||||
|
||||
# for batch_idx in range(batch_size):
|
||||
# batch_sentences = []
|
||||
# for seq_idx in range(seq_len):
|
||||
# # 获取预测的token ids
|
||||
# subject_ids = torch.argmax(subject_logits[batch_idx, seq_idx], dim=-1)
|
||||
# predicate_ids = torch.argmax(predicate_logits[batch_idx, seq_idx], dim=-1)
|
||||
# object_ids = torch.argmax(object_logits[batch_idx, seq_idx], dim=-1)
|
||||
|
||||
# # 转换为文本
|
||||
# subject_text = tokenizer.decode(subject_ids, skip_special_tokens=True).strip()
|
||||
# predicate_text = tokenizer.decode(predicate_ids, skip_special_tokens=True).strip()
|
||||
# object_text = tokenizer.decode(object_ids, skip_special_tokens=True).strip()
|
||||
|
||||
# # 拼接为句子 (主语 + 谓语 + 宾语)
|
||||
# if subject_text and predicate_text and object_text:
|
||||
# sentence = f"{subject_text} {predicate_text} {object_text}"
|
||||
# else:
|
||||
# sentence = ""
|
||||
|
||||
# batch_sentences.append(sentence)
|
||||
# sentences.append(batch_sentences)
|
||||
# 主语和宾语文本
|
||||
subject_texts = tokenizer.batch_decode(subject_ids, skip_special_tokens=True)
|
||||
object_texts = tokenizer.batch_decode(object_ids, skip_special_tokens=True)
|
||||
|
||||
# 拼接为三元组句子
|
||||
sentences = []
|
||||
for i in range(batch_size):
|
||||
subject = subject_texts[i].strip()
|
||||
predicate = predicate_texts[i].strip() if isinstance(predicate_texts[i], str) else str(predicate_texts[i])
|
||||
object_ = object_texts[i].strip()
|
||||
if subject and predicate and object_:
|
||||
sentence = f"{subject} {predicate} {object_}"
|
||||
else:
|
||||
sentence = ""
|
||||
sentences.append(sentence)
|
||||
return sentences
|
||||
|
||||
def compute_triple_rouge_loss_optimized(subject_logits, predicate_logits, object_logits,
|
||||
@ -414,10 +415,17 @@ def get_lr(it, num_iters, learning_rate):
|
||||
# 余弦学习率衰减
|
||||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||
|
||||
# 加载谓词类别
|
||||
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||||
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
||||
PREDICATE_LIST = json.load(f)
|
||||
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
|
||||
NUM_PREDICATES = len(PREDICATE_LIST)
|
||||
|
||||
# 初始化模型函数
|
||||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config, mode="triple") # 设置为三元组模式
|
||||
model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES)
|
||||
|
||||
# 加载预训练权重
|
||||
pretrained_path = "/home/rwkv/RWKV-TS/RETRO_TEST/extract/Experiment_1_2_2_pretrain_512.pth"
|
||||
@ -553,6 +561,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
last_log_time = epoch_start_time
|
||||
|
||||
# 使用DataLoader内置的iterator,移除自定义预取
|
||||
criterion_predicate = nn.CrossEntropyLoss()
|
||||
for step, batch_data in enumerate(train_loader):
|
||||
# === 每个step开始 ===
|
||||
|
||||
@ -611,7 +620,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
loss_start.record()
|
||||
|
||||
# 计算优化后的嵌入余弦相似度损失
|
||||
loss = compute_triple_rouge_loss_optimized(
|
||||
loss_triple = compute_triple_rouge_loss_optimized(
|
||||
res.subject_logits, res.predicate_logits, res.object_logits,
|
||||
target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature
|
||||
)
|
||||
@ -624,8 +633,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
Logger(f"Error: ROUGE loss computation failed: {e}", accelerator)
|
||||
import traceback
|
||||
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
|
||||
loss = res.logits.sum() * 0.0 + 1.0
|
||||
loss_triple = res.logits.sum() * 0.0 + 1.0
|
||||
|
||||
# 谓词分类loss
|
||||
loss_predicate = criterion_predicate(res.predicate_cls_logits, batch_data['predicate_label'].to(accelerator.device))
|
||||
|
||||
# 总loss
|
||||
loss = 0.99*loss_triple + 0.01*loss_predicate
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# === 5. 反向传播 ===
|
||||
@ -686,7 +700,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
Logger("=" * 50, accelerator)
|
||||
|
||||
Logger("=== 三元组预测示例 ===", accelerator)
|
||||
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
||||
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits, tokenizer)
|
||||
# 显示前2个样本的目标句子
|
||||
for i, target_sentence in enumerate(target_sentences[:2]):
|
||||
Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
||||
@ -728,7 +742,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
# 基本训练信息
|
||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||
f"Loss: {loss.item() * args.accumulation_steps:.6f}, "
|
||||
f"Loss(triple): {loss_triple.item() * args.accumulation_steps:.6f}, "
|
||||
f"Loss(predicate): {loss_predicate.item() * args.accumulation_steps:.6f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
@ -740,7 +755,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
"epoch": epoch + 1,
|
||||
"step": step + 1,
|
||||
"total_steps_in_epoch": total_steps_in_epoch,
|
||||
"triple_embedding_cosine_loss": loss.item() * args.accumulation_steps,
|
||||
"triple_embedding_cosine_loss": loss_triple.item() * args.accumulation_steps,
|
||||
"predicate_cross_entropy_loss": loss_predicate.item() * args.accumulation_steps,
|
||||
"lr": current_lr,
|
||||
"tokens_per_sec": tokens_per_sec,
|
||||
"epoch_time_left_seconds": epoch_remaining_time,
|
||||
@ -753,7 +769,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
loss_total = loss.item() * args.accumulation_steps
|
||||
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
|
||||
best_loss = loss_total
|
||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||
ckp = f'{args.save_dir}/pretrain_cls{args.dim}{moe_path}.pth'
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
accelerator.save(unwrapped_model.state_dict(), ckp)
|
||||
Logger(f"Model saved to {ckp}", accelerator)
|
||||
@ -945,14 +961,15 @@ def main():
|
||||
target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
||||
target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
||||
target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
||||
|
||||
predicate_label = torch.stack([item['predicate_label'] for item in batch])
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'target_input_ids': target_input_ids,
|
||||
'target_attention_mask': target_attention_mask,
|
||||
'target_sentences': target_sentences
|
||||
'target_sentences': target_sentences,
|
||||
'predicate_label': predicate_label
|
||||
}
|
||||
|
||||
train_loader = DataLoader(
|
||||
|
Loading…
x
Reference in New Issue
Block a user