update data_process
This commit is contained in:
parent
36fadc7ef1
commit
2698f6b57d
@ -1,44 +1,21 @@
|
||||
import os
|
||||
import warnings
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
import torch
|
||||
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from datasets import load_dataset
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = torch.nn.Linear
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, cls):
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if 'lm_head' in lora_module_names:
|
||||
lora_module_names.remove('lm_head')
|
||||
return list(lora_module_names)
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def init_model():
|
||||
device = 'cuda:0'
|
||||
# Do model patching and add fast LoRA weights
|
||||
model_name_or_path = "minimind"
|
||||
tokenizer_name_or_path = "minimind"
|
||||
model_name_or_path = "minimind-v1-small"
|
||||
tokenizer_name_or_path = "minimind-v1-small"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
target_modules = find_all_linear_names(model)
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
inference_mode=False,
|
||||
target_modules=target_modules
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
model = model.to(device)
|
||||
return model, tokenizer
|
||||
|
||||
@ -47,15 +24,10 @@ if __name__ == '__main__':
|
||||
model, tokenizer = init_model()
|
||||
training_args = TrainingArguments(output_dir="./minimind_dpo",
|
||||
per_device_train_batch_size=1,
|
||||
remove_unused_columns=False)
|
||||
remove_unused_columns=False,
|
||||
report_to="none")
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
# 确保路径正确,文件存在
|
||||
dataset_path = './dataset/dpo/train_data.json'
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
|
@ -114,24 +114,20 @@ def rl_process():
|
||||
# Dataset
|
||||
################
|
||||
|
||||
dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
|
||||
'./dataset/dpo/train_data.json',
|
||||
'./dataset/dpo/huozi_rlhf_data.json', ]
|
||||
dataset_paths = [
|
||||
'./dataset/dpo/dpo_zh_demo.json',
|
||||
'./dataset/dpo/dpo_train_data.json',
|
||||
'./dataset/dpo/huozi_rlhf_data.json',
|
||||
]
|
||||
|
||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
||||
train_dataset = load_dataset('json', data_files=dataset_paths)
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["reject"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
merged_data = []
|
||||
for split in train_dataset.keys():
|
||||
merged_data.extend(train_dataset[split])
|
||||
|
||||
ds = train_dataset.map(
|
||||
process,
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
output_dataset_path = './dataset/dpo/train_data.json'
|
||||
ds['train'].to_json(output_dataset_path, force_ascii=False, orient='records', lines=True)
|
||||
with open('./dataset/dpo/train_data.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -143,7 +139,7 @@ if __name__ == "__main__":
|
||||
# 2: sft
|
||||
# 3: RL
|
||||
################
|
||||
process_type = 1
|
||||
process_type = 3
|
||||
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
|
Loading…
x
Reference in New Issue
Block a user