添加了train_embedding用于预训练嵌入模型

This commit is contained in:
iomgaa 2025-05-08 21:11:05 +08:00
parent 4ab8064ee0
commit 253576967c

View File

@ -12,31 +12,122 @@ import torch.distributed as dist
from torch import optim, nn from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext from contextlib import nullcontext
import random
import numpy as np
import json
from transformers import AutoTokenizer from transformers import AutoTokenizer
# Removed: from model.model import MiniMindLM # Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig from model.LMConfig import LMConfig
from model.dataset import PretrainDataset # from model.dataset import PretrainDataset
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Define a simple model for pretraining embeddings # Define a Word2Vec-style CBOW model
class EmbeddingPretrainer(nn.Module): class CBOWModel(nn.Module):
def __init__(self, config: LMConfig): def __init__(self, config: LMConfig):
super().__init__() super().__init__()
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) self.embedding_dim = config.dim
# Tie weights (optional but common)
# self.tok_embeddings.weight = self.lm_head.weight # Input embeddings (context words)
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
# Output weights for target prediction
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
# Xavier initialization for better convergence
nn.init.xavier_uniform_(self.embeddings.weight)
nn.init.xavier_uniform_(self.output_weights.weight)
def forward(self, context_words):
# context_words shape: [batch_size, context_size]context_size可变
# Get embeddings for all context words
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
# Average the context word embeddings along context dimension
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
# Predict the target word
output = self.output_weights(embeds) # [batch_size, vocab_size]
return output
def forward(self, input_ids):
hidden_states = self.tok_embeddings(input_ids) # Word2Vec CBOW dataset
logits = self.lm_head(hidden_states) class CBOWDataset(Dataset):
return logits def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
super().__init__()
self.tokenizer = tokenizer
self.window_size = window_size
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 获取token ids
input_ids = encoding.input_ids.squeeze()
# 过滤掉padding
attention_mask = encoding.attention_mask.squeeze()
valid_indices = torch.where(attention_mask == 1)[0]
valid_input_ids = input_ids[valid_indices]
# 确保有足够的token进行CBOW训练
if len(valid_input_ids) <= 2 * self.window_size + 1:
# 如果token不足随机选择一个不同的样本
return self.__getitem__(random.randint(0, len(self.samples) - 1))
# 随机选择一个中心位置不包括首尾的特殊token
# 确保中心位置两边都有至少window_size个token
min_center_pos = self.window_size + 1 # 避开起始token
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
if max_center_pos <= min_center_pos:
return self.__getitem__(random.randint(0, len(self.samples) - 1))
center_pos = random.randint(min_center_pos, max_center_pos)
# 目标词(中心词)
target = valid_input_ids[center_pos].unsqueeze(0)
# 上下文词(中心词前后的词)
context = torch.cat([
valid_input_ids[center_pos - self.window_size:center_pos],
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
])
return context, target
def Logger(content): def Logger(content):
@ -52,14 +143,16 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb): def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=0) # Assuming 0 is pad_token_id loss_fct = nn.CrossEntropyLoss()
start_time = time.time() start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader): total_loss = 0
total_samples = 0
for step, (context, target) in enumerate(train_loader):
try: try:
# 将数据加载到设备上 # 将数据加载到设备上
X = X.to(args.device) context = context.to(args.device)
Y = Y.to(args.device) target = target.to(args.device)
loss_mask = loss_mask.to(args.device)
# 更新学习率 # 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
@ -67,28 +160,28 @@ def train_epoch(epoch, wandb):
param_group['lr'] = lr param_group['lr'] = lr
with ctx: with ctx:
logits = model(X) # Model returns logits directly # Forward pass
loss = loss_fct( logits = model(context) # [batch_size, vocab_size]
logits.view(-1, logits.size(-1)), # target是[batch_size, 1]需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
Y.view(-1) loss = loss_fct(logits, target.squeeze())
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# Removed: loss += res.aux_loss
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
# Print data types for debugging # Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process if step == 0 and (not ddp or dist.get_rank() == 0):
Logger("---- Data Type Check ----") Logger("---- Data Type Check ----")
Logger(f"X.dtype: {X.dtype}") Logger(f"context.dtype: {context.dtype}")
if hasattr(model, 'module'): # DDP case Logger(f"context.shape: {context.shape}")
Logger(f"target.dtype: {target.dtype}")
Logger(f"target.shape: {target.shape}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}") Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}") Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"logits.dtype: {logits.dtype}") # Changed from res.logits.dtype Logger(f"logits.dtype: {logits.dtype}")
Logger(f"logits.shape: {logits.shape}")
Logger(f"loss.dtype: {loss.dtype}") Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------") Logger("-------------------------")
scaler.scale(loss).backward() scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0: if (step + 1) % args.accumulation_steps == 0:
@ -99,52 +192,43 @@ def train_epoch(epoch, wandb):
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * args.accumulation_steps
total_samples += 1
# 打印日志 # 打印日志
if step % args.log_interval == 0: if step % args.log_interval == 0:
spend_time = time.time() - start_time spend_time = time.time() - start_time
avg_loss = total_loss / total_samples if total_samples > 0 else 0
Logger( Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1, epoch + 1,
args.epochs, args.epochs,
step, step,
iter_per_epoch, iter_per_epoch,
loss.item() * args.accumulation_steps, avg_loss,
optimizer.param_groups[-1]['lr'], optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0): if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps, wandb.log({"loss": avg_loss,
"lr": optimizer.param_groups[-1]['lr'], "lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# Modified checkpoint path and content
ckp = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.tok_embeddings.state_dict()
else:
embedding_state_dict = model.tok_embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved pretrained embedding to {ckp}")
model.train()
except Exception as e: except Exception as e:
print(f"Error occurred: {str(e)}") print(f"Error occurred: {str(e)}")
import traceback
traceback.print_exc()
# Modified checkpoint path for error # Modified checkpoint path for error
save_path = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth' save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
if os.path.exists(save_path): if os.path.exists(save_path):
os.remove(save_path) os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.tok_embeddings.state_dict() state_dict = model.module.embeddings.state_dict()
else: else:
state_dict = model.tok_embeddings.state_dict() state_dict = model.embeddings.state_dict()
torch.save(state_dict, save_path) # Save embedding state dict on error torch.save(state_dict, save_path)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any(): if param.grad is not None and torch.isnan(param.grad).any():
@ -156,9 +240,23 @@ def train_epoch(epoch, wandb):
print(f"Parameter {name} gradients: {param.grad}") print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected") raise ValueError("NaN gradient detected")
# Save model once at the end of each epoch
if not ddp or dist.get_rank() == 0:
model.eval()
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.embeddings.state_dict()
else:
embedding_state_dict = model.embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
model.train()
def init_model(lm_config_params: LMConfig): # Renamed for clarity def init_model(lm_config_params: LMConfig):
# 加载tokenizer # 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# Update vocab_size in lm_config if tokenizer has a different one # Update vocab_size in lm_config if tokenizer has a different one
@ -166,10 +264,10 @@ def init_model(lm_config_params: LMConfig): # Renamed for clarity
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.") Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
lm_config_params.vocab_size = tokenizer.vocab_size lm_config_params.vocab_size = tokenizer.vocab_size
# 加载模型 # 加载word2vec CBOW模型
model = EmbeddingPretrainer(lm_config_params).to(args.device) # Use EmbeddingPretrainer model = CBOWModel(lm_config_params).to(args.device)
# 打印模型参数 # 打印模型参数
Logger(f'EmbeddingPretrainer total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million') Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
return model, tokenizer return model, tokenizer
@ -187,32 +285,27 @@ def init_distributed_mode():
# torchrun --nproc_per_node 2 train_embedding.py # torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Embedding Pretraining") # Changed description parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
parser.add_argument("--out_dir", type=str, default="out_embedding") # Changed default out_dir parser.add_argument("--out_dir", type=str, default="out_word2vec")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=32) # Smaller batch size might be needed if memory is an issue parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。 parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Embedding-Pretrain") # Changed project name parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--ddp", action="store_true") parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。 parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。 parser.add_argument("--grad_clip", type=float, default=1.0)
# parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。 (Can be kept or removed) parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 parser.add_argument('--dim', default=768, type=int)
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。 parser.add_argument('--max_seq_len', default=512, type=int)
# Removed n_layers, use_moe as they are not relevant for EmbeddingPretrainer parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
# parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
# parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
# Add vocab_size to args, though it will be overridden by tokenizer if different
parser.add_argument('--vocab_size', default=6400, type=int) parser.add_argument('--vocab_size', default=6400, type=int)
parser.add_argument('--window_size', default=5, type=int)
args = parser.parse_args() args = parser.parse_args()
@ -222,24 +315,21 @@ if __name__ == "__main__":
dim=args.dim, dim=args.dim,
vocab_size=args.vocab_size, # Will be updated by tokenizer vocab_size=args.vocab_size, # Will be updated by tokenizer
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
# n_layers, n_heads, etc. are not directly used by EmbeddingPretrainer but LMConfig requires them
# We can set them to default or minimal values if they cause issues, or modify LMConfig
# For now, using defaults from LMConfig definition for unneeded params.
n_layers=1, # Minimal n_layers=1, # Minimal
n_heads=1, # Minimal n_heads=1, # Minimal
n_kv_heads=1 #Minimal n_kv_heads=1 #Minimal
) )
args.save_dir = os.path.join(args.out_dir) #创建保存目录。 args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。 os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。 os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。 tokens_per_iter = args.batch_size * lm_config.max_seq_len
print(f"tokens_per_iter: {tokens_per_iter}") print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。 device_type = "cuda" if "cuda" in args.device else "cpu"
# Determine the torch dtype # Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}" # Updated run name args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype) ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
@ -270,39 +360,59 @@ if __name__ == "__main__":
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name # Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
if lm_config.vocab_size != tokenizer.vocab_size: if lm_config.vocab_size != tokenizer.vocab_size:
lm_config.vocab_size = tokenizer.vocab_size lm_config.vocab_size = tokenizer.vocab_size
args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}" args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
if wandb is not None and (not ddp or dist.get_rank() == 0): if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True) wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
# 添加collate函数处理不同长度的序列
def collate_cbow_batch(batch):
# 提取context和target
contexts, targets = zip(*batch)
# 获取当前批次中最长的context长度
max_len = max([ctx.size(0) for ctx in contexts])
# 创建填充后的tensor
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
# 填充每个context
for i, ctx in enumerate(contexts):
ctx_len = ctx.size(0)
padded_contexts[i, :ctx_len] = ctx
# 将targets stack成一个tensor
stacked_targets = torch.stack(targets)
return padded_contexts, stacked_targets
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) # Create Word2Vec CBOW dataset
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None # Added shuffle and seed train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
train_loader = DataLoader( train_loader = DataLoader(
train_ds, train_ds,
batch_size=args.batch_size, batch_size=args.batch_size,
pin_memory=True, pin_memory=True,
drop_last=True, # Set to True for more stable training step counts drop_last=True,
shuffle=(train_sampler is None), # Shuffle only if not using DDP sampler shuffle=(train_sampler is None),
num_workers=args.num_workers, num_workers=args.num_workers,
sampler=train_sampler sampler=train_sampler,
collate_fn=collate_cbow_batch
) )
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) # bfloat16 also uses scaler scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp: if ddp:
# model._ddp_params_and_buffers_to_ignore = {"pos_cis"} # Not relevant for EmbeddingPretrainer
model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
# torch.autograd.set_detect_anomaly(True) # Can be enabled for debugging
iter_per_epoch = len(train_loader) iter_per_epoch = len(train_loader)
Logger(f"Starting training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.") Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
for epoch in range(args.epochs): for epoch in range(args.epochs):
if ddp: if ddp:
train_sampler.set_epoch(epoch) # Important for DDP shuffling train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb) train_epoch(epoch, wandb)
if wandb is not None and (not ddp or dist.get_rank() == 0) : if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.finish() wandb.finish()
Logger("Embedding pretraining finished.") Logger("Word2Vec embedding training finished.")