update data_process
This commit is contained in:
parent
2698f6b57d
commit
135421690e
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
@ -11,8 +12,8 @@ warnings.filterwarnings('ignore')
|
|||||||
def init_model():
|
def init_model():
|
||||||
device = 'cuda:0'
|
device = 'cuda:0'
|
||||||
# Do model patching and add fast LoRA weights
|
# Do model patching and add fast LoRA weights
|
||||||
model_name_or_path = "minimind-v1-small"
|
model_name_or_path = "minimind-v1"
|
||||||
tokenizer_name_or_path = "minimind-v1-small"
|
tokenizer_name_or_path = "minimind-v1"
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -22,10 +23,14 @@ def init_model():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model, tokenizer = init_model()
|
model, tokenizer = init_model()
|
||||||
training_args = TrainingArguments(output_dir="./minimind_dpo",
|
training_args = TrainingArguments(
|
||||||
per_device_train_batch_size=1,
|
output_dir="./minimind_dpo",
|
||||||
remove_unused_columns=False,
|
per_device_train_batch_size=1,
|
||||||
report_to="none")
|
remove_unused_columns=False,
|
||||||
|
report_to="none",
|
||||||
|
save_steps=2000,
|
||||||
|
learning_rate=4e-5
|
||||||
|
)
|
||||||
|
|
||||||
dataset_path = './dataset/dpo/train_data.json'
|
dataset_path = './dataset/dpo/train_data.json'
|
||||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
train_dataset = load_dataset('json', data_files=dataset_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user