diff --git a/data_process.py b/data_process.py index 4d67dd3..4377c00 100644 --- a/data_process.py +++ b/data_process.py @@ -63,7 +63,7 @@ def sft_process(contain_history=False): if len(q) > 512 or len(a) > 512: continue # 判断q和a中中文字符占比是否超过70% - if not (chinese_ratio(q) > 0.86 and chinese_ratio(a) > 0.86): + if not (chinese_ratio(q) > 0.5 and chinese_ratio(a) > 0.5): continue q_lst.append(q) @@ -75,10 +75,11 @@ def sft_process(contain_history=False): # 创建DataFrame并追加到CSV文件 df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst}) - # 1、默认 - df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n') + # # 1、默认 + # df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n', encoding='utf-8') # 2、若遇到数据 `_csv.Error: need to escape, but no escapechar set` 问题,可加 escapechar='\\' 参数: - # df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n', escapechar='\\') + df.to_csv(f'./dataset/{file_name}', mode='a', header=False, index=False, lineterminator='\r\n', escapechar='\\', + encoding='utf-8') chunk_size = 1000 # 每次处理的记录数 data = [] @@ -151,6 +152,6 @@ if __name__ == "__main__": if process_type == 1: pretrain_process() if process_type == 2: - sft_process(contain_history=True) + sft_process(contain_history=False) if process_type == 3: rl_process()