From 75753ea765279c1f7809a8295bb2a65a6c6ba3da Mon Sep 17 00:00:00 2001
From: gongjy <2474590974@qq.com>
Date: Fri, 27 Sep 2024 17:19:03 +0800
Subject: [PATCH] Update data preprocessing methods

---
 1-pretrain.py    | 7 +++++--
 model/dataset.py | 8 ++++++--
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/1-pretrain.py b/1-pretrain.py
index 20fdf42..d51b0e3 100644
--- a/1-pretrain.py
+++ b/1-pretrain.py
@@ -45,9 +45,10 @@ def get_lr(it, all):
 
 def train_epoch(epoch, wandb):
     start_time = time.time()
-    for step, (X, Y) in enumerate(train_loader):
+    for step, (X, Y, loss_mask) in enumerate(train_loader):
         X = X.to(args.device)
         Y = Y.to(args.device)
+        loss_mask = loss_mask.to(args.device)
 
         lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
         for param_group in optimizer.param_groups:
@@ -56,6 +57,8 @@ def train_epoch(epoch, wandb):
         with ctx:
             out = model(X, Y)
             loss = out.last_loss / args.accumulation_steps
+            loss_mask = loss_mask.view(-1)
+            loss = torch.sum(loss * loss_mask) / loss_mask.sum()
 
         scaler.scale(loss).backward()
 
@@ -129,7 +132,7 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="MiniMind Pretraining")
     parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
     parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
-    parser.add_argument("--batch_size", type=int, default=48, help="Batch size")
+    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
     parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
     parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
                         help="Device to use")
diff --git a/model/dataset.py b/model/dataset.py
index 2e417f6..339695f 100644
--- a/model/dataset.py
+++ b/model/dataset.py
@@ -28,14 +28,18 @@ class PretrainDataset(Dataset):
         sample = self.df.iloc[index]
         text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
         input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
+        text_len = len(input_id)
         # 没满最大长度的剩余部分
-        padding_len = self.max_length - len(input_id)
+        padding_len = self.max_length - text_len
         input_id = input_id + [self.padding] * padding_len
+        # 0表示不计算损失
+        loss_mask = [1] * text_len + [0] * padding_len
 
         input_id = np.array(input_id)
         X = np.array(input_id[:-1]).astype(np.int64)
         Y = np.array(input_id[1:]).astype(np.int64)
-        return torch.from_numpy(X), torch.from_numpy(Y)
+        loss_mask = np.array(loss_mask[1:]).astype(np.int64)
+        return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
 
 
 class SFTDataset(Dataset):