From 253576967c990aba78031367067c9619cc2dd8bc Mon Sep 17 00:00:00 2001
From: iomgaa <iomgaaycz@gmail.com>
Date: Thu, 8 May 2025 21:11:05 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86train=5Fembedding?=
 =?UTF-8?q?=E7=94=A8=E4=BA=8E=E9=A2=84=E8=AE=AD=E7=BB=83=E5=B5=8C=E5=85=A5?=
 =?UTF-8?q?=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 train_embedding.py | 306 ++++++++++++++++++++++++++++++---------------
 1 file changed, 208 insertions(+), 98 deletions(-)

diff --git a/train_embedding.py b/train_embedding.py
index fbb363a..7a4493d 100644
--- a/train_embedding.py
+++ b/train_embedding.py
@@ -12,31 +12,122 @@ import torch.distributed as dist
 from torch import optim, nn
 from torch.nn.parallel import DistributedDataParallel
 from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DataLoader, DistributedSampler, Dataset
 from contextlib import nullcontext
+import random
+import numpy as np
+import json
 
 from transformers import AutoTokenizer
 
 # Removed: from model.model import MiniMindLM
 from model.LMConfig import LMConfig
-from model.dataset import PretrainDataset
+# from model.dataset import PretrainDataset
 
 warnings.filterwarnings('ignore')
 
 
-# Define a simple model for pretraining embeddings
-class EmbeddingPretrainer(nn.Module):
+# Define a Word2Vec-style CBOW model
+class CBOWModel(nn.Module):
     def __init__(self, config: LMConfig):
         super().__init__()
-        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
-        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
-        # Tie weights (optional but common)
-        # self.tok_embeddings.weight = self.lm_head.weight
+        self.vocab_size = config.vocab_size
+        self.embedding_dim = config.dim
+        
+        # Input embeddings (context words)
+        self.embeddings = nn.Embedding(config.vocab_size, config.dim)
+        
+        # Output weights for target prediction
+        self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
+        
+        # Initialize weights
+        self.init_weights()
+        
+    def init_weights(self):
+        # Xavier initialization for better convergence
+        nn.init.xavier_uniform_(self.embeddings.weight)
+        nn.init.xavier_uniform_(self.output_weights.weight)
+    
+    def forward(self, context_words):
+        # context_words shape: [batch_size, context_size],context_size可变
+        
+        # Get embeddings for all context words
+        embeds = self.embeddings(context_words)  # [batch_size, context_size, embedding_dim]
+        
+        # Average the context word embeddings along context dimension
+        embeds = torch.mean(embeds, dim=1)  # [batch_size, embedding_dim]
+        
+        # Predict the target word
+        output = self.output_weights(embeds)  # [batch_size, vocab_size]
+        
+        return output
 
-    def forward(self, input_ids):
-        hidden_states = self.tok_embeddings(input_ids)
-        logits = self.lm_head(hidden_states)
-        return logits
+
+# Word2Vec CBOW dataset
+class CBOWDataset(Dataset):
+    def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
+        super().__init__()
+        self.tokenizer = tokenizer
+        self.window_size = window_size
+        self.max_length = max_length
+        self.samples = self.load_data(data_path)
+        
+    def load_data(self, path):
+        samples = []
+        with open(path, 'r', encoding='utf-8') as f:
+            for line_num, line in enumerate(f, 1):
+                data = json.loads(line.strip())
+                samples.append(data)
+        return samples
+    
+    def __len__(self):
+        return len(self.samples)
+    
+    def __getitem__(self, index):
+        sample = self.samples[index]
+        
+        # 构建输入文本
+        text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
+        encoding = self.tokenizer(
+            text,
+            max_length=self.max_length,
+            padding='max_length',
+            truncation=True,
+            return_tensors='pt'
+        )
+        
+        # 获取token ids
+        input_ids = encoding.input_ids.squeeze()
+        # 过滤掉padding
+        attention_mask = encoding.attention_mask.squeeze()
+        valid_indices = torch.where(attention_mask == 1)[0]
+        valid_input_ids = input_ids[valid_indices]
+        
+        # 确保有足够的token进行CBOW训练
+        if len(valid_input_ids) <= 2 * self.window_size + 1:
+            # 如果token不足,随机选择一个不同的样本
+            return self.__getitem__(random.randint(0, len(self.samples) - 1))
+        
+        # 随机选择一个中心位置(不包括首尾的特殊token)
+        # 确保中心位置两边都有至少window_size个token
+        min_center_pos = self.window_size + 1  # 避开起始token
+        max_center_pos = len(valid_input_ids) - self.window_size - 1  # 避开结束token
+        
+        if max_center_pos <= min_center_pos:
+            return self.__getitem__(random.randint(0, len(self.samples) - 1))
+            
+        center_pos = random.randint(min_center_pos, max_center_pos)
+        
+        # 目标词(中心词)
+        target = valid_input_ids[center_pos].unsqueeze(0)
+        
+        # 上下文词(中心词前后的词)
+        context = torch.cat([
+            valid_input_ids[center_pos - self.window_size:center_pos],
+            valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
+        ])
+        
+        return context, target
 
 
 def Logger(content):
