add accumulation_grad for pretrain

This commit is contained in:
gongjy 2024-09-16 20:58:46 +08:00
parent 8c18b324d0
commit c404941677

View File

@ -37,34 +37,30 @@ def get_lr(it, all):
return min_lr + coeff * (learning_rate - min_lr) return min_lr + coeff * (learning_rate - min_lr)
def train_epoch(epoch): def train_epoch(epoch, accumulation_steps=8):
start_time = time.time() start_time = time.time()
for step, (X, Y) in enumerate(train_loader): for step, (X, Y) in enumerate(train_loader):
X = X.to(device) X = X.to(device)
Y = Y.to(device) Y = Y.to(device)
# 设置学习率
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch) lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
# 前向传播和损失计算
with ctx: with ctx:
out = model(X, Y) out = model(X, Y)
loss = out.last_loss loss = out.last_loss / accumulation_steps
# 反向传播
scaler.scale(loss).backward() scaler.scale(loss).backward()
# 梯度剪裁和更新参数 if (step + 1) % accumulation_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
# 清零梯度 scaler.step(optimizer)
optimizer.zero_grad(set_to_none=True) scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % 100 == 0: if step % 100 == 0:
spend_time = time.time() - start_time spend_time = time.time() - start_time
@ -74,7 +70,7 @@ def train_epoch(epoch):
epochs, epochs,
step, step,
iter_per_epoch, iter_per_epoch,
loss.item(), loss.item() * accumulation_steps,
optimizer.param_groups[-1]['lr'], optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
@ -134,7 +130,7 @@ if __name__ == "__main__":
out_dir = 'out' out_dir = 'out'
epochs = 20 epochs = 20
batch_size = 64 batch_size = 64
learning_rate = 1e-4 learning_rate = 2e-4
device = 'cuda:0' device = 'cuda:0'
dtype = 'bfloat16' dtype = 'bfloat16'
save_dir = os.path.join(out_dir) save_dir = os.path.join(out_dir)