97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
|
import json
|
|||
|
import os
|
|||
|
import torch
|
|||
|
from transformers import AutoTokenizer
|
|||
|
|
|||
|
def analyze_database(json_path, tokenizer_path='./model/minimind_tokenizer'):
|
|||
|
"""分析database_init.json文件中的数据条目数量和质量"""
|
|||
|
|
|||
|
print(f"开始分析数据库文件: {json_path}")
|
|||
|
|
|||
|
# 1. 加载tokenizer
|
|||
|
try:
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|||
|
print(f"成功加载tokenizer: {tokenizer_path}")
|
|||
|
except Exception as e:
|
|||
|
print(f"加载tokenizer失败: {e}")
|
|||
|
return
|
|||
|
|
|||
|
# 2. 加载JSON文件
|
|||
|
try:
|
|||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|||
|
database_data = json.load(f)
|
|||
|
|
|||
|
# 提取sentences列表
|
|||
|
sentences_data = database_data.get('sentences', [])
|
|||
|
print(f"加载了 {len(sentences_data)} 条sentences数据")
|
|||
|
except Exception as e:
|
|||
|
print(f"加载JSON文件失败: {e}")
|
|||
|
return
|
|||
|
|
|||
|
# 3. 分析句子长度分布
|
|||
|
if len(sentences_data) == 0:
|
|||
|
print("没有找到有效的句子数据")
|
|||
|
return
|
|||
|
|
|||
|
# 按照importance_score排序
|
|||
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
|||
|
print(f"按importance_score排序完成,最高分: {sorted_sentences[0].get('importance_score', 0.0)}, 最低分: {sorted_sentences[-1].get('importance_score', 0.0)}")
|
|||
|
|
|||
|
# 统计句子长度分布
|
|||
|
token_lengths = []
|
|||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|||
|
|
|||
|
# 4. 分析token长度分布
|
|||
|
for i, sentence_data in enumerate(sorted_sentences):
|
|||
|
sentence = sentence_data.get('corrected_sentence', '')
|
|||
|
if not sentence:
|
|||
|
print(f"警告: 第 {i+1} 条数据没有corrected_sentence字段")
|
|||
|
continue
|
|||
|
|
|||
|
# 将句子转换为tokens
|
|||
|
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
|||
|
token_lengths.append(len(sentence_tokens))
|
|||
|
|
|||
|
if i < 5: # 显示前5条数据样例
|
|||
|
print(f"样例 {i+1}: {sentence[:50]}... (tokens: {len(sentence_tokens)})")
|
|||
|
|
|||
|
# 5. 统计分析结果
|
|||
|
token_lengths = torch.tensor(token_lengths)
|
|||
|
stats = {
|
|||
|
"总条目数": len(sorted_sentences),
|
|||
|
"有效条目数": len(token_lengths),
|
|||
|
"token长度-平均值": token_lengths.float().mean().item(),
|
|||
|
"token长度-最小值": token_lengths.min().item(),
|
|||
|
"token长度-最大值": token_lengths.max().item(),
|
|||
|
"token长度-中位数": token_lengths.median().item(),
|
|||
|
"token长度-标准差": token_lengths.float().std().item(),
|
|||
|
}
|
|||
|
|
|||
|
# 统计长度分布
|
|||
|
length_bins = {
|
|||
|
"小于16": (token_lengths < 16).sum().item(),
|
|||
|
"16-32": ((token_lengths >= 16) & (token_lengths < 32)).sum().item(),
|
|||
|
"32-64": ((token_lengths >= 32) & (token_lengths < 64)).sum().item(),
|
|||
|
"64-128": ((token_lengths >= 64) & (token_lengths < 128)).sum().item(),
|
|||
|
"128-256": ((token_lengths >= 128) & (token_lengths < 256)).sum().item(),
|
|||
|
"256及以上": (token_lengths >= 256).sum().item(),
|
|||
|
}
|
|||
|
|
|||
|
# 打印统计信息
|
|||
|
print("\n===== 数据库分析结果 =====")
|
|||
|
for key, value in stats.items():
|
|||
|
print(f"{key}: {value}")
|
|||
|
|
|||
|
print("\n===== Token长度分布 =====")
|
|||
|
for bin_name, count in length_bins.items():
|
|||
|
percentage = (count / len(token_lengths)) * 100
|
|||
|
print(f"{bin_name}: {count} ({percentage:.1f}%)")
|
|||
|
|
|||
|
print(f"\n结论: 该数据库文件包含 {stats['有效条目数']} 条有效数据,可以全部填充到知识库中。")
|
|||
|
|
|||
|
return stats, length_bins
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 指定数据库文件路径
|
|||
|
database_path = "./dataset/database_init.json"
|
|||
|
analyze_database(database_path)
|