diff --git a/4-lora_sft.py b/4-lora_sft.py index c5e0948..482844d 100644 --- a/4-lora_sft.py +++ b/4-lora_sft.py @@ -89,15 +89,16 @@ def train_epoch(epoch, wandb): model.save_pretrained(args.save_dir) -def find_all_linear_names(model): +def find_linear_with_keys(model, keys=["wq", "wk"]): cls = torch.nn.Linear - lora_module_names = set() + linear_names = [] for name, module in model.named_modules(): if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - return list(lora_module_names) + for key in keys: + if key in name: + linear_names.append(name) + break + return linear_names def init_model(): @@ -106,7 +107,7 @@ def init_model(): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device) - target_modules = find_all_linear_names(model) + target_modules = find_linear_with_keys(model) peft_config = LoraConfig( r=8, target_modules=target_modules @@ -123,11 +124,12 @@ if __name__ == "__main__": parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", + help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name") - parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading") + parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading") parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold") parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations") @@ -151,6 +153,7 @@ if __name__ == "__main__": if args.use_wandb: import wandb + wandb.init(project=args.wandb_project, name=args.wandb_run_name) else: wandb = None