添加了train_embedding用于预训练嵌入模型

This commit is contained in:
iomgaa 2025-05-08 15:41:04 +00:00
parent 0859f54a88
commit 10f15724b4
3 changed files with 544 additions and 65 deletions

View File

@ -173,31 +173,42 @@ class CrossAttention(nn.Module):
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = 768 // self.num_heads
self.to_q = nn.Linear(768, 768, bias=False)
self.to_k = nn.Linear(768, 768, bias=False)
self.to_v = nn.Linear(768, 768, bias=False)
self.to_out = nn.Linear(768, 768, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
# db = db.permute(0, 2, 1)
batch_size = x.size(0)
q = self.to_q(x)
k = self.to_k(db)
v = self.to_v(db)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10)
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768)
context = self.to_out(context)
return context
class FeedForward(nn.Module):

418
train_embedding.py Normal file
View File

@ -0,0 +1,418 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext
import random
import numpy as np
import json
from transformers import AutoTokenizer
# Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig
# from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# Define a Word2Vec-style CBOW model
class CBOWModel(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.vocab_size = config.vocab_size
self.embedding_dim = config.dim
# Input embeddings (context words)
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
# Output weights for target prediction
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
# Xavier initialization for better convergence
nn.init.xavier_uniform_(self.embeddings.weight)
nn.init.xavier_uniform_(self.output_weights.weight)
def forward(self, context_words):
# context_words shape: [batch_size, context_size]context_size可变
# Get embeddings for all context words
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
# Average the context word embeddings along context dimension
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
# Predict the target word
output = self.output_weights(embeds) # [batch_size, vocab_size]
return output
# Word2Vec CBOW dataset
class CBOWDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
super().__init__()
self.tokenizer = tokenizer
self.window_size = window_size
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 获取token ids
input_ids = encoding.input_ids.squeeze()
# 过滤掉padding
attention_mask = encoding.attention_mask.squeeze()
valid_indices = torch.where(attention_mask == 1)[0]
valid_input_ids = input_ids[valid_indices]
# 确保有足够的token进行CBOW训练
if len(valid_input_ids) <= 2 * self.window_size + 1:
# 如果token不足随机选择一个不同的样本
return self.__getitem__(random.randint(0, len(self.samples) - 1))
# 随机选择一个中心位置不包括首尾的特殊token
# 确保中心位置两边都有至少window_size个token
min_center_pos = self.window_size + 1 # 避开起始token
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
if max_center_pos <= min_center_pos:
return self.__getitem__(random.randint(0, len(self.samples) - 1))
center_pos = random.randint(min_center_pos, max_center_pos)
# 目标词(中心词)
target = valid_input_ids[center_pos].unsqueeze(0)
# 上下文词(中心词前后的词)
context = torch.cat([
valid_input_ids[center_pos - self.window_size:center_pos],
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
])
return context, target
def Logger(content):
# 如果没有使用ddp或者ddp的主设备那么就打印
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
# 更新学习率
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss()
start_time = time.time()
total_loss = 0
total_samples = 0
for step, (context, target) in enumerate(train_loader):
try:
# 将数据加载到设备上
context = context.to(args.device)
target = target.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
# Forward pass
logits = model(context) # [batch_size, vocab_size]
# target是[batch_size, 1]需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
loss = loss_fct(logits, target.squeeze())
loss = loss / args.accumulation_steps
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0):
Logger("---- Data Type Check ----")
Logger(f"context.dtype: {context.dtype}")
Logger(f"context.shape: {context.shape}")
Logger(f"target.dtype: {target.dtype}")
Logger(f"target.shape: {target.shape}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"logits.dtype: {logits.dtype}")
Logger(f"logits.shape: {logits.shape}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * args.accumulation_steps
total_samples += 1
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
avg_loss = total_loss / total_samples if total_samples > 0 else 0
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
avg_loss,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": avg_loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
traceback.print_exc()
# Modified checkpoint path for error
save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.embeddings.state_dict()
else:
state_dict = model.embeddings.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
# Save model once at the end of each epoch
if not ddp or dist.get_rank() == 0:
model.eval()
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.embeddings.state_dict()
else:
embedding_state_dict = model.embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
model.train()
def init_model(lm_config_params: LMConfig):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# Update vocab_size in lm_config if tokenizer has a different one
if tokenizer.vocab_size != lm_config_params.vocab_size:
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
lm_config_params.vocab_size = tokenizer.vocab_size
# 加载word2vec CBOW模型
model = CBOWModel(lm_config_params).to(args.device)
# 打印模型参数
Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
return model, tokenizer
def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
dist.init_process_group(backend="nccl") #初始化分布式进程组使用NCCL后端NVIDIA Collective Communications Library这是NVIDIA GPU之间通信的优化库。
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
# torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
parser.add_argument("--out_dir", type=str, default="out_word2vec")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=768, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument('--vocab_size', default=6400, type=int)
parser.add_argument('--window_size', default=5, type=int)
args = parser.parse_args()
# Create LMConfig with relevant parameters for embedding
lm_config = LMConfig(
dim=args.dim,
vocab_size=args.vocab_size, # Will be updated by tokenizer
max_seq_len=args.max_seq_len,
n_layers=1, # Minimal
n_heads=1, # Minimal
n_kv_heads=1 #Minimal
)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu"
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0" # Default values, will be overwritten in DDP
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode() # This sets DEVICE and ddp_local_rank
args.device = torch.device(DEVICE) # Ensure args.device is updated
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed_all(base_seed + rank) # Use seed_all for DDP
if args.use_wandb and (not ddp or dist.get_rank() == 0): # Check rank for DDP wandb init
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
else:
wandb = None
model, tokenizer = init_model(lm_config) # Pass the lm_config instance
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
if lm_config.vocab_size != tokenizer.vocab_size:
lm_config.vocab_size = tokenizer.vocab_size
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
# 添加collate函数处理不同长度的序列
def collate_cbow_batch(batch):
# 提取context和target
contexts, targets = zip(*batch)
# 获取当前批次中最长的context长度
max_len = max([ctx.size(0) for ctx in contexts])
# 创建填充后的tensor
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
# 填充每个context
for i, ctx in enumerate(contexts):
ctx_len = ctx.size(0)
padded_contexts[i, :ctx_len] = ctx
# 将targets stack成一个tensor
stacked_targets = torch.stack(targets)
return padded_contexts, stacked_targets
# Create Word2Vec CBOW dataset
train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler,
collate_fn=collate_cbow_batch
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
for epoch in range(args.epochs):
if ddp:
train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.finish()
Logger("Word2Vec embedding training finished.")

View File

@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
from contextlib import nullcontext
from typing import Optional
from transformers import AutoTokenizer
@ -37,9 +38,10 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失Cross-Entropy Loss当 reduction='none' 时nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
try:
# 将数据加载到设备上
X = X.to(args.device)
Y = Y.to(args.device)
@ -51,27 +53,37 @@ def train_epoch(epoch, wandb):
param_group['lr'] = lr
with ctx:
res = model(X) #获取输出
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())#计算损失
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss
# 为了批次堆叠进行的处理真正的batch size为num gpu*batch size per gpu*accumulation steps
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
Logger("---- Data Type Check ----")
Logger(f"X.dtype: {X.dtype}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"res.logits.dtype: {res.logits.dtype}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
scaler.scale(loss).backward()
# 如果达到堆叠数目就进行处理
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。
optimizer.zero_grad(set_to_none=True)
# 打印日志
if step % args.log_interval == 0:
@ -105,12 +117,45 @@ def train_epoch(epoch, wandb):
torch.save(state_dict, ckp) #只保存参数
model.train()
except Exception as e:
print(f"Error occurred: {str(e)}")
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
def init_model(lm_config):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# 加载模型
model = MiniMindLM(lm_config).to(args.device)
# Load pretrained token embeddings if path is provided
if pretrained_embedding_path and os.path.exists(pretrained_embedding_path):
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
embedding_weights = torch.load(pretrained_embedding_path, map_location=args.device)
model.tok_embeddings.load_state_dict(embedding_weights)
Logger("Successfully loaded pretrained token embeddings.")
elif pretrained_embedding_path:
Logger(f"Warning: Pretrained embedding path {pretrained_embedding_path} provided but file does not exist. Initializing embeddings from scratch.")
# 打印模型参数
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
@ -153,6 +198,7 @@ if __name__ == "__main__":
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) #创建LMConfig对象用于控制模型配置。
@ -163,9 +209,12 @@ if __name__ == "__main__":
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
@ -189,7 +238,7 @@ if __name__ == "__main__":
else:
wandb = None
model, tokenizer = init_model(lm_config)
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
@ -202,13 +251,14 @@ if __name__ == "__main__":
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)