diff --git a/1-pretrain.py b/1-pretrain.py index 20fdf42..d51b0e3 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -45,9 +45,10 @@ def get_lr(it, all): def train_epoch(epoch, wandb): start_time = time.time() - for step, (X, Y) in enumerate(train_loader): + for step, (X, Y, loss_mask) in enumerate(train_loader): X = X.to(args.device) Y = Y.to(args.device) + loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch) for param_group in optimizer.param_groups: @@ -56,6 +57,8 @@ def train_epoch(epoch, wandb): with ctx: out = model(X, Y) loss = out.last_loss / args.accumulation_steps + loss_mask = loss_mask.view(-1) + loss = torch.sum(loss * loss_mask) / loss_mask.sum() scaler.scale(loss).backward() @@ -129,7 +132,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining") parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") - parser.add_argument("--batch_size", type=int, default=48, help="Batch size") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") diff --git a/model/dataset.py b/model/dataset.py index 2e417f6..339695f 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -28,14 +28,18 @@ class PretrainDataset(Dataset): sample = self.df.iloc[index] text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}" input_id = self.tokenizer(text).data['input_ids'][:self.max_length] + text_len = len(input_id) # 没满最大长度的剩余部分 - padding_len = self.max_length - len(input_id) + padding_len = self.max_length - text_len input_id = input_id + [self.padding] * padding_len + # 0表示不计算损失 + loss_mask = [1] * text_len + [0] * padding_len input_id = np.array(input_id) X = np.array(input_id[:-1]).astype(np.int64) Y = np.array(input_id[1:]).astype(np.int64) - return torch.from_numpy(X), torch.from_numpy(Y) + loss_mask = np.array(loss_mask[1:]).astype(np.int64) + return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) class SFTDataset(Dataset):