update lora-sft
This commit is contained in:
parent
1240829c89
commit
f7127e4310
@ -89,15 +89,16 @@ def train_epoch(epoch, wandb):
|
||||
model.save_pretrained(args.save_dir)
|
||||
|
||||
|
||||
def find_all_linear_names(model):
|
||||
def find_linear_with_keys(model, keys=["wq", "wk"]):
|
||||
cls = torch.nn.Linear
|
||||
lora_module_names = set()
|
||||
linear_names = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, cls):
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
return list(lora_module_names)
|
||||
for key in keys:
|
||||
if key in name:
|
||||
linear_names.append(name)
|
||||
break
|
||||
return linear_names
|
||||
|
||||
|
||||
def init_model():
|
||||
@ -106,7 +107,7 @@ def init_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
|
||||
|
||||
target_modules = find_all_linear_names(model)
|
||||
target_modules = find_linear_with_keys(model)
|
||||
peft_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=target_modules
|
||||
@ -123,11 +124,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to use")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
|
||||
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
|
||||
parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
|
||||
@ -151,6 +153,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.use_wandb:
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user