diff --git a/5-dpo_train.py b/5-dpo_train.py index 3ae3df7..ef17e2c 100644 --- a/5-dpo_train.py +++ b/5-dpo_train.py @@ -3,7 +3,7 @@ import warnings os.environ['CUDA_VISIBLE_DEVICES'] = '0' from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer from datasets import load_dataset warnings.filterwarnings('ignore') @@ -23,7 +23,7 @@ def init_model(): if __name__ == '__main__': model, tokenizer = init_model() - training_args = TrainingArguments( + training_config = DPOConfig( output_dir="./minimind_dpo", per_device_train_batch_size=1, remove_unused_columns=False, @@ -38,7 +38,7 @@ if __name__ == '__main__': dpo_trainer = DPOTrainer( model, ref_model=None, - args=training_args, + args=training_config, beta=0.1, train_dataset=train_dataset['train'], tokenizer=tokenizer, diff --git a/requirements.txt b/requirements.txt index 26c8d29..342675d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,6 @@ torch==2.1.2 transformers==4.44.0 jinja2==3.1.2 jsonlines==4.0.0 -trl==0.8.6 +trl==0.11.3 ujson==5.1.0 wandb==0.18.3