update data_process

This commit is contained in:
gongjy 2024-10-12 18:47:08 +08:00
parent 36fadc7ef1
commit 2698f6b57d
2 changed files with 18 additions and 50 deletions

View File

@ -1,44 +1,21 @@
import os
import warnings
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
warnings.filterwarnings('ignore')
def init_model():
device = 'cuda:0'
# Do model patching and add fast LoRA weights
model_name_or_path = "minimind"
tokenizer_name_or_path = "minimind"
model_name_or_path = "minimind-v1-small"
tokenizer_name_or_path = "minimind-v1-small"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
target_modules = find_all_linear_names(model)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.1,
inference_mode=False,
target_modules=target_modules
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model = model.to(device)
return model, tokenizer
@ -47,15 +24,10 @@ if __name__ == '__main__':
model, tokenizer = init_model()
training_args = TrainingArguments(output_dir="./minimind_dpo",
per_device_train_batch_size=1,
remove_unused_columns=False)
remove_unused_columns=False,
report_to="none")
################
# Dataset
################
# 确保路径正确,文件存在
dataset_path = './dataset/dpo/train_data.json'
# 加载数据集
train_dataset = load_dataset('json', data_files=dataset_path)
dpo_trainer = DPOTrainer(

View File

@ -114,24 +114,20 @@ def rl_process():
# Dataset
################
dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
'./dataset/dpo/train_data.json',
'./dataset/dpo/huozi_rlhf_data.json', ]
dataset_paths = [
'./dataset/dpo/dpo_zh_demo.json',
'./dataset/dpo/dpo_train_data.json',
'./dataset/dpo/huozi_rlhf_data.json',
]
train_dataset = load_dataset('json', data_files=dataset_path)
train_dataset = load_dataset('json', data_files=dataset_paths)
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["reject"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
merged_data = []
for split in train_dataset.keys():
merged_data.extend(train_dataset[split])
ds = train_dataset.map(
process,
load_from_cache_file=False,
)
output_dataset_path = './dataset/dpo/train_data.json'
ds['train'].to_json(output_dataset_path, force_ascii=False, orient='records', lines=True)
with open('./dataset/dpo/train_data.json', 'w', encoding='utf-8') as f:
json.dump(merged_data, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
@ -143,7 +139,7 @@ if __name__ == "__main__":
# 2: sft
# 3: RL
################
process_type = 1
process_type = 3
if process_type == 1:
pretrain_process()