Minimind/data_process.py

154 lines
4.8 KiB
Python
Raw Normal View History

2024-09-27 16:19:30 +08:00
import csv
2024-09-21 22:57:22 +08:00
import itertools
2024-08-28 16:41:44 +08:00
import re
import json
import jsonlines
import psutil
import ujson
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset
bos_token = "<s>"
eos_token = "</s>"
2024-09-27 16:19:30 +08:00
def pretrain_process(chunk_size=50000):
2024-09-21 22:57:22 +08:00
chunk_idx = 0
2024-08-28 16:41:44 +08:00
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
2024-09-27 16:19:30 +08:00
with open('./dataset/pretrain_data.csv', 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['text'])
while True:
chunk = list(itertools.islice(reader, chunk_size))
if not chunk:
break
for idx, obj in enumerate(chunk):
try:
content = obj.get('text', '')
if len(content) > 512:
continue
writer.writerow([content])
except UnicodeDecodeError as e:
print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}")
2024-09-21 22:57:22 +08:00
continue
2024-09-27 16:19:30 +08:00
chunk_idx += 1
print('chunk:', ((chunk_idx - 1) * chunk_size, chunk_idx * chunk_size), 'process end')
2024-08-28 16:41:44 +08:00
def sft_process(contain_history=False):
file_name = 'sft_data.csv'
if not contain_history:
file_name = 'sft_data_single.csv'
def chinese_ratio(text):
# 匹配所有中文字符
chinese_chars = re.findall(r'[\u4e00-\u9fff]', text)
# 中文字符数量占比
return len(chinese_chars) / len(text) if text else 0
def process_and_write_data(data):
q_lst, a_lst, history_lst = [], [], []
for per in data:
history, q, a = per['history'], per['q'], per['a']
if (contain_history and not history) or not q or not a:
continue
if len(q) < 10 or len(a) < 5:
continue
2024-10-23 12:02:28 +08:00
if len(q) > 512 or len(a) > 512:
2024-08-28 16:41:44 +08:00
continue
# 判断q和a中中文字符占比是否超过70%
2024-10-23 12:02:28 +08:00
if not (chinese_ratio(q) > 0.86 and chinese_ratio(a) > 0.86):
2024-08-28 16:41:44 +08:00
continue
q_lst.append(q)
a_lst.append(a)
if contain_history:
history_lst.append(history)
else:
history_lst.append([])
# 创建DataFrame并追加到CSV文件
df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst})
df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n')
chunk_size = 1000 # 每次处理的记录数
data = []
with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f:
f.write('history,q,a\n')
2024-09-22 21:17:05 +08:00
sft_datasets = ['./dataset/sft_data_zh.jsonl']
2024-08-28 16:41:44 +08:00
if not contain_history:
2024-09-22 21:17:05 +08:00
sft_datasets = ['./dataset/sft_data_zh.jsonl']
2024-08-28 16:41:44 +08:00
2024-10-23 12:02:28 +08:00
chunk_num = 0
2024-08-28 16:41:44 +08:00
for path in sft_datasets:
with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader):
2024-09-13 13:32:24 +08:00
try:
data.append({
'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''),
'a': obj.get('output', '') + obj.get('a', '')
})
if len(data) >= chunk_size:
2024-10-23 12:02:28 +08:00
chunk_num += 1
2024-09-13 13:32:24 +08:00
process_and_write_data(data)
data = []
2024-10-23 12:02:28 +08:00
if chunk_num % 100 == 0:
print(f'chunk:{chunk_num} process end')
2024-09-13 13:32:24 +08:00
except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue
2024-08-28 16:41:44 +08:00
if data:
process_and_write_data(data)
data = []
2024-09-27 16:19:30 +08:00
2024-08-28 16:41:44 +08:00
def rl_process():
################
# Dataset
################
2024-10-12 18:47:08 +08:00
dataset_paths = [
'./dataset/dpo/dpo_zh_demo.json',
'./dataset/dpo/dpo_train_data.json',
'./dataset/dpo/huozi_rlhf_data.json',
]
2024-08-28 16:41:44 +08:00
2024-10-12 18:47:08 +08:00
train_dataset = load_dataset('json', data_files=dataset_paths)
2024-08-28 16:41:44 +08:00
2024-10-12 18:47:08 +08:00
merged_data = []
for split in train_dataset.keys():
merged_data.extend(train_dataset[split])
2024-08-28 16:41:44 +08:00
2024-10-12 18:47:08 +08:00
with open('./dataset/dpo/train_data.json', 'w', encoding='utf-8') as f:
json.dump(merged_data, f, ensure_ascii=False, indent=4)
2024-08-28 16:41:44 +08:00
if __name__ == "__main__":
2024-09-13 13:32:24 +08:00
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False)
2024-08-28 16:41:44 +08:00
print('tokenizer词表大小', len(tokenizer))
################
# 1: pretrain
# 2: sft
# 3: RL
################
2024-10-23 12:02:28 +08:00
process_type = 2
2024-08-28 16:41:44 +08:00
if process_type == 1:
pretrain_process()
if process_type == 2:
2024-09-21 22:57:22 +08:00
sft_process(contain_history=False)
2024-08-28 16:41:44 +08:00
if process_type == 3:
rl_process()