import json import os import torch from transformers import AutoTokenizer def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'): """分析database_init.json文件中的数据条目数量和质量""" print(f"开始分析数据库文件: {json_path}") # 1. 加载tokenizer try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) print(f"成功加载tokenizer: {tokenizer_path}") except Exception as e: print(f"加载tokenizer失败: {e}") return # 2. 加载JSON文件 try: with open(json_path, 'r', encoding='utf-8') as f: database_data = json.load(f) # 提取sentences列表 sentences_data = database_data.get('sentences', []) print(f"加载了 {len(sentences_data)} 条sentences数据") except Exception as e: print(f"加载JSON文件失败: {e}") return # 3. 分析句子长度分布 if len(sentences_data) == 0: print("没有找到有效的句子数据") return # 按照importance_score排序 sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) print(f"按importance_score排序完成,最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}") # 统计句子长度分布 token_lengths = [] pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 # 4. 分析token长度分布 for i, sentence_data in enumerate(sorted_sentences): sentence = sentence_data.get('corrected_sentence', '') if not sentence: print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段") continue # 将句子转换为tokens sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) token_lengths.append(len(sentence_tokens)) if i < 5: # 显示前5条数据样例 print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})") # 5. 统计分析结果 token_lengths = torch.tensor(token_lengths) stats = { "总条目数": len(sorted_sentences), "有效条目数": len(token_lengths), "token长度-平均值": token_lengths.float().mean().item(), "token长度-最小值": token_lengths.min().item(), "token长度-最大值": token_lengths.max().item(), "token长度-中位数": token_lengths.median().item(), "token长度-标准差": token_lengths.float().std().item(), } # 统计长度分布 length_bins = { "小于16": (token_lengths < 16).sum().item(), "16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(), "32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(), "64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(), "128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(), "256及以上": (token_lengths >= 256).sum().item(), } # 打印统计信息 print("\n===== 数据库分析结果 =====") for key, value in stats.items(): print(f"{key}: {value}") print("\n===== Token长度分布 =====") for bin_name, count in length_bins.items(): percentage = (count / len(token_lengths)) * 100 print(f"{bin_name}: {count} ({percentage:.1f}%)") print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。") return stats, length_bins if __name__ == "__main__": # 指定数据库文件路径 database_path = "./dataset/database_init.json" analyze_database(database_path)