update lora-sft
This commit is contained in:
parent
1240829c89
commit
f7127e4310
@ -89,15 +89,16 @@ def train_epoch(epoch, wandb):
|
|||||||
model.save_pretrained(args.save_dir)
|
model.save_pretrained(args.save_dir)
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_linear_with_keys(model, keys=["wq", "wk"]):
|
||||||
cls = torch.nn.Linear
|
cls = torch.nn.Linear
|
||||||
lora_module_names = set()
|
linear_names = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, cls):
|
if isinstance(module, cls):
|
||||||
names = name.split('.')
|
for key in keys:
|
||||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
if key in name:
|
||||||
|
linear_names.append(name)
|
||||||
return list(lora_module_names)
|
break
|
||||||
|
return linear_names
|
||||||
|
|
||||||
|
|
||||||
def init_model():
|
def init_model():
|
||||||
@ -106,7 +107,7 @@ def init_model():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
|
||||||
|
|
||||||
target_modules = find_all_linear_names(model)
|
target_modules = find_linear_with_keys(model)
|
||||||
peft_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
r=8,
|
r=8,
|
||||||
target_modules=target_modules
|
target_modules=target_modules
|
||||||
@ -123,11 +124,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
||||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
||||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device to use")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
||||||
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
||||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
|
||||||
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
|
parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading")
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
|
||||||
parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
|
parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
|
||||||
@ -151,6 +153,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.use_wandb:
|
if args.use_wandb:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||||
else:
|
else:
|
||||||
wandb = None
|
wandb = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user