add accumulation_grad for pretrain
This commit is contained in:
parent
8c18b324d0
commit
c404941677
@ -37,34 +37,30 @@ def get_lr(it, all):
|
|||||||
return min_lr + coeff * (learning_rate - min_lr)
|
return min_lr + coeff * (learning_rate - min_lr)
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch):
|
def train_epoch(epoch, accumulation_steps=8):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for step, (X, Y) in enumerate(train_loader):
|
for step, (X, Y) in enumerate(train_loader):
|
||||||
X = X.to(device)
|
X = X.to(device)
|
||||||
Y = Y.to(device)
|
Y = Y.to(device)
|
||||||
|
|
||||||
# 设置学习率
|
|
||||||
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
|
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
# 前向传播和损失计算
|
|
||||||
with ctx:
|
with ctx:
|
||||||
out = model(X, Y)
|
out = model(X, Y)
|
||||||
loss = out.last_loss
|
loss = out.last_loss / accumulation_steps
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# 梯度剪裁和更新参数
|
if (step + 1) % accumulation_steps == 0:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
# 清零梯度
|
scaler.step(optimizer)
|
||||||
optimizer.zero_grad(set_to_none=True)
|
scaler.update()
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if step % 100 == 0:
|
if step % 100 == 0:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
@ -74,7 +70,7 @@ def train_epoch(epoch):
|
|||||||
epochs,
|
epochs,
|
||||||
step,
|
step,
|
||||||
iter_per_epoch,
|
iter_per_epoch,
|
||||||
loss.item(),
|
loss.item() * accumulation_steps,
|
||||||
optimizer.param_groups[-1]['lr'],
|
optimizer.param_groups[-1]['lr'],
|
||||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||||
|
|
||||||
@ -134,7 +130,7 @@ if __name__ == "__main__":
|
|||||||
out_dir = 'out'
|
out_dir = 'out'
|
||||||
epochs = 20
|
epochs = 20
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
learning_rate = 1e-4
|
learning_rate = 2e-4
|
||||||
device = 'cuda:0'
|
device = 'cuda:0'
|
||||||
dtype = 'bfloat16'
|
dtype = 'bfloat16'
|
||||||
save_dir = os.path.join(out_dir)
|
save_dir = os.path.join(out_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user