From c4049416770d9bfc4e388a4e3d4846ee8ab189b8 Mon Sep 17 00:00:00 2001
From: gongjy <2474590974@qq.com>
Date: Mon, 16 Sep 2024 20:58:46 +0800
Subject: [PATCH] add accumulation_grad for pretrain

---
 1-pretrain.py | 26 +++++++++++---------------
 1 file changed, 11 insertions(+), 15 deletions(-)

diff --git a/1-pretrain.py b/1-pretrain.py
index dcf3f06..ce427aa 100644
--- a/1-pretrain.py
+++ b/1-pretrain.py
@@ -37,34 +37,30 @@ def get_lr(it, all):
     return min_lr + coeff * (learning_rate - min_lr)
 
 
-def train_epoch(epoch):
+def train_epoch(epoch, accumulation_steps=8):
     start_time = time.time()
-
     for step, (X, Y) in enumerate(train_loader):
         X = X.to(device)
         Y = Y.to(device)
 
-        # 设置学习率
         lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
         for param_group in optimizer.param_groups:
             param_group['lr'] = lr
 
-        # 前向传播和损失计算
         with ctx:
             out = model(X, Y)
-            loss = out.last_loss
+            loss = out.last_loss / accumulation_steps
 
-        # 反向传播
         scaler.scale(loss).backward()
 
-        # 梯度剪裁和更新参数
-        scaler.unscale_(optimizer)
-        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
-        scaler.step(optimizer)
-        scaler.update()
+        if (step + 1) % accumulation_steps == 0:
+            scaler.unscale_(optimizer)
+            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 
-        # 清零梯度
-        optimizer.zero_grad(set_to_none=True)
+            scaler.step(optimizer)
+            scaler.update()
+
+            optimizer.zero_grad(set_to_none=True)
 
         if step % 100 == 0:
             spend_time = time.time() - start_time
@@ -74,7 +70,7 @@ def train_epoch(epoch):
                     epochs,
                     step,
                     iter_per_epoch,
-                    loss.item(),
+                    loss.item() * accumulation_steps,
                     optimizer.param_groups[-1]['lr'],
                     spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
 
@@ -134,7 +130,7 @@ if __name__ == "__main__":
     out_dir = 'out'
     epochs = 20
     batch_size = 64
-    learning_rate = 1e-4
+    learning_rate = 2e-4
     device = 'cuda:0'
     dtype = 'bfloat16'
     save_dir = os.path.join(out_dir)