Minimind/preprocessing/preprocess_combined_json.py

133 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import datetime
from typing import List, Dict, Any
# 配置参数
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
prepare_num = 1048576 # database_init.json的数据条数可以根据需要修改
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
"""
将句子列表转换为 database_init.json 格式
Args:
sentences: 句子列表
importance_score: 重要性评分默认为10.0
Returns:
转换后的字典格式数据
"""
# 构建句子数据
sentence_data = []
for sentence in sentences:
sentence_item = {
"original_sentence": sentence,
"corrected_sentence": sentence, # 与original_sentence相同
"importance_score": importance_score
}
sentence_data.append(sentence_item)
# 构建完整的数据结构
result = {
"metadata": {
"batch_number": 1,
"batch_size": len(sentences),
"total_processed_count": len(sentences),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"total_sentences": len(sentences),
"duplicates_removed": 0 # 在此函数中不涉及去重所以设为0
},
"sentences": sentence_data
}
return result
def preprocess_combined_json():
# 读取原始数据
print("正在读取combined.json...")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_count = len(data)
print(f"总共有 {total_count} 条数据")
# 处理所有数据将subject、predicate、object拼接成句子同时记录原始数据
print("正在处理数据并拼接句子...")
sentence_to_original = {} # 记录句子到原始数据的映射
all_sentences = []
for i, item in enumerate(data):
# 拼接subject、predicate、object为一句话
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
all_sentences.append(sentence)
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
if sentence not in sentence_to_original:
sentence_to_original[sentence] = item
if (i + 1) % 100000 == 0:
print(f"已处理 {i + 1}/{total_count} 条数据")
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
# 去重处理
print("正在进行去重处理...")
unique_sentences = list(set(all_sentences))
duplicates_removed = len(all_sentences) - len(unique_sentences)
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed}")
# 检查是否有足够的去重数据
if len(unique_sentences) < prepare_num:
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
selected_sentences = unique_sentences
else:
print(f"选择前 {prepare_num} 条去重数据")
selected_sentences = unique_sentences[:prepare_num]
# 转换为database_init.json格式
print("正在转换为database_init.json格式...")
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
# 更新metadata中的duplicates_removed信息
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
# 保存database_init.json
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
print(f"正在保存 {database_output_path}...")
with open(database_output_path, "w", encoding="utf-8") as f:
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
# 保存剩余数据作为训练集(保持原格式)
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
if remaining_sentences:
# 将剩余的句子转换回原始格式
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
remaining_original_data = []
for sentence in remaining_sentences:
if sentence in sentence_to_original:
remaining_original_data.append(sentence_to_original[sentence])
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
train_output_path = os.path.join(output_dir, "combined_train.json")
with open(train_output_path, "w", encoding="utf-8") as f:
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
print(f"combined_train.json 保存完成")
else:
print("没有剩余数据用于训练集")
remaining_original_data = []
print("\n数据处理完成!")
print(f"原始数据: {total_count}")
print(f"拼接后: {len(all_sentences)} 条句子")
print(f"去重后: {len(unique_sentences)} 条句子")
print(f"用于database_init: {len(selected_sentences)}")
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0}")
if __name__ == "__main__":
preprocess_combined_json()