import os import json import argparse import torch from tqdm import tqdm from transformers import AutoTokenizer from model.model_extra import MiniMindLM from model.LMConfig import LMConfig def decode_triple(subject_logits, predicate_logits, object_logits, tokenizer): # logits: [1, max_len, vocab_size] subject_ids = subject_logits.argmax(-1).squeeze(0).tolist() predicate_ids = predicate_logits.argmax(-1).squeeze(0).tolist() object_ids = object_logits.argmax(-1).squeeze(0).tolist() # 去除pad和eos def clean(ids): if isinstance(ids, int): ids = [ids] if tokenizer.eos_token_id in ids: ids = ids[:ids.index(tokenizer.eos_token_id)] if tokenizer.pad_token_id in ids: ids = [i for i in ids if i != tokenizer.pad_token_id] return ids subject = tokenizer.decode(clean(subject_ids), skip_special_tokens=True).strip() predicate = tokenizer.decode(clean(predicate_ids), skip_special_tokens=True).strip() object_ = tokenizer.decode(clean(object_ids), skip_special_tokens=True).strip() return {"subject": subject, "predicate": predicate, "object": object_} def infer_triples(model, tokenizer, sentences, device): results = [] model.eval() for sent in tqdm(sentences, desc="推理中"): # 编码 inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=512, padding='max_length') input_ids = inputs["input_ids"].to(device) with torch.no_grad(): output = model(input_ids=input_ids) triple = decode_triple(output.subject_logits, output.predicate_logits, output.object_logits, tokenizer) results.append({"input": sent, "output": [triple]}) return results def main(): parser = argparse.ArgumentParser(description="MiniMind 三元组抽取推理脚本") parser.add_argument('--model_path', type=str, default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out/pretrain_512.pth') parser.add_argument('--tokenizer_path', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/model/minimind_tokenizer') parser.add_argument('--input_json', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json') parser.add_argument('--output_dir', type=str,default='/home/rwkv/RWKV-TS/RETRO_TEST/Minimind/out', help='输出目录') parser.add_argument('--device', type=str, default='cuda', help='推理设备') # 以下参数与train保持一致 parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") parser.add_argument('--flash_attn', action='store_true', default=True, help="启用FlashAttention") parser.add_argument('--knowledge_num', type=int, default=960400,help="知识库的数据数目") parser.add_argument('--knowledge_length', type=int, default=32,help="知识库的句子长度") parser.add_argument('--embeddings_epoch', type=int, default=2, help="embedding训练的epoch数") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # 加载模型和分词器 print("加载模型和分词器...") lm_config = LMConfig( dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe, disable_db=args.disable_db, flash_attn=args.flash_attn, knowledge_num=args.knowledge_num, knowledge_length=args.knowledge_length, embeddings_epoch=args.embeddings_epoch ) model = MiniMindLM(lm_config) model.load_state_dict(torch.load(args.model_path, map_location=args.device)) model.to(args.device) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) with open(args.input_json, 'r', encoding='utf-8') as f: data = json.load(f) # 支持两种格式:[{"text":...}, ...] 或 ["句子", ...] if isinstance(data[0], dict) and "text" in data[0]: sentences = [item["text"] for item in data] elif isinstance(data[0], dict) and "input" in data[0]: sentences = [item["input"] for item in data] else: sentences = data results = infer_triples(model, tokenizer, sentences, args.device) output_path = os.path.join(args.output_dir, os.path.basename(args.input_json).replace('.json', '_triples.json')) with open(output_path, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"已保存预测结果到: {output_path}") if __name__ == "__main__": main()