update data_process
This commit is contained in:
parent
36fadc7ef1
commit
2698f6b57d
@ -1,44 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from peft import get_peft_model, LoraConfig, TaskType
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
warnings.filterwarnings('ignore')
|
||||||
cls = torch.nn.Linear
|
|
||||||
lora_module_names = set()
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, cls):
|
|
||||||
names = name.split('.')
|
|
||||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
|
||||||
|
|
||||||
if 'lm_head' in lora_module_names:
|
|
||||||
lora_module_names.remove('lm_head')
|
|
||||||
return list(lora_module_names)
|
|
||||||
|
|
||||||
|
|
||||||
def init_model():
|
def init_model():
|
||||||
device = 'cuda:0'
|
device = 'cuda:0'
|
||||||
# Do model patching and add fast LoRA weights
|
# Do model patching and add fast LoRA weights
|
||||||
model_name_or_path = "minimind"
|
model_name_or_path = "minimind-v1-small"
|
||||||
tokenizer_name_or_path = "minimind"
|
tokenizer_name_or_path = "minimind-v1-small"
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
target_modules = find_all_linear_names(model)
|
|
||||||
peft_config = LoraConfig(
|
|
||||||
task_type=TaskType.CAUSAL_LM,
|
|
||||||
r=8,
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
inference_mode=False,
|
|
||||||
target_modules=target_modules
|
|
||||||
)
|
|
||||||
model = get_peft_model(model, peft_config)
|
|
||||||
model.print_trainable_parameters()
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -47,15 +24,10 @@ if __name__ == '__main__':
|
|||||||
model, tokenizer = init_model()
|
model, tokenizer = init_model()
|
||||||
training_args = TrainingArguments(output_dir="./minimind_dpo",
|
training_args = TrainingArguments(output_dir="./minimind_dpo",
|
||||||
per_device_train_batch_size=1,
|
per_device_train_batch_size=1,
|
||||||
remove_unused_columns=False)
|
remove_unused_columns=False,
|
||||||
|
report_to="none")
|
||||||
|
|
||||||
################
|
|
||||||
# Dataset
|
|
||||||
################
|
|
||||||
# 确保路径正确,文件存在
|
|
||||||
dataset_path = './dataset/dpo/train_data.json'
|
dataset_path = './dataset/dpo/train_data.json'
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
train_dataset = load_dataset('json', data_files=dataset_path)
|
||||||
|
|
||||||
dpo_trainer = DPOTrainer(
|
dpo_trainer = DPOTrainer(
|
||||||
|
@ -114,24 +114,20 @@ def rl_process():
|
|||||||
# Dataset
|
# Dataset
|
||||||
################
|
################
|
||||||
|
|
||||||
dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
|
dataset_paths = [
|
||||||
'./dataset/dpo/train_data.json',
|
'./dataset/dpo/dpo_zh_demo.json',
|
||||||
'./dataset/dpo/huozi_rlhf_data.json', ]
|
'./dataset/dpo/dpo_train_data.json',
|
||||||
|
'./dataset/dpo/huozi_rlhf_data.json',
|
||||||
|
]
|
||||||
|
|
||||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
train_dataset = load_dataset('json', data_files=dataset_paths)
|
||||||
|
|
||||||
def process(row):
|
merged_data = []
|
||||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
for split in train_dataset.keys():
|
||||||
row["reject"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
merged_data.extend(train_dataset[split])
|
||||||
return row
|
|
||||||
|
|
||||||
ds = train_dataset.map(
|
with open('./dataset/dpo/train_data.json', 'w', encoding='utf-8') as f:
|
||||||
process,
|
json.dump(merged_data, f, ensure_ascii=False, indent=4)
|
||||||
load_from_cache_file=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_dataset_path = './dataset/dpo/train_data.json'
|
|
||||||
ds['train'].to_json(output_dataset_path, force_ascii=False, orient='records', lines=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -143,7 +139,7 @@ if __name__ == "__main__":
|
|||||||
# 2: sft
|
# 2: sft
|
||||||
# 3: RL
|
# 3: RL
|
||||||
################
|
################
|
||||||
process_type = 1
|
process_type = 3
|
||||||
|
|
||||||
if process_type == 1:
|
if process_type == 1:
|
||||||
pretrain_process()
|
pretrain_process()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user