@@ -52,14 +143,16 @@ def get_lr(current_step, total_steps, lr):
 
 
 def train_epoch(epoch, wandb):
-    loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=0) # Assuming 0 is pad_token_id
+    loss_fct = nn.CrossEntropyLoss()
     start_time = time.time()
-    for step, (X, Y, loss_mask) in enumerate(train_loader):
+    total_loss = 0
+    total_samples = 0
+    
+    for step, (context, target) in enumerate(train_loader):
         try:
             # 将数据加载到设备上
-            X = X.to(args.device)
-            Y = Y.to(args.device)
-            loss_mask = loss_mask.to(args.device)
+            context = context.to(args.device)
+            target = target.to(args.device)
 
             # 更新学习率
             lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
@@ -67,28 +160,28 @@ def train_epoch(epoch, wandb):
                 param_group['lr'] = lr
 
             with ctx:
-                logits = model(X) # Model returns logits directly
-                loss = loss_fct(
-                    logits.view(-1, logits.size(-1)),
-                    Y.view(-1)
-                ).view(Y.size())
-                loss = (loss * loss_mask).sum() / loss_mask.sum()
-                # Removed: loss += res.aux_loss 
+                # Forward pass
+                logits = model(context)  # [batch_size, vocab_size]
+                # target是[batch_size, 1],需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
+                loss = loss_fct(logits, target.squeeze())
                 loss = loss / args.accumulation_steps
             
             # Print data types for debugging
-            if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
+            if step == 0 and (not ddp or dist.get_rank() == 0):
                 Logger("---- Data Type Check ----")
-                Logger(f"X.dtype: {X.dtype}")
-                if hasattr(model, 'module'): # DDP case
+                Logger(f"context.dtype: {context.dtype}")
+                Logger(f"context.shape: {context.shape}")
+                Logger(f"target.dtype: {target.dtype}")
+                Logger(f"target.shape: {target.shape}")
+                if hasattr(model, 'module'):  # DDP case
                     Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
-                else: # Non-DDP case
+                else:  # Non-DDP case
                     Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
-                Logger(f"logits.dtype: {logits.dtype}") # Changed from res.logits.dtype
+                Logger(f"logits.dtype: {logits.dtype}")
+                Logger(f"logits.shape: {logits.shape}")
                 Logger(f"loss.dtype: {loss.dtype}")
                 Logger("-------------------------")
 
-
             scaler.scale(loss).backward()
 
             if (step + 1) % args.accumulation_steps == 0:
@@ -99,52 +192,43 @@ def train_epoch(epoch, wandb):
                 scaler.update()
 
                 optimizer.zero_grad(set_to_none=True)
+            
+            total_loss += loss.item() * args.accumulation_steps
+            total_samples += 1
 
             # 打印日志
             if step % args.log_interval == 0:
                 spend_time = time.time() - start_time
