update data_process

This commit is contained in:
gongjy 2024-10-23 12:25:45 +08:00
parent 69bcb8dc90
commit 42bd06e55d

View File

@ -75,7 +75,10 @@ def sft_process(contain_history=False):
# 创建DataFrame并追加到CSV文件 # 创建DataFrame并追加到CSV文件
df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst}) df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst})
# 1、默认
df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n') df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n')
# 2、若遇到数据 `_csv.Error: need to escape, but no escapechar set` 问题,可加 escapechar='\\' 参数:
# df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n', escapechar='\\')
chunk_size = 1000 # 每次处理的记录数 chunk_size = 1000 # 每次处理的记录数
data = [] data = []
@ -148,6 +151,6 @@ if __name__ == "__main__":
if process_type == 1: if process_type == 1:
pretrain_process() pretrain_process()
if process_type == 2: if process_type == 2:
sft_process(contain_history=False) sft_process(contain_history=True)
if process_type == 3: if process_type == 3:
rl_process() rl_process()