Minimind/inference.py
2025-07-05 03:03:43 +00:00

108 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from model.model_extra import MiniMindLM
from model.LMConfig import LMConfig
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
PREDICATE_LIST = json.load(f)
print(len(PREDICATE_LIST))
def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer, predicate_cls_logits=None):
# logits: [1, max_len, vocab_size]
subject_ids = subject_logits.argmax(-1).squeeze(0).tolist()
object_ids = object_logits.argmax(-1).squeeze(0).tolist()
def clean(ids):
if isinstance(ids, int):
ids = [ids]
if tokenizer.eos_token_id in ids:
ids = ids[:ids.index(tokenizer.eos_token_id)]
if tokenizer.pad_token_id in ids:
ids = [i for i in ids if i != tokenizer.pad_token_id]
return ids
subject = tokenizer.decode(clean(subject_ids), skip_special_tokens=True).strip()
object_ = tokenizer.decode(clean(object_ids), skip_special_tokens=True).strip()
# 谓词用分类结果
if predicate_cls_logits is not None:
pred_id = predicate_cls_logits.argmax(-1).item()
predicate = PREDICATE_LIST[pred_id] if pred_id < len(PREDICATE_LIST) else "<UNK>"
else:
predicate_ids = predicate_logits.argmax(-1).squeeze(0).tolist()
predicate = tokenizer.decode(clean(predicate_ids), skip_special_tokens=True).strip()
return {"subject": subject, "predicate": predicate, "object": object_}
def infer_triples(model, tokenizer, sentences, device):
results = []
model.eval()
for sent in tqdm(sentences, desc="推理中"):
# 编码
inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
output = model(input_ids=input_ids)
triple = decode_triple(output.subject_logits, output.predicate_logits, output.object_logits, tokenizer, output.predicate_cls_logits)
results.append({"input": sent, "output": [triple]})
return results
def main():
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取推理脚本")
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_cls512.pth')
parser.add_argument('--tokenizer_path', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
parser.add_argument('--input_json', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
parser.add_argument('--output_dir', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
parser.add_argument('--device', type=str, default='cuda', help='推理设备')
# 以下参数与train保持一致
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention")
parser.add_argument('--knowledge_num', type=int, default=960400,help="知识库的数据数目")
parser.add_argument('--knowledge_length', type=int, default=32,help="知识库的句子长度")
parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# 加载模型和分词器
print("加载模型和分词器...")
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embeddings_epoch
)
model = MiniMindLM(lm_config)
model.load_state_dict(torch.load(args.model_path, map_location=args.device))
model.to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
with open(args.input_json, 'r', encoding='utf-8') as f:
data = json.load(f)
# 支持两种格式:[{"text":...}, ...] 或 ["句子", ...]
if isinstance(data[0], dict) and "text" in data[0]:
sentences = [item["text"] for item in data]
elif isinstance(data[0], dict) and "input" in data[0]:
sentences = [item["input"] for item in data]
else:
sentences = data
results = infer_triples(model, tokenizer, sentences, args.device)
output_path = os.path.join(args.output_dir, os.path.basename(args.input_json).replace('.json', '_triples.json'))
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"已保存预测结果到: {output_path}")
if __name__ == "__main__":
main()