添加了argparse,方便命令行输入参数

This commit is contained in:
Yu Chengzhang 2024-09-24 12:41:58 +08:00
parent ef9a592d14
commit 51dcf51c5d
3 changed files with 206 additions and 219 deletions

View File

@ -1,5 +1,6 @@
import os import os
import platform import platform
import argparse
import time import time
import math import math
import warnings import warnings
@ -23,66 +24,65 @@ def Logger(content):
def get_lr(it, all): def get_lr(it, all):
warmup_iters = 0 warmup_iters = args.warmup_iters
lr_decay_iters = all lr_decay_iters = all
min_lr = learning_rate / 10 min_lr = args.learning_rate / 10
if it < warmup_iters: if it < warmup_iters:
return learning_rate * it / warmup_iters return args.learning_rate * it / warmup_iters
if it > lr_decay_iters: if it > lr_decay_iters:
return min_lr return min_lr
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1 assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr) return min_lr + coeff * (args.learning_rate - min_lr)
def train_epoch(epoch, wandb, accumulation_steps=8): def train_epoch(epoch, wandb):
start_time = time.time() start_time = time.time()
for step, (X, Y) in enumerate(train_loader): for step, (X, Y) in enumerate(train_loader):
X = X.to(device) X = X.to(args.device)
Y = Y.to(device) Y = Y.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
with ctx: with ctx:
out = model(X, Y) out = model(X, Y)
loss = out.last_loss / accumulation_steps loss = out.last_loss / args.accumulation_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
if (step + 1) % accumulation_steps == 0: if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
if step % 100 == 0: if step % args.log_interval == 0:
spend_time = time.time() - start_time spend_time = time.time() - start_time
Logger( Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
epoch, epoch,
epochs, args.epochs,
step, step,
iter_per_epoch, iter_per_epoch,
loss.item() * accumulation_steps, loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'], optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if wandb != None: if wandb is not None:
wandb.log({"loss": loss.item() * accumulation_steps, wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'], "lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % 1000 == 0 and (not ddp or dist.get_rank() == 0): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval() model.eval()
# torch.save(model.state_dict(), '{}/iter_{}.pth'.format(save_dir, int(step + epoch * iter_per_epoch)))
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
@ -97,17 +97,8 @@ def init_model():
def count_parameters(model): def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
# model init model = Transformer(lm_config).to(args.device)
model = Transformer(lm_config).to(device)
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
# ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
#
# state_dict = torch.load(ckp, map_location=device)
# unwanted_prefix = '_orig_mod.'
# for k, v in list(state_dict.items()):
# if k.startswith(unwanted_prefix):
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
# model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万') Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万')
return model return model
@ -125,81 +116,78 @@ def init_distributed_mode():
torch.cuda.set_device(DEVICE) torch.cuda.set_device(DEVICE)
# torchrun --nproc_per_node 2 1-pretrain.py # torchrun --nproc_per_node 2 1-pretrain.py
# I/O
if __name__ == "__main__": if __name__ == "__main__":
# ----------------------------------------------------------------------------- parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="Weights & Biases project name")
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_data.bin", help="Path to training data")
parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel")
parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
args = parser.parse_args()
lm_config = LMConfig() lm_config = LMConfig()
max_seq_len = lm_config.max_seq_len max_seq_len = lm_config.max_seq_len
out_dir = 'out' args.save_dir = os.path.join(args.out_dir)
epochs = 20 os.makedirs(args.save_dir, exist_ok=True)
batch_size = 64 os.makedirs(args.out_dir, exist_ok=True)
learning_rate = 2e-4 tokens_per_iter = args.batch_size * max_seq_len
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
save_dir = os.path.join(out_dir)
os.makedirs(save_dir, exist_ok=True)
os.makedirs(out_dir, exist_ok=True)
tokens_per_iter = batch_size * max_seq_len
torch.manual_seed(1337) torch.manual_seed(1337)
device_type = device if "cuda" in device else "cpu" device_type = "cuda" if "cuda" in args.device else "cpu"
use_wandb = True #是否使用wandb args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb_project = "MiniMind-Pretrain"
wandb_run_name = f"MiniMind-Pretrain-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ctx = (
nullcontext()
if device_type == "cpu"
else torch.cuda.amp.autocast()
)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0" ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp: if ddp:
init_distributed_mode() init_distributed_mode()
device = torch.device(DEVICE) args.device = torch.device(DEVICE)
if use_wandb and (not ddp or ddp_local_rank == 0): if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb import wandb
wandb.init(project=wandb_project, name=wandb_run_name) wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else: else:
wandb = None wandb = None
# -----------------------------------------------------------------------------
# -----init dataloader------ data_path_list = [args.data_path]
data_path_list = ['./dataset/pretrain_data.bin']
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True) train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
train_sampler = DistributedSampler(train_ds) if ddp else None train_sampler = DistributedSampler(train_ds) if ddp else None
num_workers = 8 # 可以根据系统的 CPU 核心数来调整
train_loader = DataLoader( train_loader = DataLoader(
train_ds, train_ds,
batch_size=batch_size, batch_size=args.batch_size,
pin_memory=True, pin_memory=True,
drop_last=False, drop_last=False,
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=args.num_workers,
sampler=train_sampler sampler=train_sampler
) )
# init model
model = init_model() model = init_model()
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype)) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
# optimizer optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# compile the model
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
Logger("compiling the model... (takes a ~minute)") Logger("compiling the model... (takes a ~minute)")
unoptimized_model = model unoptimized_model = model
model = torch.compile(model) model = torch.compile(model)
if ddp: if ddp:
# Ignore the freqs_cis buffer so that DDP does not broadcast it at
# construction time since NCCL does not support ComplexFloat
model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
# training loop
iter_per_epoch = len(train_loader) iter_per_epoch = len(train_loader)
for epoch in range(epochs): for epoch in range(args.epochs):
train_epoch(epoch, wandb) train_epoch(epoch, wandb)

