update trl
This commit is contained in:
parent
5c4b34bbe3
commit
02adb7bc0d
@ -3,7 +3,7 @@ import warnings
|
|||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
|
||||||
from trl import DPOTrainer
|
from trl import DPOConfig, DPOTrainer
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
@ -23,7 +23,7 @@ def init_model():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model, tokenizer = init_model()
|
model, tokenizer = init_model()
|
||||||
training_args = TrainingArguments(
|
training_config = DPOConfig(
|
||||||
output_dir="./minimind_dpo",
|
output_dir="./minimind_dpo",
|
||||||
per_device_train_batch_size=1,
|
per_device_train_batch_size=1,
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
@ -38,7 +38,7 @@ if __name__ == '__main__':
|
|||||||
dpo_trainer = DPOTrainer(
|
dpo_trainer = DPOTrainer(
|
||||||
model,
|
model,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
args=training_args,
|
args=training_config,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
train_dataset=train_dataset['train'],
|
train_dataset=train_dataset['train'],
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -23,6 +23,6 @@ torch==2.1.2
|
|||||||
transformers==4.44.0
|
transformers==4.44.0
|
||||||
jinja2==3.1.2
|
jinja2==3.1.2
|
||||||
jsonlines==4.0.0
|
jsonlines==4.0.0
|
||||||
trl==0.8.6
|
trl==0.11.3
|
||||||
ujson==5.1.0
|
ujson==5.1.0
|
||||||
wandb==0.18.3
|
wandb==0.18.3
|
||||||
|
Loading…
x
Reference in New Issue
Block a user