+                avg_loss = total_loss / total_samples if total_samples > 0 else 0
                 Logger(
                     'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                         epoch + 1,
                         args.epochs,
                         step,
                         iter_per_epoch,
-                        loss.item() * args.accumulation_steps,
+                        avg_loss,
                         optimizer.param_groups[-1]['lr'],
                         spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
 
                 if (wandb is not None) and (not ddp or dist.get_rank() == 0):
-                    wandb.log({"loss": loss.item() * args.accumulation_steps,
+                    wandb.log({"loss": avg_loss,
                                "lr": optimizer.param_groups[-1]['lr'],
                                "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
 
-            # 保存模型
-            if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
-                model.eval()
-                # Modified checkpoint path and content
-                ckp = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}.pth'
-                
-                if isinstance(model, torch.nn.parallel.DistributedDataParallel):
-                    embedding_state_dict = model.module.tok_embeddings.state_dict()
-                else:
-                    embedding_state_dict = model.tok_embeddings.state_dict()
-
-                torch.save(embedding_state_dict, ckp)
-                Logger(f"Saved pretrained embedding to {ckp}")
-                model.train()
-
         except Exception as e:
             print(f"Error occurred: {str(e)}")
+            import traceback
+            traceback.print_exc()
             # Modified checkpoint path for error
-            save_path = f'{args.save_dir}/pretrained_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
-            if  os.path.exists(save_path):
+            save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
+            if os.path.exists(save_path):
                 os.remove(save_path)
             
             if isinstance(model, torch.nn.parallel.DistributedDataParallel):
-                state_dict = model.module.tok_embeddings.state_dict()
+                state_dict = model.module.embeddings.state_dict()
             else:
-                state_dict = model.tok_embeddings.state_dict()
-            torch.save(state_dict, save_path) # Save embedding state dict on error
+                state_dict = model.embeddings.state_dict()
+            torch.save(state_dict, save_path)
 
             for name, param in model.named_parameters():
                 if param.grad is not None and torch.isnan(param.grad).any():
@@ -156,9 +240,23 @@ def train_epoch(epoch, wandb):
                     print(f"Parameter {name} gradients: {param.grad}")
             
             raise ValueError("NaN gradient detected")
+    
+    # Save model once at the end of each epoch
+    if not ddp or dist.get_rank() == 0:
+        model.eval()
+        ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
+        
+        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+            embedding_state_dict = model.module.embeddings.state_dict()
+        else:
+            embedding_state_dict = model.embeddings.state_dict()
+
+        torch.save(embedding_state_dict, ckp)
+        Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
+        model.train()
 
 
-def init_model(lm_config_params: LMConfig): # Renamed for clarity
+def init_model(lm_config_params: LMConfig):
     # 加载tokenizer
     tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
     # Update vocab_size in lm_config if tokenizer has a different one
@@ -166,10 +264,10 @@ def init_model(lm_config_params: LMConfig): # Renamed for clarity
         Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
         lm_config_params.vocab_size = tokenizer.vocab_size
 
-    # 加载模型
-    model = EmbeddingPretrainer(lm_config_params).to(args.device) # Use EmbeddingPretrainer
+    # 加载word2vec CBOW模型
+    model = CBOWModel(lm_config_params).to(args.device)
     # 打印模型参数
-    Logger(f'EmbeddingPretrainer total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
+    Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
     return model, tokenizer
 
 
@@ -187,32 +285,27 @@ def init_distributed_mode():
 
 # torchrun --nproc_per_node 2 train_embedding.py
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="MiniMind Embedding Pretraining") # Changed description
-    parser.add_argument("--out_dir", type=str, default="out_embedding") # Changed default out_dir
-    # 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
+    parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
+    parser.add_argument("--out_dir", type=str, default="out_word2vec")
     parser.add_argument("--epochs", type=int, default=3)
-    parser.add_argument("--batch_size", type=int, default=32) # Smaller batch size might be needed if memory is an issue
+    parser.add_argument("--batch_size", type=int, default=256)
     parser.add_argument("--learning_rate", type=float, default=5e-4)
-    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
+    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
     parser.add_argument("--dtype", type=str, default="bfloat16")
     parser.add_argument("--use_wandb", default=False, action="store_true")
-    parser.add_argument("--wandb_project", type=str, default="MiniMind-Embedding-Pretrain") # Changed project name
-    parser.add_argument("--num_workers", type=int, default=8)
+    parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
+    parser.add_argument("--num_workers", type=int, default=32)
     parser.add_argument("--ddp", action="store_true")
-    parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。
-    parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
-    # parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。 (Can be kept or removed)
-    parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
-    parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
-    parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
-    parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
-    # Removed n_layers, use_moe as they are not relevant for EmbeddingPretrainer
-    # parser.add_argument('--n_layers', default=8, type=int) 
-    parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
-    # parser.add_argument('--use_moe', default=False, type=bool) 
-    parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
-    # Add vocab_size to args, though it will be overridden by tokenizer if different
+    parser.add_argument("--accumulation_steps", type=int, default=8)
+    parser.add_argument("--grad_clip", type=float, default=1.0)
+    parser.add_argument("--log_interval", type=int, default=100)
+    parser.add_argument("--save_interval", type=int, default=100)
+    parser.add_argument('--local_rank', type=int, default=-1)
+    parser.add_argument('--dim', default=768, type=int)
+    parser.add_argument('--max_seq_len', default=512, type=int)
+    parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
     parser.add_argument('--vocab_size', default=6400, type=int)
+    parser.add_argument('--window_size', default=5, type=int)
 
 
     args = parser.parse_args()
@@ -222,24 +315,21 @@ if __name__ == "__main__":
         dim=args.dim, 
         vocab_size=args.vocab_size, # Will be updated by tokenizer
         max_seq_len=args.max_seq_len,
-        # n_layers, n_heads, etc. are not directly used by EmbeddingPretrainer but LMConfig requires them
-        # We can set them to default or minimal values if they cause issues, or modify LMConfig
-        # For now, using defaults from LMConfig definition for unneeded params.
         n_layers=1, # Minimal
         n_heads=1, # Minimal
         n_kv_heads=1 #Minimal
     ) 
-    args.save_dir = os.path.join(args.out_dir) #创建保存目录。
-    os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
-    os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。
-    tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。
+    args.save_dir = os.path.join(args.out_dir)
+    os.makedirs(args.save_dir, exist_ok=True)
+    os.makedirs(args.out_dir, exist_ok=True)
+    tokens_per_iter = args.batch_size * lm_config.max_seq_len
     print(f"tokens_per_iter: {tokens_per_iter}")
-    device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
+    device_type = "cuda" if "cuda" in args.device else "cpu"
 
     # Determine the torch dtype
     pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
 
-    args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}" # Updated run name
+    args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
 
     ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
 
@@ -270,39 +360,59 @@ if __name__ == "__main__":
     # Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
     if lm_config.vocab_size != tokenizer.vocab_size:
         lm_config.vocab_size = tokenizer.vocab_size
-        args.wandb_run_name = f"MiniMind-Embedding-Pretrain-Dim-{args.dim}-Vocab-{lm_config.vocab_size}"
+        args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
         if wandb is not None and (not ddp or dist.get_rank() == 0):
             wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
 
+    # 添加collate函数处理不同长度的序列
+    def collate_cbow_batch(batch):
+        # 提取context和target
+        contexts, targets = zip(*batch)
+        
+        # 获取当前批次中最长的context长度
+        max_len = max([ctx.size(0) for ctx in contexts])
+        
+        # 创建填充后的tensor
+        padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
+        
+        # 填充每个context
+        for i, ctx in enumerate(contexts):
+            ctx_len = ctx.size(0)
+            padded_contexts[i, :ctx_len] = ctx
+        
+        # 将targets stack成一个tensor
+        stacked_targets = torch.stack(targets)
+        
+        return padded_contexts, stacked_targets
 
-    train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
-    train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None # Added shuffle and seed
+    # Create Word2Vec CBOW dataset
+    train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
+    train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
     train_loader = DataLoader(
         train_ds,
         batch_size=args.batch_size,
         pin_memory=True,
-        drop_last=True, # Set to True for more stable training step counts
-        shuffle=(train_sampler is None), # Shuffle only if not using DDP sampler
+        drop_last=True,
+        shuffle=(train_sampler is None),
         num_workers=args.num_workers,
-        sampler=train_sampler
+        sampler=train_sampler,
+        collate_fn=collate_cbow_batch
     )
 
-    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) # bfloat16 also uses scaler
+    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
     optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
 
     if ddp:
-        # model._ddp_params_and_buffers_to_ignore = {"pos_cis"} # Not relevant for EmbeddingPretrainer
         model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
         
-    # torch.autograd.set_detect_anomaly(True) # Can be enabled for debugging
     iter_per_epoch = len(train_loader)
-    Logger(f"Starting training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
+    Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
     for epoch in range(args.epochs):
         if ddp:
-            train_sampler.set_epoch(epoch) # Important for DDP shuffling
+            train_sampler.set_epoch(epoch)
         train_epoch(epoch, wandb)
 
-    if wandb is not None and (not ddp or dist.get_rank() == 0) :
+    if wandb is not None and (not ddp or dist.get_rank() == 0):
         wandb.finish()
     
-    Logger("Embedding pretraining finished.") 
\ No newline at end of file
+    Logger("Word2Vec embedding training finished.") 
\ No newline at end of file