View File

@ -1,5 +1,6 @@
import os import os
import platform import platform
import argparse
import time import time
import math import math
import warnings import warnings
@ -12,7 +13,6 @@ from contextlib import nullcontext
from torch import optim from torch import optim
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
from model.model import Transformer from model.model import Transformer
@ -28,28 +28,27 @@ def Logger(content):
def get_lr(it, all): def get_lr(it, all):
warmup_iters = 0 warmup_iters = args.warmup_iters
lr_decay_iters = all lr_decay_iters = all
min_lr = learning_rate / epochs min_lr = args.learning_rate / 10
if it < warmup_iters: if it < warmup_iters:
return learning_rate * it / warmup_iters return args.learning_rate * it / warmup_iters
if it > lr_decay_iters: if it > lr_decay_iters:
return min_lr return min_lr
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1 assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr) return min_lr + coeff * (args.learning_rate - min_lr)
# ------------------------------------------------------------------------------
def train_epoch(epoch, wandb): def train_epoch(epoch, wandb):
start_time = time.time() start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader): for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(device) X = X.to(args.device)
Y = Y.to(device) Y = Y.to(args.device)
loss_mask = loss_mask.to(device) loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
@ -59,41 +58,38 @@ def train_epoch(epoch, wandb):
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(loss * loss_mask) / loss_mask.sum() loss = torch.sum(loss * loss_mask) / loss_mask.sum()
# Backward pass
scaler.scale(loss).backward() scaler.scale(loss).backward()
# Unscale gradients and clip them if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# Update parameters
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
# Zero the gradients
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
# 打印日志 if step % args.log_interval == 0:
if step % 100 == 0:
spend_time = time.time() - start_time spend_time = time.time() - start_time
Logger( Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.8f} epoch_Time:{}min:'.format( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
epoch, epoch,
epochs, args.epochs,
step, step,
iter_per_epoch, iter_per_epoch,
loss, loss.item(),
optimizer.param_groups[-1]['lr'], optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if use_wandb != None: if wandb is not None:
wandb.log({"loss": loss, "lr": optimizer.param_groups[-1]['lr'], wandb.log({"loss": loss.item(),
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % 1000 == 0 and (not ddp or dist.get_rank() == 0): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval() model.eval()
# torch.save(model.state_dict(), '{}/sft_iter_{}.pth'.format(save_dir, int(step + epoch * iter_per_epoch)))
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{save_dir}/full_sft_{lm_config.dim}{moe_path}.pth' ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
@ -103,7 +99,7 @@ def train_epoch(epoch, wandb):
model.train() model.train()
def init_model(lm_config): def init_model():
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model_from = 1 # 1从权重2用transformers model_from = 1 # 1从权重2用transformers
@ -114,7 +110,7 @@ def init_model(lm_config):
model = Transformer(lm_config) model = Transformer(lm_config)
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth' ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=device) state_dict = torch.load(ckp, map_location=args.device)
unwanted_prefix = '_orig_mod.' unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()): for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix): if k.startswith(unwanted_prefix):
@ -124,7 +120,7 @@ def init_model(lm_config):
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True) model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万') Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万')
model = model.to(device) model = model.to(args.device)
return model, tokenizer return model, tokenizer
@ -141,83 +137,78 @@ def init_distributed_mode():
torch.cuda.set_device(DEVICE) torch.cuda.set_device(DEVICE)
# I/O
if __name__ == "__main__": if __name__ == "__main__":
# ----------------------------------------------------------------------------- parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=40, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="Weights & Biases project name")
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
args = parser.parse_args()
lm_config = LMConfig() lm_config = LMConfig()
max_seq_len = lm_config.max_seq_len max_seq_len = lm_config.max_seq_len
out_dir = 'out' args.save_dir = os.path.join(args.out_dir)
epochs = 19 os.makedirs(args.save_dir, exist_ok=True)
gradient_accumulation_steps = 1 os.makedirs(args.out_dir, exist_ok=True)
batch_size = 40 tokens_per_iter = args.batch_size * max_seq_len
learning_rate = 1e-4
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
# dtype = 'float16'
save_dir = os.path.join(out_dir)
os.makedirs(save_dir, exist_ok=True)
tokens_per_iter = gradient_accumulation_steps * batch_size * max_seq_len
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337) torch.manual_seed(1337)
device_type = device if "cuda" in device else "cpu" device_type = "cuda" if "cuda" in args.device else "cpu"
use_wandb = True #是否使用wandb args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb_project = "MiniMind-Full-SFT"
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
if use_wandb:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name)
else:
wandb = None
ctx = ( ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
nullcontext()
if device_type == "cpu"
else torch.cuda.amp.autocast()
)
### ddp config
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0" ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp: if ddp:
init_distributed_mode() init_distributed_mode()
device = torch.device(DEVICE) args.device = torch.device(DEVICE)
# -----------------------------------------------------------------------------
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model()
model, tokenizer = init_model(lm_config)
# -----init dataloader------
df = pd.read_csv('./dataset/sft_data_single.csv') df = pd.read_csv('./dataset/sft_data_single.csv')
df = df.sample(frac=1.0) df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader( train_loader = DataLoader(
train_ds, train_ds,
batch_size=batch_size, batch_size=args.batch_size,
pin_memory=False, pin_memory=True,
drop_last=False, drop_last=False,
shuffle=False, shuffle=False,
num_workers=8, num_workers=args.num_workers,
sampler=train_sampler sampler=train_sampler
) )
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype)) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
# optimizer optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
iter_per_epoch = len(train_loader) if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
# compile the model
if False and not lm_config.use_moe and platform.system() != 'Windows' and float(
torch.__version__.split('.')[0]) >= 2:
Logger("compiling the model... (takes a ~minute)") Logger("compiling the model... (takes a ~minute)")
unoptimized_model = model unoptimized_model = model
model = torch.compile(model) # requires PyTorch 2.0 model = torch.compile(model)
if ddp: if ddp:
# Ignore the pos_cis buffer so that DDP does not broadcast it at
# construction time since NCCL does not support ComplexFloat
model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
# training loop iter_per_epoch = len(train_loader)
for epoch in range(epochs,wandb): for epoch in range(args.epochs):
train_epoch(epoch) train_epoch(epoch, wandb)

