添加了train_embedding用于预训练嵌入模型

This commit is contained in:
iomgaa 2025-05-08 21:11:05 +08:00
parent 4ab8064ee0
commit 253576967c

View File

@ -12,31 +12,122 @@ import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext
import random
import numpy as np
import json
from transformers import AutoTokenizer
# Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
# from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# Define a simple model for pretraining embeddings
class EmbeddingPretrainer(nn.Module):
# Define a Word2Vec-style CBOW model
class CBOWModel(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
# Tie weights (optional but common)
# self.tok_embeddings.weight = self.lm_head.weight
self.vocab_size = config.vocab_size
self.embedding_dim = config.dim
def forward(self, input_ids):
hidden_states = self.tok_embeddings(input_ids)
logits = self.lm_head(hidden_states)
return logits
# Input embeddings (context words)
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
# Output weights for target prediction
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
# Xavier initialization for better convergence
nn.init.xavier_uniform_(self.embeddings.weight)
nn.init.xavier_uniform_(self.output_weights.weight)
def forward(self, context_words):
# context_words shape: [batch_size, context_size]context_size可变
# Get embeddings for all context words
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
# Average the context word embeddings along context dimension
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
# Predict the target word
output = self.output_weights(embeds) # [batch_size, vocab_size]
return output
# Word2Vec CBOW dataset
class CBOWDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
super().__init__()
self.tokenizer = tokenizer
self.window_size = window_size
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 获取token ids
input_ids = encoding.input_ids.squeeze()
# 过滤掉padding
attention_mask = encoding.attention_mask.squeeze()
valid_indices = torch.where(attention_mask == 1)[0]
valid_input_ids = input_ids[valid_indices]
# 确保有足够的token进行CBOW训练
if len(valid_input_ids) <= 2 * self.window_size + 1:
# 如果token不足随机选择一个不同的样本
return self.__getitem__(random.randint(0, len(self.samples) - 1))
# 随机选择一个中心位置不包括首尾的特殊token
# 确保中心位置两边都有至少window_size个token
min_center_pos = self.window_size + 1 # 避开起始token
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
if max_center_pos <= min_center_pos:
return self.__getitem__(random.randint(0, len(self.samples) - 1))
center_pos = random.randint(min_center_pos, max_center_pos)
# 目标词(中心词)
target = valid_input_ids[center_pos].unsqueeze(0)
# 上下文词(中心词前后的词)
context = torch.cat([
valid_input_ids[center_pos - self.window_size:center_pos],
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
])
return context, target
def Logger(content):
@ -52,14 +143,16 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=0) # Assuming 0 is pad_token_id
loss_fct = nn.CrossEntropyLoss()
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
total_loss = 0
total_samples = 0
for step, (context, target) in enumerate(train_loader):
try:
# 将数据加载到设备上
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
context = context.to(args.device)
target = target.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
@ -67,28 +160,28 @@ def train_epoch(epoch, wandb):
param_group['lr'] = lr
with ctx:
logits = model(X) # Model returns logits directly
loss = loss_fct(
logits.view(-1, logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# Removed: loss += res.aux_loss
# Forward pass
logits = model(context) # [batch_size, vocab_size]
# target是[batch_size, 1]需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
loss = loss_fct(logits, target.squeeze())
loss = loss / args.accumulation_steps
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
if step == 0 and (not ddp or dist.get_rank() == 0):
Logger("---- Data Type Check ----")
Logger(f"X.dtype: {X.dtype}")
Logger(f"context.dtype: {context.dtype}")
Logger(f"context.shape: {context.shape}")
Logger(f"target.dtype: {target.dtype}")
Logger(f"target.shape: {target.shape}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"logits.dtype: {logits.dtype}") # Changed from res.logits.dtype
Logger(f"logits.dtype: {logits.dtype}")
Logger(f"logits.shape: {logits.shape}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
@ -100,51 +193,42 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * args.accumulation_steps
total_samples += 1
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
avg_loss = total_loss / total_samples if total_samples > 0 else 0
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
avg_loss,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
wandb.log({"loss": avg_loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# Modified checkpoint path and content
ckp = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.tok_embeddings.state_dict()
else:
embedding_state_dict = model.tok_embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved pretrained embedding to {ckp}")
model.train()
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
traceback.print_exc()
# Modified checkpoint path for error
save_path = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.tok_embeddings.state_dict()
state_dict = model.module.embeddings.state_dict()
else:
state_dict = model.tok_embeddings.state_dict()
torch.save(state_dict, save_path) # Save embedding state dict on error
state_dict = model.embeddings.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
@ -157,8 +241,22 @@ def train_epoch(epoch, wandb):
raise ValueError("NaN gradient detected")
# Save model once at the end of each epoch
if not ddp or dist.get_rank() == 0:
model.eval()
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
def init_model(lm_config_params: LMConfig): # Renamed for clarity
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.embeddings.state_dict()
else:
embedding_state_dict = model.embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
model.train()
def init_model(lm_config_params: LMConfig):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# Update vocab_size in lm_config if tokenizer has a different one
@ -166,10 +264,10 @@ def init_model(lm_config_params: LMConfig): # Renamed for clarity
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
lm_config_params.vocab_size = tokenizer.vocab_size
# 加载模型
model = EmbeddingPretrainer(lm_config_params).to(args.device) # Use EmbeddingPretrainer
# 加载word2vec CBOW模型
model = CBOWModel(lm_config_params).to(args.device)
# 打印模型参数
Logger(f'EmbeddingPretrainer total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
return model, tokenizer
@ -187,32 +285,27 @@ def init_distributed_mode():
# torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Embedding Pretraining") # Changed description
parser.add_argument("--out_dir", type=str, default="out_embedding") # Changed default out_dir
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
parser.add_argument("--out_dir", type=str, default="out_word2vec")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=32) # Smaller batch size might be needed if memory is an issue
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Embedding-Pretrain") # Changed project name
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
# parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。 (Can be kept or removed)
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
# Removed n_layers, use_moe as they are not relevant for EmbeddingPretrainer
# parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
# parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
# Add vocab_size to args, though it will be overridden by tokenizer if different
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=768, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument('--vocab_size', default=6400, type=int)
parser.add_argument('--window_size', default=5, type=int)
args = parser.parse_args()
@ -222,24 +315,21 @@ if __name__ == "__main__":
dim=args.dim,
vocab_size=args.vocab_size, # Will be updated by tokenizer
max_seq_len=args.max_seq_len,
# n_layers, n_heads, etc. are not directly used by EmbeddingPretrainer but LMConfig requires them
# We can set them to default or minimal values if they cause issues, or modify LMConfig
# For now, using defaults from LMConfig definition for unneeded params.
n_layers=1, # Minimal
n_heads=1, # Minimal
n_kv_heads=1 #Minimal
)
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。
tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
device_type = "cuda" if "cuda" in args.device else "cpu"
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}" # Updated run name
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
@ -270,39 +360,59 @@ if __name__ == "__main__":
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
if lm_config.vocab_size != tokenizer.vocab_size:
lm_config.vocab_size = tokenizer.vocab_size
args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}"
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
# 添加collate函数处理不同长度的序列
def collate_cbow_batch(batch):
# 提取context和target
contexts, targets = zip(*batch)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None # Added shuffle and seed
# 获取当前批次中最长的context长度
max_len = max([ctx.size(0) for ctx in contexts])
# 创建填充后的tensor
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
# 填充每个context
for i, ctx in enumerate(contexts):
ctx_len = ctx.size(0)
padded_contexts[i, :ctx_len] = ctx
# 将targets stack成一个tensor
stacked_targets = torch.stack(targets)
return padded_contexts, stacked_targets
# Create Word2Vec CBOW dataset
train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=True, # Set to True for more stable training step counts
shuffle=(train_sampler is None), # Shuffle only if not using DDP sampler
drop_last=True,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
sampler=train_sampler,
collate_fn=collate_cbow_batch
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) # bfloat16 also uses scaler
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# model._ddp_params_and_buffers_to_ignore = {"pos_cis"} # Not relevant for EmbeddingPretrainer
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
# torch.autograd.set_detect_anomaly(True) # Can be enabled for debugging
iter_per_epoch = len(train_loader)
Logger(f"Starting training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
for epoch in range(args.epochs):
if ddp:
train_sampler.set_epoch(epoch) # Important for DDP shuffling
train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if wandb is not None and (not ddp or dist.get_rank() == 0) :
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.finish()
Logger("Embedding pretraining finished.")
Logger("Word2Vec embedding training finished.")