update infer
This commit is contained in:
parent
7ce71f24bc
commit
ee7aaba91d
99
inference.py
Normal file
99
inference.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from model.model_extra import MiniMindLM
|
||||||
|
from model.LMConfig import LMConfig
|
||||||
|
|
||||||
|
def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer):
|
||||||
|
# logits: [1, max_len, vocab_size]
|
||||||
|
subject_ids = subject_logits.argmax(-1).squeeze(0).tolist()
|
||||||
|
predicate_ids = predicate_logits.argmax(-1).squeeze(0).tolist()
|
||||||
|
object_ids = object_logits.argmax(-1).squeeze(0).tolist()
|
||||||
|
# 去除pad和eos
|
||||||
|
def clean(ids):
|
||||||
|
if isinstance(ids, int):
|
||||||
|
ids = [ids]
|
||||||
|
if tokenizer.eos_token_id in ids:
|
||||||
|
ids = ids[:ids.index(tokenizer.eos_token_id)]
|
||||||
|
if tokenizer.pad_token_id in ids:
|
||||||
|
ids = [i for i in ids if i != tokenizer.pad_token_id]
|
||||||
|
return ids
|
||||||
|
subject = tokenizer.decode(clean(subject_ids), skip_special_tokens=True).strip()
|
||||||
|
predicate = tokenizer.decode(clean(predicate_ids), skip_special_tokens=True).strip()
|
||||||
|
object_ = tokenizer.decode(clean(object_ids), skip_special_tokens=True).strip()
|
||||||
|
return {"subject": subject, "predicate": predicate, "object": object_}
|
||||||
|
|
||||||
|
def infer_triples(model, tokenizer, sentences, device):
|
||||||
|
results = []
|
||||||
|
model.eval()
|
||||||
|
for sent in tqdm(sentences, desc="推理中"):
|
||||||
|
# 编码
|
||||||
|
inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids=input_ids)
|
||||||
|
triple = decode_triple(output.subject_logits, output.predicate_logits, output.object_logits, tokenizer)
|
||||||
|
results.append({"input": sent, "output": [triple]})
|
||||||
|
return results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind 三元组抽取推理脚本")
|
||||||
|
parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_512.pth')
|
||||||
|
parser.add_argument('--tokenizer_path', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer')
|
||||||
|
parser.add_argument('--input_json', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json')
|
||||||
|
parser.add_argument('--output_dir', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录')
|
||||||
|
parser.add_argument('--device', type=str, default='cuda', help='推理设备')
|
||||||
|
# 以下参数与train保持一致
|
||||||
|
parser.add_argument('--dim', default=512, type=int)
|
||||||
|
parser.add_argument('--n_layers', default=8, type=int)
|
||||||
|
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||||
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
|
parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention")
|
||||||
|
parser.add_argument('--knowledge_num', type=int, default=960400,help="知识库的数据数目")
|
||||||
|
parser.add_argument('--knowledge_length', type=int, default=32,help="知识库的句子长度")
|
||||||
|
parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载模型和分词器
|
||||||
|
print("加载模型和分词器...")
|
||||||
|
lm_config = LMConfig(
|
||||||
|
dim=args.dim,
|
||||||
|
n_layers=args.n_layers,
|
||||||
|
max_seq_len=args.max_seq_len,
|
||||||
|
use_moe=args.use_moe,
|
||||||
|
disable_db=args.disable_db,
|
||||||
|
flash_attn=args.flash_attn,
|
||||||
|
knowledge_num=args.knowledge_num,
|
||||||
|
knowledge_length=args.knowledge_length,
|
||||||
|
embeddings_epoch=args.embeddings_epoch
|
||||||
|
)
|
||||||
|
model = MiniMindLM(lm_config)
|
||||||
|
model.load_state_dict(torch.load(args.model_path, map_location=args.device))
|
||||||
|
model.to(args.device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||||
|
|
||||||
|
with open(args.input_json, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# 支持两种格式:[{"text":...}, ...] 或 ["句子", ...]
|
||||||
|
if isinstance(data[0], dict) and "text" in data[0]:
|
||||||
|
sentences = [item["text"] for item in data]
|
||||||
|
elif isinstance(data[0], dict) and "input" in data[0]:
|
||||||
|
sentences = [item["input"] for item in data]
|
||||||
|
else:
|
||||||
|
sentences = data
|
||||||
|
|
||||||
|
results = infer_triples(model, tokenizer, sentences, args.device)
|
||||||
|
|
||||||
|
output_path = os.path.join(args.output_dir, os.path.basename(args.input_json).replace('.json', '_triples.json'))
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"已保存预测结果到: {output_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -420,7 +420,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
model = MiniMindLM(lm_config, mode="triple") # 设置为三元组模式
|
model = MiniMindLM(lm_config, mode="triple") # 设置为三元组模式
|
||||||
|
|
||||||
# 加载预训练权重
|
# 加载预训练权重
|
||||||
pretrained_path = "./out/Experiment_1_2_2_pretrain_512.pth"
|
pretrained_path = "/home/rwkv/RWKV-TS/RETRO_TEST/extract/Experiment_1_2_2_pretrain_512.pth"
|
||||||
Logger(f"Loading pretrained weights from {pretrained_path}")
|
Logger(f"Loading pretrained weights from {pretrained_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -794,7 +794,7 @@ def main():
|
|||||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||||
parser.add_argument('--use_moe', default=False, type=bool)
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
parser.add_argument("--data_path", type=str, default="/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json")
|
parser.add_argument("--data_path", type=str, default="/home/rwkv/RWKV-TS/RETRO_TEST/extract/processed_trex_data.json")
|
||||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user