Minimind/preprocessing/preprocess_pretrain.py

742 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import pandas as pd
import tarfile
import tempfile
import shutil
from pathlib import Path
import re
import langdetect
from tqdm import tqdm
import logging
import random
import hashlib
from transformers import AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 配置参数
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
# 数据源路径
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
# Token长度限制
MIN_TOKENS = 410
MAX_TOKENS = 490
# 数据集质量和采样比例配置 - 主文件
DATASET_CONFIG = {
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量使用全部最多500万条
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量使用80%最多300万条
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量使用20%最多200万条
}
# 额外文件的配置 - 剩余数据
DATASET_CONFIG_EXTRA = {
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%最多500万条
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%最多400万条
}
# 全局变量:记录已选择的数据
selected_data_hashes = {
"wikipedia": set(),
"gutenberg": set(),
"openwebtext": set()
}
# 初始化tokenizer
tokenizer = None
def init_tokenizer():
"""初始化tokenizer"""
global tokenizer
try:
# 首先尝试使用本地的minimind tokenizer
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
if os.path.exists(local_tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
logger.info("Local MiniMind tokenizer initialized successfully")
else:
# 如果本地tokenizer不存在使用GPT-2但设置离线模式
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
logger.info("GPT-2 tokenizer initialized successfully (offline)")
except Exception as e:
logger.error(f"Error initializing tokenizer: {e}")
logger.info("Trying to use a simple fallback tokenizer...")
# 使用简单的分词方法作为备选
tokenizer = None
logger.warning("Using simple word-based tokenization as fallback")
def count_tokens(text):
"""计算文本的token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except:
pass
# 如果tokenization失败或tokenizer为None使用简单估算
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
def is_english_text(text, threshold=0.8):
"""检测文本是否为英文"""
try:
if len(text) < 50: # 太短的文本跳过检测
return True
detected_lang = langdetect.detect(text)
return detected_lang == 'en'
except:
# 如果检测失败,使用简单的英文字符比例判断
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
return (english_chars / max(total_chars, 1)) > threshold
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
"""将文本截断到目标token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= target_tokens:
return text
# 截断到目标长度
truncated_tokens = tokens[:target_tokens]
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# 尝试在句号处截断以保持完整性
sentences = truncated_text.split('.')
if len(sentences) > 1:
# 保留除最后一个不完整句子外的所有句子
truncated_text = '.'.join(sentences[:-1]) + '.'
return truncated_text
except:
pass
# 如果处理失败或tokenizer为None使用字符数估算
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
text = text[:estimated_chars]
# 尝试在句号处截断以保持完整性
sentences = text.split('.')
if len(sentences) > 1:
text = '.'.join(sentences[:-1]) + '.'
return text
def split_text_into_chunks(text, target_chunk_size=1500):
"""将长文本分割成多个中等长度的段落块"""
# 清理文本
text = text.strip()
if not text:
return []
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
chunks = []
# 按段落分割
paragraphs = text.split('\n\n')
current_chunk = ""
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
# 如果当前块加上新段落长度适中,就添加
if len(current_chunk) + len(paragraph) < target_chunk_size:
if current_chunk:
current_chunk += "\n\n" + paragraph
else:
current_chunk = paragraph
else:
# 如果当前块不为空,保存它
if current_chunk:
chunks.append(current_chunk)
# 如果段落本身就很长,需要进一步分割
if len(paragraph) > target_chunk_size * 2:
# 按句子分割长段落
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < target_chunk_size:
if temp_chunk:
temp_chunk += " " + sentence
else:
temp_chunk = sentence
else:
if temp_chunk:
chunks.append(temp_chunk)
temp_chunk = sentence
if temp_chunk:
current_chunk = temp_chunk
else:
current_chunk = ""
else:
current_chunk = paragraph
# 添加最后一个块
if current_chunk:
chunks.append(current_chunk)
return chunks
def format_text_for_pretrain(text):
"""将文本格式化为预训练格式并检查token长度"""
# 清理文本
text = text.strip()
if not text:
return None
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
# 检查token长度
token_count = count_tokens(text)
# 如果太短,跳过
if token_count < MIN_TOKENS:
return None
# 如果太长,截断
if token_count > MAX_TOKENS:
text = truncate_to_token_limit(text, MAX_TOKENS)
token_count = count_tokens(text)
# 再次检查是否在合理范围内
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
return None
# 格式化为预训练格式
formatted_text = f"<|im_start|>{text}<|im_end|>"
return formatted_text
def get_text_hash(text):
"""获取文本的哈希值,用于去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
"""根据配置决定是否采样当前记录"""
if config_dict is None:
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
config = config_dict[dataset_name]
# 检查是否达到最大样本数
if config["max_samples"] and current_count >= config["max_samples"]:
return False
# 根据采样比例随机决定
return random.random() < config["sample_ratio"]
def process_pretrain_hq():
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
logger.info("Processing pretrain_hq.jsonl...")
count = 0
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Processing pretrain_hq"):
try:
data = json.loads(line.strip())
text = data.get('text', '').strip()
if text: # 只要有文本就直接输出,不做任何检测
if should_sample("pretrain_hq", count):
yield {"text": text}
count += 1
except json.JSONDecodeError:
continue
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
def process_wikipedia(is_extra_mode=False):
"""处理Wikipedia数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有英文Wikipedia文件
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 200: # 预过滤太短的文本
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=2000)
for chunk in chunks:
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["wikipedia"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
def process_gutenberg(is_extra_mode=False):
"""处理Gutenberg数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有Gutenberg训练文件
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 300 and is_english_text(text): # 预过滤
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1800)
for chunk in chunks:
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["gutenberg"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
def process_openwebtext(is_extra_mode=False):
"""处理OpenWebText数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
max_files = 5 # 减少处理的文件数量以避免过长处理时间
# 获取tar文件列表
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
try:
with tarfile.open(tar_path, 'r') as outer_tar:
# 创建临时目录处理外层tar
with tempfile.TemporaryDirectory() as temp_dir:
outer_tar.extractall(temp_dir)
# 处理解压后的xz文件
for root, dirs, files in os.walk(temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for file in files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if file.endswith('.xz'):
xz_path = os.path.join(root, file)
# 创建另一个临时目录处理xz文件
with tempfile.TemporaryDirectory() as xz_temp_dir:
try:
# 解压xz文件
import subprocess
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
subprocess.run(['xz', '-dc', xz_path],
stdout=open(decompressed_path, 'wb'),
check=True)
# 检查解压后的文件是否是tar格式
if tarfile.is_tarfile(decompressed_path):
# 处理内层tar文件
with tarfile.open(decompressed_path, 'r') as inner_tar:
with tempfile.TemporaryDirectory() as inner_temp_dir:
inner_tar.extractall(inner_temp_dir)
# 处理最终的txt文件
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for txt_file in inner_files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if txt_file.endswith('.txt'):
txt_path = os.path.join(inner_root, txt_file)
try:
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading txt file {txt_path}: {e}")
continue
else:
# 如果不是tar文件直接作为文本处理
try:
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
continue
except Exception as e:
logger.debug(f"Error processing xz file {xz_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {tar_path}: {e}")
continue
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
def merge_datasets():
"""合并所有数据集,生成主文件和额外文件"""
logger.info("Starting dataset merging...")
logger.info("Main dataset configuration:")
for name, config in DATASET_CONFIG.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
logger.info("Extra dataset configuration:")
for name, config in DATASET_CONFIG_EXTRA.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
# 确保输出目录存在
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
# 统计信息
main_dataset_stats = {}
extra_dataset_stats = {}
# 第一阶段:生成主文件
logger.info("="*50)
logger.info("PHASE 1: Generating main dataset file")
logger.info("="*50)
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
main_total_count = 0
# 处理各个数据集(主模式)
main_datasets = [
("pretrain_hq", process_pretrain_hq),
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
]
for dataset_name, dataset_func in main_datasets:
logger.info(f"Processing {dataset_name} for main file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
main_total_count += 1
# 每5000条记录输出一次进度
if main_total_count % 5000 == 0:
logger.info(f"Main file: Processed {main_total_count} total records")
# 保存统计信息
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for main file: {e}")
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Main file generation completed. Total records: {main_total_count}")
# 第二阶段:生成额外文件
logger.info("="*50)
logger.info("PHASE 2: Generating extra dataset file")
logger.info("="*50)
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
extra_total_count = 0
# 处理各个数据集(额外模式)- 不包括pretrain_hq
extra_datasets = [
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
]
for dataset_name, dataset_func in extra_datasets:
logger.info(f"Processing {dataset_name} for extra file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
extra_total_count += 1
# 每5000条记录输出一次进度
if extra_total_count % 5000 == 0:
logger.info(f"Extra file: Processed {extra_total_count} total records")
# 保存统计信息
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for extra file: {e}")
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
# 打印详细统计信息
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
logger.info("All dataset processing completed successfully!")
logger.info(f"Main file saved to: {OUTPUT_FILE}")
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
"""打印详细统计信息"""
print("\n" + "="*100)
print("DATASET PROCESSING SUMMARY")
print("="*100)
# 主文件统计
print("\nMAIN FILE (merged_pretrain.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in main_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 额外文件统计
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in extra_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 总体统计
total_records = main_total_count + extra_total_count
print("\nOVERALL STATISTICS:")
print("-" * 50)
print(f"Main file records: {main_total_count:>10,}")
print(f"Extra file records: {extra_total_count:>10,}")
print(f"Total records: {total_records:>10,}")
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
# 质量分布统计
quality_stats = {}
for dataset_name, stats in main_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['main'] += stats['selected']
for dataset_name, stats in extra_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['extra'] += stats['selected']
print("\nQUALITY DISTRIBUTION:")
print("-" * 60)
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
print("-" * 60)
for quality in sorted(quality_stats.keys()):
main_count = quality_stats[quality]['main']
extra_count = quality_stats[quality]['extra']
total_count = main_count + extra_count
percentage = (total_count / total_records * 100) if total_records > 0 else 0
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
print("-" * 60)
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
print(f"\nFiles saved to:")
print(f" Main file: {OUTPUT_FILE}")
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
print("="*100)
def main():
"""主函数"""
try:
# 设置随机种子以确保结果可重现
random.seed(42)
# 检查依赖包
try:
import langdetect
from transformers import AutoTokenizer
except ImportError as e:
logger.error(f"Missing dependencies: {e}")
logger.error("Please install: pip install langdetect transformers")
return
# 初始化tokenizer
init_tokenizer()
# 检查输入文件
if not os.path.exists(PRETRAIN_HQ_PATH):
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
return
# 开始合并数据集
merge_datasets()
logger.info("All processing completed successfully!")
except Exception as e:
logger.error(f"Error in main process: {e}")
raise
if __name__ == "__main__":
main()