update trl

This commit is contained in:
gongjy 2024-10-13 22:44:28 +08:00
parent 5c4b34bbe3
commit 02adb7bc0d
2 changed files with 4 additions and 4 deletions

View File

@ -3,7 +3,7 @@ import warnings
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
from datasets import load_dataset
warnings.filterwarnings('ignore')
@ -23,7 +23,7 @@ def init_model():
if __name__ == '__main__':
model, tokenizer = init_model()
training_args = TrainingArguments(
training_config = DPOConfig(
output_dir="./minimind_dpo",
per_device_train_batch_size=1,
remove_unused_columns=False,
@ -38,7 +38,7 @@ if __name__ == '__main__':
dpo_trainer = DPOTrainer(
model,
ref_model=None,
args=training_args,
args=training_config,
beta=0.1,
train_dataset=train_dataset['train'],
tokenizer=tokenizer,

View File

@ -23,6 +23,6 @@ torch==2.1.2
transformers==4.44.0
jinja2==3.1.2
jsonlines==4.0.0
trl==0.8.6
trl==0.11.3
ujson==5.1.0
wandb==0.18.3