View File

@ -1,5 +1,6 @@
import os import os
import platform import platform
import argparse
import time import time
import math import math
import warnings import warnings
@ -16,32 +17,36 @@ from torch.utils.data import DataLoader
from model.LMConfig import LMConfig from model.LMConfig import LMConfig
from model.dataset import SFTDataset from model.dataset import SFTDataset
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore')
def get_lr(it): def Logger(content):
warmup_iters = 1000 print(content)
lr_decay_iters = 80000
min_lr = 1e-5
def get_lr(it, all):
warmup_iters = args.warmup_iters
lr_decay_iters = all
min_lr = args.learning_rate / 10
if it < warmup_iters: if it < warmup_iters:
return learning_rate * it / warmup_iters return args.learning_rate * it / warmup_iters
if it > lr_decay_iters: if it > lr_decay_iters:
return min_lr return min_lr
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1 assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr) return min_lr + coeff * (args.learning_rate - min_lr)
# ------------------------------------------------------------------------------
def train_epoch(epoch, wandb): def train_epoch(epoch, wandb):
start_time = time.time() start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader): for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(device) X = X.to(args.device)
Y = Y.to(device) Y = Y.to(args.device)
loss_mask = loss_mask.to(device) loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
@ -50,32 +55,38 @@ def train_epoch(epoch, wandb):
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0, reduction='none') loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0, reduction='none')
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(loss * loss_mask) / loss_mask.sum() loss = torch.sum(loss * loss_mask) / loss_mask.sum()
loss = loss / args.accumulation_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
if step % 100 == 0: if step % args.log_interval == 0:
spend_time = time.time() - start_time spend_time = time.time() - start_time
print( Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
epoch, epoch,
epochs, args.epochs,
step, step,
iter_per_epoch, iter_per_epoch,
loss.item(), loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'], optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if use_wandb != None: if wandb is not None:
wandb.log({"loss": loss.item(), "lr": optimizer.param_groups[-1]['lr'], wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0:
model.save_pretrained(args.save_dir)
def find_all_linear_names(model): def find_all_linear_names(model):
cls = torch.nn.Linear cls = torch.nn.Linear
@ -94,7 +105,7 @@ def init_model():
model_name_or_path = "./minimind" model_name_or_path = "./minimind"
tokenizer_name_or_path = "./minimind" tokenizer_name_or_path = "./minimind"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(device) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
target_modules = find_all_linear_names(model) target_modules = find_all_linear_names(model)
peft_config = LoraConfig( peft_config = LoraConfig(
@ -107,73 +118,70 @@ def init_model():
) )
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
model.print_trainable_parameters() model.print_trainable_parameters()
model = model.to(device) model = model.to(args.device)
return model, tokenizer return model, tokenizer
# I/O
if __name__ == "__main__": if __name__ == "__main__":
# ----------------------------------------------------------------------------- parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
args = parser.parse_args()
lm_config = LMConfig() lm_config = LMConfig()
max_seq_len = lm_config.max_seq_len max_seq_len = lm_config.max_seq_len
out_dir = 'out' args.save_dir = os.path.join(args.out_dir)
epochs = 20 os.makedirs(args.save_dir, exist_ok=True)
gradient_accumulation_steps = 1 os.makedirs(args.out_dir, exist_ok=True)
batch_size = 16 tokens_per_iter = args.batch_size * max_seq_len
learning_rate = 1e-4
weight_decay = 1e-1
device = 'cuda:0'
dtype = 'bfloat16'
save_dir = os.path.join(out_dir)
os.makedirs(save_dir, exist_ok=True)
tokens_per_iter = gradient_accumulation_steps * batch_size * max_seq_len
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337) torch.manual_seed(1337)
device_type = device if "cuda" in device else "cpu" device_type = "cuda" if "cuda" in args.device else "cpu"
use_wandb = True #是否使用wandb args.wandb_run_name = f"MiniMind-LoRA-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb_project = "MiniMind-LoRA"
wandb_run_name = f"MiniMind-LoRA-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
if use_wandb:
if args.use_wandb:
import wandb import wandb
wandb.init(project=wandb_project, name=wandb_run_name) wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else: else:
wandb = None wandb = None
ctx = (
nullcontext()
if device_type == "cpu"
else torch.cuda.amp.autocast()
)
# -----------------------------------------------------------------------------
model, tokenizer = init_model() model, tokenizer = init_model()
# -----init dataloader------
df = pd.read_csv('./dataset/sft_data.csv') df = pd.read_csv('./dataset/sft_data.csv')
df = df.sample(frac=1.0) df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_loader = DataLoader( train_loader = DataLoader(
train_ds, train_ds,
batch_size=batch_size, batch_size=args.batch_size,
pin_memory=False, pin_memory=True,
drop_last=False, drop_last=False,
shuffle=False, shuffle=False,
num_workers=0, num_workers=args.num_workers,
) )
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# optimizer optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
iter_per_epoch = len(train_loader)
# compile the model
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
print("compiling the model... (takes a ~minute)") Logger("compiling the model... (takes a ~minute)")
unoptimized_model = model unoptimized_model = model
model = torch.compile(model) model = torch.compile(model)
raw_model = model iter_per_epoch = len(train_loader)
# training loop for epoch in range(args.epochs):
for epoch in range(epochs):
train_epoch(epoch, wandb) train_epoch(epoch, wandb)
model.save_pretrained('minimind')