update lora-sft

This commit is contained in:
gongjy 2024-11-10 20:40:49 +08:00
parent 1240829c89
commit f7127e4310

View File

@ -89,15 +89,16 @@ def train_epoch(epoch, wandb):
model.save_pretrained(args.save_dir) model.save_pretrained(args.save_dir)
def find_all_linear_names(model): def find_linear_with_keys(model, keys=["wq", "wk"]):
cls = torch.nn.Linear cls = torch.nn.Linear
lora_module_names = set() linear_names = []
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls): if isinstance(module, cls):
names = name.split('.') for key in keys:
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if key in name:
linear_names.append(name)
return list(lora_module_names) break
return linear_names
def init_model(): def init_model():
@ -106,7 +107,7 @@ def init_model():
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
target_modules = find_all_linear_names(model) target_modules = find_linear_with_keys(model)
peft_config = LoraConfig( peft_config = LoraConfig(
r=8, r=8,
target_modules=target_modules target_modules=target_modules
@ -123,11 +124,12 @@ if __name__ == "__main__":
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name") parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading") parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations") parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
@ -151,6 +153,7 @@ if __name__ == "__main__":
if args.use_wandb: if args.use_wandb:
import wandb import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name) wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else: else:
wandb = None wandb = None