Minimind/analyze_database.py

97 lines
3.7 KiB
Python
Raw Normal View History

2025-06-17 13:01:20 +08:00
import json
import os
import torch
from transformers import AutoTokenizer
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
"""分析database_init.json文件中的数据条目数量和质量"""
print(f"开始分析数据库文件: {json_path}")
# 1. 加载tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(f"成功加载tokenizer: {tokenizer_path}")
except Exception as e:
print(f"加载tokenizer失败: {e}")
return
# 2. 加载JSON文件
try:
with open(json_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
print(f"加载了 {len(sentences_data)} 条sentences数据")
except Exception as e:
print(f"加载JSON文件失败: {e}")
return
# 3. 分析句子长度分布
if len(sentences_data) == 0:
print("没有找到有效的句子数据")
return
# 按照importance_score排序
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
print(f"按importance_score排序完成最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
# 统计句子长度分布
token_lengths = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 4. 分析token长度分布
for i, sentence_data in enumerate(sorted_sentences):
sentence = sentence_data.get('corrected_sentence', '')
if not sentence:
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
continue
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(sentence_tokens))
if i < 5: # 显示前5条数据样例
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
# 5. 统计分析结果
token_lengths = torch.tensor(token_lengths)
stats = {
"总条目数": len(sorted_sentences),
"有效条目数": len(token_lengths),
"token长度-平均值": token_lengths.float().mean().item(),
"token长度-最小值": token_lengths.min().item(),
"token长度-最大值": token_lengths.max().item(),
"token长度-中位数": token_lengths.median().item(),
"token长度-标准差": token_lengths.float().std().item(),
}
# 统计长度分布
length_bins = {
"小于16": (token_lengths < 16).sum().item(),
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
"256及以上": (token_lengths >= 256).sum().item(),
}
# 打印统计信息
print("\n===== 数据库分析结果 =====")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n===== Token长度分布 =====")
for bin_name, count in length_bins.items():
percentage = (count / len(token_lengths)) * 100
print(f"{bin_name}: {count} ({percentage:.1f}%)")
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
return stats, length_bins
if __name__ == "__main__":
# 指定数据库文件路径
database_path = "./dataset/database_init.json"
analyze_database(database_path)