update data_process

This commit is contained in:
gongjy 2024-10-12 19:46:08 +08:00
parent 2698f6b57d
commit 135421690e

View File

@ -1,5 +1,6 @@
import os import os
import warnings import warnings
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer from trl import DPOTrainer
@ -11,8 +12,8 @@ warnings.filterwarnings('ignore')
def init_model(): def init_model():
device = 'cuda:0' device = 'cuda:0'
# Do model patching and add fast LoRA weights # Do model patching and add fast LoRA weights
model_name_or_path = "minimind-v1-small" model_name_or_path = "minimind-v1"
tokenizer_name_or_path = "minimind-v1-small" tokenizer_name_or_path = "minimind-v1"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -22,10 +23,14 @@ def init_model():
if __name__ == '__main__': if __name__ == '__main__':
model, tokenizer = init_model() model, tokenizer = init_model()
training_args = TrainingArguments(output_dir="./minimind_dpo", training_args = TrainingArguments(
per_device_train_batch_size=1, output_dir="./minimind_dpo",
remove_unused_columns=False, per_device_train_batch_size=1,
report_to="none") remove_unused_columns=False,
report_to="none",
save_steps=2000,
learning_rate=4e-5
)
dataset_path = './dataset/dpo/train_data.json' dataset_path = './dataset/dpo/train_data.json'
train_dataset = load_dataset('json', data_files=dataset_path) train_dataset = load_dataset('json', data_files=dataset_path)