From f7127e43101670ff8155724638f337683462b053 Mon Sep 17 00:00:00 2001
From: gongjy <2474590974@qq.com>
Date: Sun, 10 Nov 2024 20:40:49 +0800
Subject: [PATCH] update lora-sft

---
 4-lora_sft.py | 21 ++++++++++++---------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/4-lora_sft.py b/4-lora_sft.py
index c5e0948..482844d 100644
--- a/4-lora_sft.py
+++ b/4-lora_sft.py
@@ -89,15 +89,16 @@ def train_epoch(epoch, wandb):
             model.save_pretrained(args.save_dir)
 
 
-def find_all_linear_names(model):
+def find_linear_with_keys(model, keys=["wq", "wk"]):
     cls = torch.nn.Linear
-    lora_module_names = set()
+    linear_names = []
     for name, module in model.named_modules():
         if isinstance(module, cls):
-            names = name.split('.')
-            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
-
-    return list(lora_module_names)
+            for key in keys:
+                if key in name:
+                    linear_names.append(name)
+                    break
+    return linear_names
 
 
 def init_model():
@@ -106,7 +107,7 @@ def init_model():
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
     model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
 
-    target_modules = find_all_linear_names(model)
+    target_modules = find_linear_with_keys(model)
     peft_config = LoraConfig(
         r=8,
         target_modules=target_modules
@@ -123,11 +124,12 @@ if __name__ == "__main__":
     parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
     parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
     parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
-    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
+    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
+                        help="Device to use")
     parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
     parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
     parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
-    parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
+    parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading")
     parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
     parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
     parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
@@ -151,6 +153,7 @@ if __name__ == "__main__":
 
     if args.use_wandb:
         import wandb
+
         wandb.init(project=args.wandb_project, name=args.wandb_run_name)
     else:
         wandb = None