update lora-sft

This commit is contained in:
gongjy 2024-10-11 17:43:52 +08:00
parent 3a034a47c8
commit 36fadc7ef1

View File

@ -16,6 +16,7 @@ from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from model.LMConfig import LMConfig from model.LMConfig import LMConfig
from model.dataset import SFTDataset from model.dataset import SFTDataset
from model.model import Transformer
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@ -96,8 +97,6 @@ def find_all_linear_names(model):
names = name.split('.') names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names) return list(lora_module_names)
@ -109,11 +108,7 @@ def init_model():
target_modules = find_all_linear_names(model) target_modules = find_all_linear_names(model)
peft_config = LoraConfig( peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, r=8,
lora_alpha=16,
lora_dropout=0.1,
inference_mode=False,
target_modules=target_modules target_modules=target_modules
) )
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
@ -126,7 +121,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning") parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size") parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
@ -162,7 +157,7 @@ if __name__ == "__main__":
model, tokenizer = init_model() model, tokenizer = init_model()
df = pd.read_csv('./dataset/sft_data.csv') df = pd.read_csv('./dataset/sft_data_single.csv')
df = df.sample(frac=1.0) df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_loader = DataLoader( train_loader = DataLoader(
@ -175,7 +170,10 @@ if __name__ == "__main__":
) )
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.learning_rate
)
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
Logger("compiling the model... (takes a ~minute)") Logger("compiling the model... (takes a ~minute)")