133 lines
5.5 KiB
Python
133 lines
5.5 KiB
Python
import json
|
||
import os
|
||
import datetime
|
||
from typing import List, Dict, Any
|
||
|
||
# 配置参数
|
||
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
|
||
prepare_num = 1048576 # database_init.json的数据条数,可以根据需要修改
|
||
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
|
||
|
||
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
|
||
"""
|
||
将句子列表转换为 database_init.json 格式
|
||
|
||
Args:
|
||
sentences: 句子列表
|
||
importance_score: 重要性评分,默认为10.0
|
||
|
||
Returns:
|
||
转换后的字典格式数据
|
||
"""
|
||
# 构建句子数据
|
||
sentence_data = []
|
||
for sentence in sentences:
|
||
sentence_item = {
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": sentence, # 与original_sentence相同
|
||
"importance_score": importance_score
|
||
}
|
||
sentence_data.append(sentence_item)
|
||
|
||
# 构建完整的数据结构
|
||
result = {
|
||
"metadata": {
|
||
"batch_number": 1,
|
||
"batch_size": len(sentences),
|
||
"total_processed_count": len(sentences),
|
||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
"total_sentences": len(sentences),
|
||
"duplicates_removed": 0 # 在此函数中不涉及去重,所以设为0
|
||
},
|
||
"sentences": sentence_data
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
def preprocess_combined_json():
|
||
# 读取原始数据
|
||
print("正在读取combined.json...")
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
total_count = len(data)
|
||
print(f"总共有 {total_count} 条数据")
|
||
|
||
# 处理所有数据:将subject、predicate、object拼接成句子,同时记录原始数据
|
||
print("正在处理数据并拼接句子...")
|
||
sentence_to_original = {} # 记录句子到原始数据的映射
|
||
all_sentences = []
|
||
|
||
for i, item in enumerate(data):
|
||
# 拼接subject、predicate、object为一句话
|
||
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
|
||
all_sentences.append(sentence)
|
||
|
||
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
|
||
if sentence not in sentence_to_original:
|
||
sentence_to_original[sentence] = item
|
||
|
||
if (i + 1) % 100000 == 0:
|
||
print(f"已处理 {i + 1}/{total_count} 条数据")
|
||
|
||
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
|
||
|
||
# 去重处理
|
||
print("正在进行去重处理...")
|
||
unique_sentences = list(set(all_sentences))
|
||
duplicates_removed = len(all_sentences) - len(unique_sentences)
|
||
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed} 条")
|
||
|
||
# 检查是否有足够的去重数据
|
||
if len(unique_sentences) < prepare_num:
|
||
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
|
||
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
|
||
selected_sentences = unique_sentences
|
||
else:
|
||
print(f"选择前 {prepare_num} 条去重数据")
|
||
selected_sentences = unique_sentences[:prepare_num]
|
||
|
||
# 转换为database_init.json格式
|
||
print("正在转换为database_init.json格式...")
|
||
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
|
||
|
||
# 更新metadata中的duplicates_removed信息
|
||
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
|
||
|
||
# 保存database_init.json
|
||
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
|
||
print(f"正在保存 {database_output_path}...")
|
||
with open(database_output_path, "w", encoding="utf-8") as f:
|
||
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
|
||
|
||
# 保存剩余数据作为训练集(保持原格式)
|
||
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
|
||
if remaining_sentences:
|
||
# 将剩余的句子转换回原始格式
|
||
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
|
||
remaining_original_data = []
|
||
for sentence in remaining_sentences:
|
||
if sentence in sentence_to_original:
|
||
remaining_original_data.append(sentence_to_original[sentence])
|
||
|
||
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
|
||
train_output_path = os.path.join(output_dir, "combined_train.json")
|
||
with open(train_output_path, "w", encoding="utf-8") as f:
|
||
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
|
||
print(f"combined_train.json 保存完成")
|
||
else:
|
||
print("没有剩余数据用于训练集")
|
||
remaining_original_data = []
|
||
|
||
print("\n数据处理完成!")
|
||
print(f"原始数据: {total_count} 条")
|
||
print(f"拼接后: {len(all_sentences)} 条句子")
|
||
print(f"去重后: {len(unique_sentences)} 条句子")
|
||
print(f"用于database_init: {len(selected_sentences)} 条")
|
||
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0} 条")
|
||
|
||
if __name__ == "__main__":
|
||
preprocess_combined_json() |