#!/usr/bin/env python3
"""
JSON文件合并脚本
读取多个JSON文件并合并为一个JSON文件
"""

import json
import os
from typing import Dict, List, Any, Union

# 需要合并的JSON文件列表
JSON_FILES_TO_MERGE = [
    "output/trex_sentences_enhanced_checkpoint_360000.json"
]
for i in range(1, 1010):
    JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")

def load_json_file(file_path: str) -> Union[Dict, List, None]:
    """加载JSON文件"""
    if not os.path.exists(file_path):
        print(f"警告: 文件 {file_path} 不存在")
        return None
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"成功加载: {file_path}")
        return data
    except json.JSONDecodeError as e:
        print(f"错误: 无法解析JSON文件 {file_path} - {e}")
        return None
    except Exception as e:
        print(f"错误: 读取文件 {file_path} 失败 - {e}")
        return None

def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
    """合并两个JSON数据结构"""
    
    # 如果两个都是列表,直接合并
    if isinstance(data1, list) and isinstance(data2, list):
        print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项")
        return data1 + data2
    
    # 如果两个都是字典
    elif isinstance(data1, dict) and isinstance(data2, dict):
        print("合并两个字典结构")
        merged = data1.copy()
        
        # 特殊处理:如果都有'sentences'字段且为列表,合并sentences
        if 'sentences' in data1 and 'sentences' in data2:
            if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
                print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项")
                merged['sentences'] = data1['sentences'] + data2['sentences']
                
                # 更新metadata if exists
                if 'metadata' in merged:
                    if isinstance(merged['metadata'], dict):
                        merged['metadata']['total_sentences'] = len(merged['sentences'])
                        merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
                
                # 合并其他字段
                for key, value in data2.items():
                    if key != 'sentences' and key not in merged:
                        merged[key] = value
                        
                return merged
        
        # 普通字典合并
        for key, value in data2.items():
            if key in merged:
                # 如果key重复且都是列表,合并列表
                if isinstance(merged[key], list) and isinstance(value, list):
                    merged[key] = merged[key] + value
                # 如果key重复且都是字典,递归合并
                elif isinstance(merged[key], dict) and isinstance(value, dict):
                    merged[key] = merge_json_data(merged[key], value)
                else:
                    # 其他情况保留第二个文件的值
                    merged[key] = value
                    print(f"字段 '{key}' 被覆盖")
            else:
                merged[key] = value
        
        return merged
    
    # 类型不匹配的情况,创建一个包含两者的新结构
    else:
        print("数据类型不匹配,创建包含两者的新结构")
        return {
            "data_from_save.json": data1,
            "data_from_save2.json": data2,
            "merged_at": "test.py"
        }

def save_merged_json(data: Union[Dict, List], output_path: str):
    """保存合并后的JSON数据"""
    try:
        # 确保输出目录存在
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        
        print(f"合并结果已保存到: {output_path}")
        
        # 显示统计信息
        if isinstance(data, dict):
            if 'sentences' in data and isinstance(data['sentences'], list):
                print(f"总计句子数: {len(data['sentences'])}")
            print(f"总计字段数: {len(data)}")
        elif isinstance(data, list):
            print(f"总计列表项数: {len(data)}")
            
    except Exception as e:
        print(f"错误: 保存文件失败 - {e}")

def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
    """从合并结果中移除重复的句子(基于句子内容)"""
    if isinstance(data, dict) and 'sentences' in data:
        if isinstance(data['sentences'], list):
            original_count = len(data['sentences'])
            seen_sentences = set()
            unique_sentences = []
            
            for item in data['sentences']:
                if isinstance(item, dict):
                    # 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识
                    sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
                elif isinstance(item, str):
                    sentence_key = item
                else:
                    sentence_key = str(item)
                
                if sentence_key and sentence_key not in seen_sentences:
                    seen_sentences.add(sentence_key)
                    unique_sentences.append(item)
            
            data['sentences'] = unique_sentences
            
            # 更新metadata
            if 'metadata' in data and isinstance(data['metadata'], dict):
                data['metadata']['total_sentences'] = len(unique_sentences)
                data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
            
            print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
    
    return data

def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
    """合并多个JSON数据结构"""
    if not data_list:
        return {}
    
    if len(data_list) == 1:
        return data_list[0]
    
    print(f"准备合并 {len(data_list)} 个JSON数据结构")
    
    # 从第一个数据开始,逐步合并其他数据
    merged_data = data_list[0]
    
    for i, data in enumerate(data_list[1:], 1):
        print(f"正在合并第 {i+1} 个数据结构...")
        merged_data = merge_json_data(merged_data, data)
    
    return merged_data

def main():
    """主函数"""
    print("=== JSON文件合并脚本 ===")
    
    # 输出路径
    output_path = "output/merged.json"
    
    print(f"准备合并以下文件:")
    for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
        print(f"  {i}. {file_path}")
    print(f"输出文件: {output_path}")
    print()
    
    # 加载所有文件
    loaded_data = []
    successfully_loaded = []
    
    for file_path in JSON_FILES_TO_MERGE:
        data = load_json_file(file_path)
        if data is not None:
            loaded_data.append(data)
            successfully_loaded.append(file_path)
    
    # 检查是否至少有一个文件加载成功
    if not loaded_data:
        print("错误: 没有文件能够成功加载,退出")
        return
    
    print(f"成功加载了 {len(loaded_data)} 个文件:")
    for file_path in successfully_loaded:
        print(f"  ✓ {file_path}")
    
    if len(loaded_data) < len(JSON_FILES_TO_MERGE):
        failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
        print(f"警告: {failed_count} 个文件加载失败")
    print()
    
    # 合并所有数据
    if len(loaded_data) == 1:
        print("只有一个文件可用,直接使用...")
        merged_data = loaded_data[0]
    else:
        print("开始合并所有文件...")
        merged_data = merge_multiple_json_data(loaded_data)
    
    # 去重处理
    print("\n检查并去除重复项...")
    merged_data = remove_duplicates_from_sentences(merged_data)
    
    # 保存合并结果
    print("\n保存合并结果...")
    save_merged_json(merged_data, output_path)
    
    print("\n=== 合并完成 ===")
    print(f"合并了 {len(successfully_loaded)} 个文件的数据")

if __name__ == "__main__":
    main()