742 lines
34 KiB
Python
742 lines
34 KiB
Python
import json
|
||
import os
|
||
import pandas as pd
|
||
import tarfile
|
||
import tempfile
|
||
import shutil
|
||
from pathlib import Path
|
||
import re
|
||
import langdetect
|
||
from tqdm import tqdm
|
||
import logging
|
||
import random
|
||
import hashlib
|
||
from transformers import AutoTokenizer
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 配置参数
|
||
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
|
||
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
|
||
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
|
||
|
||
# 数据源路径
|
||
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
|
||
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
|
||
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
|
||
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
|
||
|
||
# Token长度限制
|
||
MIN_TOKENS = 410
|
||
MAX_TOKENS = 490
|
||
|
||
# 数据集质量和采样比例配置 - 主文件
|
||
DATASET_CONFIG = {
|
||
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
|
||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量,使用全部,最多500万条
|
||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量,使用80%,最多300万条
|
||
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量,使用20%,最多200万条
|
||
}
|
||
|
||
# 额外文件的配置 - 剩余数据
|
||
DATASET_CONFIG_EXTRA = {
|
||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
|
||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%,最多500万条
|
||
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%,最多400万条
|
||
}
|
||
|
||
# 全局变量:记录已选择的数据
|
||
selected_data_hashes = {
|
||
"wikipedia": set(),
|
||
"gutenberg": set(),
|
||
"openwebtext": set()
|
||
}
|
||
|
||
# 初始化tokenizer
|
||
tokenizer = None
|
||
|
||
def init_tokenizer():
|
||
"""初始化tokenizer"""
|
||
global tokenizer
|
||
try:
|
||
# 首先尝试使用本地的minimind tokenizer
|
||
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
|
||
if os.path.exists(local_tokenizer_path):
|
||
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
|
||
logger.info("Local MiniMind tokenizer initialized successfully")
|
||
else:
|
||
# 如果本地tokenizer不存在,使用GPT-2(但设置离线模式)
|
||
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
|
||
logger.info("GPT-2 tokenizer initialized successfully (offline)")
|
||
except Exception as e:
|
||
logger.error(f"Error initializing tokenizer: {e}")
|
||
logger.info("Trying to use a simple fallback tokenizer...")
|
||
# 使用简单的分词方法作为备选
|
||
tokenizer = None
|
||
logger.warning("Using simple word-based tokenization as fallback")
|
||
|
||
def count_tokens(text):
|
||
"""计算文本的token数量"""
|
||
if tokenizer is None:
|
||
init_tokenizer()
|
||
|
||
if tokenizer is not None:
|
||
try:
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
return len(tokens)
|
||
except:
|
||
pass
|
||
|
||
# 如果tokenization失败或tokenizer为None,使用简单估算
|
||
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
|
||
|
||
def is_english_text(text, threshold=0.8):
|
||
"""检测文本是否为英文"""
|
||
try:
|
||
if len(text) < 50: # 太短的文本跳过检测
|
||
return True
|
||
detected_lang = langdetect.detect(text)
|
||
return detected_lang == 'en'
|
||
except:
|
||
# 如果检测失败,使用简单的英文字符比例判断
|
||
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
|
||
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
|
||
return (english_chars / max(total_chars, 1)) > threshold
|
||
|
||
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
|
||
"""将文本截断到目标token数量"""
|
||
if tokenizer is None:
|
||
init_tokenizer()
|
||
|
||
if tokenizer is not None:
|
||
try:
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
if len(tokens) <= target_tokens:
|
||
return text
|
||
|
||
# 截断到目标长度
|
||
truncated_tokens = tokens[:target_tokens]
|
||
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
||
|
||
# 尝试在句号处截断以保持完整性
|
||
sentences = truncated_text.split('.')
|
||
if len(sentences) > 1:
|
||
# 保留除最后一个不完整句子外的所有句子
|
||
truncated_text = '.'.join(sentences[:-1]) + '.'
|
||
|
||
return truncated_text
|
||
except:
|
||
pass
|
||
|
||
# 如果处理失败或tokenizer为None,使用字符数估算
|
||
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
|
||
text = text[:estimated_chars]
|
||
|
||
# 尝试在句号处截断以保持完整性
|
||
sentences = text.split('.')
|
||
if len(sentences) > 1:
|
||
text = '.'.join(sentences[:-1]) + '.'
|
||
|
||
return text
|
||
|
||
def split_text_into_chunks(text, target_chunk_size=1500):
|
||
"""将长文本分割成多个中等长度的段落块"""
|
||
# 清理文本
|
||
text = text.strip()
|
||
if not text:
|
||
return []
|
||
|
||
# 移除过多的换行符和空格
|
||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||
text = re.sub(r' +', ' ', text)
|
||
|
||
chunks = []
|
||
|
||
# 按段落分割
|
||
paragraphs = text.split('\n\n')
|
||
current_chunk = ""
|
||
|
||
for paragraph in paragraphs:
|
||
paragraph = paragraph.strip()
|
||
if not paragraph:
|
||
continue
|
||
|
||
# 如果当前块加上新段落长度适中,就添加
|
||
if len(current_chunk) + len(paragraph) < target_chunk_size:
|
||
if current_chunk:
|
||
current_chunk += "\n\n" + paragraph
|
||
else:
|
||
current_chunk = paragraph
|
||
else:
|
||
# 如果当前块不为空,保存它
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
# 如果段落本身就很长,需要进一步分割
|
||
if len(paragraph) > target_chunk_size * 2:
|
||
# 按句子分割长段落
|
||
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
||
temp_chunk = ""
|
||
|
||
for sentence in sentences:
|
||
if len(temp_chunk) + len(sentence) < target_chunk_size:
|
||
if temp_chunk:
|
||
temp_chunk += " " + sentence
|
||
else:
|
||
temp_chunk = sentence
|
||
else:
|
||
if temp_chunk:
|
||
chunks.append(temp_chunk)
|
||
temp_chunk = sentence
|
||
|
||
if temp_chunk:
|
||
current_chunk = temp_chunk
|
||
else:
|
||
current_chunk = ""
|
||
else:
|
||
current_chunk = paragraph
|
||
|
||
# 添加最后一个块
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
def format_text_for_pretrain(text):
|
||
"""将文本格式化为预训练格式并检查token长度"""
|
||
# 清理文本
|
||
text = text.strip()
|
||
if not text:
|
||
return None
|
||
|
||
# 移除过多的换行符和空格
|
||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||
text = re.sub(r' +', ' ', text)
|
||
|
||
# 检查token长度
|
||
token_count = count_tokens(text)
|
||
|
||
# 如果太短,跳过
|
||
if token_count < MIN_TOKENS:
|
||
return None
|
||
|
||
# 如果太长,截断
|
||
if token_count > MAX_TOKENS:
|
||
text = truncate_to_token_limit(text, MAX_TOKENS)
|
||
token_count = count_tokens(text)
|
||
|
||
# 再次检查是否在合理范围内
|
||
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
|
||
return None
|
||
|
||
# 格式化为预训练格式
|
||
formatted_text = f"<|im_start|>{text}<|im_end|>"
|
||
return formatted_text
|
||
|
||
def get_text_hash(text):
|
||
"""获取文本的哈希值,用于去重"""
|
||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||
|
||
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
|
||
"""根据配置决定是否采样当前记录"""
|
||
if config_dict is None:
|
||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||
|
||
config = config_dict[dataset_name]
|
||
|
||
# 检查是否达到最大样本数
|
||
if config["max_samples"] and current_count >= config["max_samples"]:
|
||
return False
|
||
|
||
# 根据采样比例随机决定
|
||
return random.random() < config["sample_ratio"]
|
||
|
||
def process_pretrain_hq():
|
||
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
|
||
logger.info("Processing pretrain_hq.jsonl...")
|
||
count = 0
|
||
|
||
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
|
||
for line in tqdm(f, desc="Processing pretrain_hq"):
|
||
try:
|
||
data = json.loads(line.strip())
|
||
text = data.get('text', '').strip()
|
||
|
||
if text: # 只要有文本就直接输出,不做任何检测
|
||
if should_sample("pretrain_hq", count):
|
||
yield {"text": text}
|
||
count += 1
|
||
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
|
||
|
||
def process_wikipedia(is_extra_mode=False):
|
||
"""处理Wikipedia数据"""
|
||
mode_text = "extra" if is_extra_mode else "main"
|
||
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
|
||
count = 0
|
||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||
|
||
# 获取所有英文Wikipedia文件
|
||
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
|
||
|
||
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
|
||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||
break
|
||
|
||
try:
|
||
df = pd.read_parquet(file_path)
|
||
for _, row in df.iterrows():
|
||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||
break
|
||
|
||
text = row.get('text', '').strip()
|
||
if text and len(text) > 200: # 预过滤太短的文本
|
||
# 先将长文本分割成中等大小的块
|
||
chunks = split_text_into_chunks(text, target_chunk_size=2000)
|
||
|
||
for chunk in chunks:
|
||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||
break
|
||
|
||
chunk_hash = get_text_hash(chunk)
|
||
|
||
# 在额外模式下,跳过已经被主文件选中的数据
|
||
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
|
||
continue
|
||
|
||
formatted_text = format_text_for_pretrain(chunk)
|
||
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
|
||
# 在主模式下记录哈希值
|
||
if not is_extra_mode:
|
||
selected_data_hashes["wikipedia"].add(chunk_hash)
|
||
|
||
yield {"text": formatted_text}
|
||
count += 1
|
||
except Exception as e:
|
||
logger.error(f"Error processing {file_path}: {e}")
|
||
continue
|
||
|
||
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
|
||
|
||
def process_gutenberg(is_extra_mode=False):
|
||
"""处理Gutenberg数据"""
|
||
mode_text = "extra" if is_extra_mode else "main"
|
||
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
|
||
count = 0
|
||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||
|
||
# 获取所有Gutenberg训练文件
|
||
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
|
||
|
||
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
|
||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||
break
|
||
|
||
try:
|
||
df = pd.read_parquet(file_path)
|
||
for _, row in df.iterrows():
|
||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||
break
|
||
|
||
text = row.get('text', '').strip()
|
||
if text and len(text) > 300 and is_english_text(text): # 预过滤
|
||
# 先将长文本分割成中等大小的块
|
||
chunks = split_text_into_chunks(text, target_chunk_size=1800)
|
||
|
||
for chunk in chunks:
|
||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||
break
|
||
|
||
chunk_hash = get_text_hash(chunk)
|
||
|
||
# 在额外模式下,跳过已经被主文件选中的数据
|
||
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
|
||
continue
|
||
|
||
formatted_text = format_text_for_pretrain(chunk)
|
||
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
|
||
# 在主模式下记录哈希值
|
||
if not is_extra_mode:
|
||
selected_data_hashes["gutenberg"].add(chunk_hash)
|
||
|
||
yield {"text": formatted_text}
|
||
count += 1
|
||
except Exception as e:
|
||
logger.error(f"Error processing {file_path}: {e}")
|
||
continue
|
||
|
||
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
|
||
|
||
def process_openwebtext(is_extra_mode=False):
|
||
"""处理OpenWebText数据"""
|
||
mode_text = "extra" if is_extra_mode else "main"
|
||
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
|
||
count = 0
|
||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||
max_files = 5 # 减少处理的文件数量以避免过长处理时间
|
||
|
||
# 获取tar文件列表
|
||
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
|
||
|
||
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
try:
|
||
with tarfile.open(tar_path, 'r') as outer_tar:
|
||
# 创建临时目录处理外层tar
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
outer_tar.extractall(temp_dir)
|
||
|
||
# 处理解压后的xz文件
|
||
for root, dirs, files in os.walk(temp_dir):
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
for file in files:
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
if file.endswith('.xz'):
|
||
xz_path = os.path.join(root, file)
|
||
|
||
# 创建另一个临时目录处理xz文件
|
||
with tempfile.TemporaryDirectory() as xz_temp_dir:
|
||
try:
|
||
# 解压xz文件
|
||
import subprocess
|
||
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
|
||
subprocess.run(['xz', '-dc', xz_path],
|
||
stdout=open(decompressed_path, 'wb'),
|
||
check=True)
|
||
|
||
# 检查解压后的文件是否是tar格式
|
||
if tarfile.is_tarfile(decompressed_path):
|
||
# 处理内层tar文件
|
||
with tarfile.open(decompressed_path, 'r') as inner_tar:
|
||
with tempfile.TemporaryDirectory() as inner_temp_dir:
|
||
inner_tar.extractall(inner_temp_dir)
|
||
|
||
# 处理最终的txt文件
|
||
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
for txt_file in inner_files:
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
if txt_file.endswith('.txt'):
|
||
txt_path = os.path.join(inner_root, txt_file)
|
||
try:
|
||
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
text = f.read().strip()
|
||
if text and len(text) > 500 and is_english_text(text):
|
||
# 先将长文本分割成中等大小的块
|
||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||
|
||
for chunk in chunks:
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
chunk_hash = get_text_hash(chunk)
|
||
|
||
# 在额外模式下,跳过已经被主文件选中的数据
|
||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||
continue
|
||
|
||
formatted_text = format_text_for_pretrain(chunk)
|
||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||
# 在主模式下记录哈希值
|
||
if not is_extra_mode:
|
||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||
|
||
yield {"text": formatted_text}
|
||
count += 1
|
||
except Exception as e:
|
||
logger.debug(f"Error reading txt file {txt_path}: {e}")
|
||
continue
|
||
else:
|
||
# 如果不是tar文件,直接作为文本处理
|
||
try:
|
||
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
text = f.read().strip()
|
||
if text and len(text) > 500 and is_english_text(text):
|
||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||
|
||
for chunk in chunks:
|
||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||
break
|
||
|
||
chunk_hash = get_text_hash(chunk)
|
||
|
||
# 在额外模式下,跳过已经被主文件选中的数据
|
||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||
continue
|
||
|
||
formatted_text = format_text_for_pretrain(chunk)
|
||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||
# 在主模式下记录哈希值
|
||
if not is_extra_mode:
|
||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||
|
||
yield {"text": formatted_text}
|
||
count += 1
|
||
except Exception as e:
|
||
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error processing xz file {xz_path}: {e}")
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"Error processing {tar_path}: {e}")
|
||
continue
|
||
|
||
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
|
||
|
||
def merge_datasets():
|
||
"""合并所有数据集,生成主文件和额外文件"""
|
||
logger.info("Starting dataset merging...")
|
||
logger.info("Main dataset configuration:")
|
||
for name, config in DATASET_CONFIG.items():
|
||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||
|
||
logger.info("Extra dataset configuration:")
|
||
for name, config in DATASET_CONFIG_EXTRA.items():
|
||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
|
||
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
|
||
|
||
# 统计信息
|
||
main_dataset_stats = {}
|
||
extra_dataset_stats = {}
|
||
|
||
# 第一阶段:生成主文件
|
||
logger.info("="*50)
|
||
logger.info("PHASE 1: Generating main dataset file")
|
||
logger.info("="*50)
|
||
|
||
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
|
||
main_total_count = 0
|
||
|
||
# 处理各个数据集(主模式)
|
||
main_datasets = [
|
||
("pretrain_hq", process_pretrain_hq),
|
||
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
|
||
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
|
||
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
|
||
]
|
||
|
||
for dataset_name, dataset_func in main_datasets:
|
||
logger.info(f"Processing {dataset_name} for main file...")
|
||
dataset_count = 0
|
||
|
||
try:
|
||
for record in dataset_func():
|
||
json.dump(record, outfile, ensure_ascii=False)
|
||
outfile.write('\n')
|
||
dataset_count += 1
|
||
main_total_count += 1
|
||
|
||
# 每5000条记录输出一次进度
|
||
if main_total_count % 5000 == 0:
|
||
logger.info(f"Main file: Processed {main_total_count} total records")
|
||
|
||
# 保存统计信息
|
||
main_dataset_stats[dataset_name] = {
|
||
'selected': dataset_count,
|
||
'config': DATASET_CONFIG[dataset_name]
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing {dataset_name} for main file: {e}")
|
||
main_dataset_stats[dataset_name] = {
|
||
'selected': dataset_count,
|
||
'config': DATASET_CONFIG[dataset_name]
|
||
}
|
||
|
||
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
|
||
|
||
logger.info(f"Main file generation completed. Total records: {main_total_count}")
|
||
|
||
# 第二阶段:生成额外文件
|
||
logger.info("="*50)
|
||
logger.info("PHASE 2: Generating extra dataset file")
|
||
logger.info("="*50)
|
||
|
||
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
|
||
extra_total_count = 0
|
||
|
||
# 处理各个数据集(额外模式)- 不包括pretrain_hq
|
||
extra_datasets = [
|
||
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
|
||
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
|
||
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
|
||
]
|
||
|
||
for dataset_name, dataset_func in extra_datasets:
|
||
logger.info(f"Processing {dataset_name} for extra file...")
|
||
dataset_count = 0
|
||
|
||
try:
|
||
for record in dataset_func():
|
||
json.dump(record, outfile, ensure_ascii=False)
|
||
outfile.write('\n')
|
||
dataset_count += 1
|
||
extra_total_count += 1
|
||
|
||
# 每5000条记录输出一次进度
|
||
if extra_total_count % 5000 == 0:
|
||
logger.info(f"Extra file: Processed {extra_total_count} total records")
|
||
|
||
# 保存统计信息
|
||
extra_dataset_stats[dataset_name] = {
|
||
'selected': dataset_count,
|
||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing {dataset_name} for extra file: {e}")
|
||
extra_dataset_stats[dataset_name] = {
|
||
'selected': dataset_count,
|
||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||
}
|
||
|
||
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
|
||
|
||
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
|
||
|
||
# 打印详细统计信息
|
||
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
|
||
|
||
logger.info("All dataset processing completed successfully!")
|
||
logger.info(f"Main file saved to: {OUTPUT_FILE}")
|
||
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
|
||
|
||
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
|
||
"""打印详细统计信息"""
|
||
print("\n" + "="*100)
|
||
print("DATASET PROCESSING SUMMARY")
|
||
print("="*100)
|
||
|
||
# 主文件统计
|
||
print("\nMAIN FILE (merged_pretrain.jsonl):")
|
||
print("-" * 90)
|
||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
|
||
print("-" * 90)
|
||
|
||
for dataset_name, stats in main_dataset_stats.items():
|
||
selected = stats['selected']
|
||
config = stats['config']
|
||
ratio = config['sample_ratio']
|
||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
|
||
quality = config['quality']
|
||
|
||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||
|
||
print("-" * 90)
|
||
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||
|
||
# 额外文件统计
|
||
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
|
||
print("-" * 90)
|
||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
|
||
print("-" * 90)
|
||
|
||
for dataset_name, stats in extra_dataset_stats.items():
|
||
selected = stats['selected']
|
||
config = stats['config']
|
||
ratio = config['sample_ratio']
|
||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
|
||
quality = config['quality']
|
||
|
||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||
|
||
print("-" * 90)
|
||
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||
|
||
# 总体统计
|
||
total_records = main_total_count + extra_total_count
|
||
print("\nOVERALL STATISTICS:")
|
||
print("-" * 50)
|
||
print(f"Main file records: {main_total_count:>10,}")
|
||
print(f"Extra file records: {extra_total_count:>10,}")
|
||
print(f"Total records: {total_records:>10,}")
|
||
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
|
||
|
||
# 质量分布统计
|
||
quality_stats = {}
|
||
for dataset_name, stats in main_dataset_stats.items():
|
||
quality = stats['config']['quality']
|
||
if quality not in quality_stats:
|
||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||
quality_stats[quality]['main'] += stats['selected']
|
||
|
||
for dataset_name, stats in extra_dataset_stats.items():
|
||
quality = stats['config']['quality']
|
||
if quality not in quality_stats:
|
||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||
quality_stats[quality]['extra'] += stats['selected']
|
||
|
||
print("\nQUALITY DISTRIBUTION:")
|
||
print("-" * 60)
|
||
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
|
||
print("-" * 60)
|
||
for quality in sorted(quality_stats.keys()):
|
||
main_count = quality_stats[quality]['main']
|
||
extra_count = quality_stats[quality]['extra']
|
||
total_count = main_count + extra_count
|
||
percentage = (total_count / total_records * 100) if total_records > 0 else 0
|
||
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
|
||
print("-" * 60)
|
||
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
|
||
|
||
print(f"\nFiles saved to:")
|
||
print(f" Main file: {OUTPUT_FILE}")
|
||
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
|
||
print("="*100)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
try:
|
||
# 设置随机种子以确保结果可重现
|
||
random.seed(42)
|
||
|
||
# 检查依赖包
|
||
try:
|
||
import langdetect
|
||
from transformers import AutoTokenizer
|
||
except ImportError as e:
|
||
logger.error(f"Missing dependencies: {e}")
|
||
logger.error("Please install: pip install langdetect transformers")
|
||
return
|
||
|
||
# 初始化tokenizer
|
||
init_tokenizer()
|
||
|
||
# 检查输入文件
|
||
if not os.path.exists(PRETRAIN_HQ_PATH):
|
||
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
|
||
return
|
||
|
||
# 开始合并数据集
|
||
merge_datasets()
|
||
logger.info("All processing completed successfully!")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in main process: {e}")
|
||
raise
|
||
|
||
if __name__ == "__main__":
|
||
main()
|