Update data preprocessing methods

This commit is contained in:
gongjy 2024-09-27 17:19:03 +08:00
parent 1cc73836d4
commit 75753ea765
2 changed files with 11 additions and 4 deletions

View File

@ -45,9 +45,10 @@ def get_lr(it, all):
def train_epoch(epoch, wandb):
start_time = time.time()
for step, (X, Y) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
for param_group in optimizer.param_groups:
@ -56,6 +57,8 @@ def train_epoch(epoch, wandb):
with ctx:
out = model(X, Y)
loss = out.last_loss / args.accumulation_steps
loss_mask = loss_mask.view(-1)
loss = torch.sum(loss * loss_mask) / loss_mask.sum()
scaler.scale(loss).backward()
@ -129,7 +132,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=48, help="Batch size")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
help="Device to use")

View File

@ -28,14 +28,18 @@ class PretrainDataset(Dataset):
sample = self.df.iloc[index]
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - len(input_id)
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = [1] * text_len + [0] * padding_len
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SFTDataset(Dataset):