Minimind/train_embedding.py

418 lines
17 KiB
Python
Raw Permalink Normal View History

import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext
import random
import numpy as np
import json
from transformers import AutoTokenizer
# Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig
# from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# Define a Word2Vec-style CBOW model
class CBOWModel(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.vocab_size = config.vocab_size
self.embedding_dim = config.dim
# Input embeddings (context words)
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
# Output weights for target prediction
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
# Xavier initialization for better convergence
nn.init.xavier_uniform_(self.embeddings.weight)
nn.init.xavier_uniform_(self.output_weights.weight)
def forward(self, context_words):
# context_words shape: [batch_size, context_size]context_size可变
# Get embeddings for all context words
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
# Average the context word embeddings along context dimension
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
# Predict the target word
output = self.output_weights(embeds) # [batch_size, vocab_size]
return output
# Word2Vec CBOW dataset
class CBOWDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
super().__init__()
self.tokenizer = tokenizer
self.window_size = window_size
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 获取token ids
input_ids = encoding.input_ids.squeeze()
# 过滤掉padding
attention_mask = encoding.attention_mask.squeeze()
valid_indices = torch.where(attention_mask == 1)[0]
valid_input_ids = input_ids[valid_indices]
# 确保有足够的token进行CBOW训练
if len(valid_input_ids) <= 2 * self.window_size + 1:
# 如果token不足随机选择一个不同的样本
return self.__getitem__(random.randint(0, len(self.samples) - 1))
# 随机选择一个中心位置不包括首尾的特殊token
# 确保中心位置两边都有至少window_size个token
min_center_pos = self.window_size + 1 # 避开起始token
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
if max_center_pos <= min_center_pos:
return self.__getitem__(random.randint(0, len(self.samples) - 1))
center_pos = random.randint(min_center_pos, max_center_pos)
# 目标词(中心词)
target = valid_input_ids[center_pos].unsqueeze(0)
# 上下文词(中心词前后的词)
context = torch.cat([
valid_input_ids[center_pos - self.window_size:center_pos],
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
])
return context, target
def Logger(content):
# 如果没有使用ddp或者ddp的主设备那么就打印
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
# 更新学习率
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss()
start_time = time.time()
total_loss = 0
total_samples = 0
for step, (context, target) in enumerate(train_loader):
try:
# 将数据加载到设备上
context = context.to(args.device)
target = target.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
# Forward pass
logits = model(context) # [batch_size, vocab_size]
# target是[batch_size, 1]需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
loss = loss_fct(logits, target.squeeze())
loss = loss / args.accumulation_steps
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0):
Logger("---- Data Type Check ----")
Logger(f"context.dtype: {context.dtype}")
Logger(f"context.shape: {context.shape}")
Logger(f"target.dtype: {target.dtype}")
Logger(f"target.shape: {target.shape}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"logits.dtype: {logits.dtype}")
Logger(f"logits.shape: {logits.shape}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * args.accumulation_steps
total_samples += 1
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
avg_loss = total_loss / total_samples if total_samples > 0 else 0
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
avg_loss,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": avg_loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
traceback.print_exc()
# Modified checkpoint path for error
save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.embeddings.state_dict()
else:
state_dict = model.embeddings.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
# Save model once at the end of each epoch
if not ddp or dist.get_rank() == 0:
model.eval()
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.embeddings.state_dict()
else:
embedding_state_dict = model.embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
model.train()
def init_model(lm_config_params: LMConfig):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# Update vocab_size in lm_config if tokenizer has a different one
if tokenizer.vocab_size != lm_config_params.vocab_size:
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
lm_config_params.vocab_size = tokenizer.vocab_size
# 加载word2vec CBOW模型
model = CBOWModel(lm_config_params).to(args.device)
# 打印模型参数
Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
return model, tokenizer
def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
dist.init_process_group(backend="nccl") #初始化分布式进程组使用NCCL后端NVIDIA Collective Communications Library这是NVIDIA GPU之间通信的优化库。
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
# torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
parser.add_argument("--out_dir", type=str, default="out_word2vec")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=768, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument('--vocab_size', default=6400, type=int)
parser.add_argument('--window_size', default=5, type=int)
args = parser.parse_args()
# Create LMConfig with relevant parameters for embedding
lm_config = LMConfig(
dim=args.dim,
vocab_size=args.vocab_size, # Will be updated by tokenizer
max_seq_len=args.max_seq_len,
n_layers=1, # Minimal
n_heads=1, # Minimal
n_kv_heads=1 #Minimal
)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu"
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0" # Default values, will be overwritten in DDP
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode() # This sets DEVICE and ddp_local_rank
args.device = torch.device(DEVICE) # Ensure args.device is updated
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed_all(base_seed + rank) # Use seed_all for DDP
if args.use_wandb and (not ddp or dist.get_rank() == 0): # Check rank for DDP wandb init
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
else:
wandb = None
model, tokenizer = init_model(lm_config) # Pass the lm_config instance
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
if lm_config.vocab_size != tokenizer.vocab_size:
lm_config.vocab_size = tokenizer.vocab_size
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
# 添加collate函数处理不同长度的序列
def collate_cbow_batch(batch):
# 提取context和target
contexts, targets = zip(*batch)
# 获取当前批次中最长的context长度
max_len = max([ctx.size(0) for ctx in contexts])
# 创建填充后的tensor
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
# 填充每个context
for i, ctx in enumerate(contexts):
ctx_len = ctx.size(0)
padded_contexts[i, :ctx_len] = ctx
# 将targets stack成一个tensor
stacked_targets = torch.stack(targets)
return padded_contexts, stacked_targets
# Create Word2Vec CBOW dataset
train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler,
collate_fn=collate_cbow_batch
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
for epoch in range(args.epochs):
if ddp:
train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.finish()
Logger("Word2Vec embedding training finished.")