update trl
This commit is contained in:
parent
5c4b34bbe3
commit
02adb7bc0d
@ -3,7 +3,7 @@ import warnings
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
@ -23,7 +23,7 @@ def init_model():
|
||||
|
||||
if __name__ == '__main__':
|
||||
model, tokenizer = init_model()
|
||||
training_args = TrainingArguments(
|
||||
training_config = DPOConfig(
|
||||
output_dir="./minimind_dpo",
|
||||
per_device_train_batch_size=1,
|
||||
remove_unused_columns=False,
|
||||
@ -38,7 +38,7 @@ if __name__ == '__main__':
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
args=training_config,
|
||||
beta=0.1,
|
||||
train_dataset=train_dataset['train'],
|
||||
tokenizer=tokenizer,
|
||||
|
@ -23,6 +23,6 @@ torch==2.1.2
|
||||
transformers==4.44.0
|
||||
jinja2==3.1.2
|
||||
jsonlines==4.0.0
|
||||
trl==0.8.6
|
||||
trl==0.11.3
|
||||
ujson==5.1.0
|
||||
wandb==0.18.3
|
||||
|
Loading…
x
Reference in New Issue
Block a user