Minimind/analyze_database.py
2025-06-17 13:01:20 +08:00

97 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import torch
from transformers import AutoTokenizer
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
"""分析database_init.json文件中的数据条目数量和质量"""
print(f"开始分析数据库文件: {json_path}")
# 1. 加载tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(f"成功加载tokenizer: {tokenizer_path}")
except Exception as e:
print(f"加载tokenizer失败: {e}")
return
# 2. 加载JSON文件
try:
with open(json_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
print(f"加载了 {len(sentences_data)} 条sentences数据")
except Exception as e:
print(f"加载JSON文件失败: {e}")
return
# 3. 分析句子长度分布
if len(sentences_data) == 0:
print("没有找到有效的句子数据")
return
# 按照importance_score排序
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
print(f"按importance_score排序完成最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
# 统计句子长度分布
token_lengths = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 4. 分析token长度分布
for i, sentence_data in enumerate(sorted_sentences):
sentence = sentence_data.get('corrected_sentence', '')
if not sentence:
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
continue
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(sentence_tokens))
if i < 5: # 显示前5条数据样例
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
# 5. 统计分析结果
token_lengths = torch.tensor(token_lengths)
stats = {
"总条目数": len(sorted_sentences),
"有效条目数": len(token_lengths),
"token长度-平均值": token_lengths.float().mean().item(),
"token长度-最小值": token_lengths.min().item(),
"token长度-最大值": token_lengths.max().item(),
"token长度-中位数": token_lengths.median().item(),
"token长度-标准差": token_lengths.float().std().item(),
}
# 统计长度分布
length_bins = {
"小于16": (token_lengths < 16).sum().item(),
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
"256及以上": (token_lengths >= 256).sum().item(),
}
# 打印统计信息
print("\n===== 数据库分析结果 =====")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n===== Token长度分布 =====")
for bin_name, count in length_bins.items():
percentage = (count / len(token_lengths)) * 100
print(f"{bin_name}: {count} ({percentage:.1f}%)")
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
return stats, length_bins
if __name__ == "__main__":
# 指定数据库文件路径
database_path = "./dataset/database_init.json"
analyze_database(database_path)