import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline"  # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext
import random
import numpy as np
import json

from transformers import AutoTokenizer

# Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig
# from model.dataset import PretrainDataset

warnings.filterwarnings('ignore')


# Define a Word2Vec-style CBOW model
class CBOWModel(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.embedding_dim = config.dim
        
        # Input embeddings (context words)
        self.embeddings = nn.Embedding(config.vocab_size, config.dim)
        
        # Output weights for target prediction
        self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
        
        # Initialize weights
        self.init_weights()
        
    def init_weights(self):
        # Xavier initialization for better convergence
        nn.init.xavier_uniform_(self.embeddings.weight)
        nn.init.xavier_uniform_(self.output_weights.weight)
    
    def forward(self, context_words):
        # context_words shape: [batch_size, context_size],context_size可变
        
        # Get embeddings for all context words
        embeds = self.embeddings(context_words)  # [batch_size, context_size, embedding_dim]
        
        # Average the context word embeddings along context dimension
        embeds = torch.mean(embeds, dim=1)  # [batch_size, embedding_dim]
        
        # Predict the target word
        output = self.output_weights(embeds)  # [batch_size, vocab_size]
        
        return output


# Word2Vec CBOW dataset
class CBOWDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
        super().__init__()
        self.tokenizer = tokenizer
        self.window_size = window_size
        self.max_length = max_length
        self.samples = self.load_data(data_path)
        
    def load_data(self, path):
        samples = []
        with open(path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                data = json.loads(line.strip())
                samples.append(data)
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        sample = self.samples[index]
        
        # 构建输入文本
        text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # 获取token ids
        input_ids = encoding.input_ids.squeeze()
        # 过滤掉padding
        attention_mask = encoding.attention_mask.squeeze()
        valid_indices = torch.where(attention_mask == 1)[0]
        valid_input_ids = input_ids[valid_indices]
        
        # 确保有足够的token进行CBOW训练
        if len(valid_input_ids) <= 2 * self.window_size + 1:
            # 如果token不足,随机选择一个不同的样本
            return self.__getitem__(random.randint(0, len(self.samples) - 1))
        
        # 随机选择一个中心位置(不包括首尾的特殊token)
        # 确保中心位置两边都有至少window_size个token
        min_center_pos = self.window_size + 1  # 避开起始token
        max_center_pos = len(valid_input_ids) - self.window_size - 1  # 避开结束token
        
        if max_center_pos <= min_center_pos:
            return self.__getitem__(random.randint(0, len(self.samples) - 1))
            
        center_pos = random.randint(min_center_pos, max_center_pos)
        
        # 目标词(中心词)
        target = valid_input_ids[center_pos].unsqueeze(0)
        
        # 上下文词(中心词前后的词)
        context = torch.cat([
            valid_input_ids[center_pos - self.window_size:center_pos],
            valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
        ])
        
        return context, target


def Logger(content):
    # 如果没有使用ddp或者ddp的主设备,那么就打印
    if not ddp or dist.get_rank() == 0:
        print(content)


def get_lr(current_step, total_steps, lr):
    # 更新学习率
    # \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))


def train_epoch(epoch, wandb):
    loss_fct = nn.CrossEntropyLoss()
    start_time = time.time()
    total_loss = 0
    total_samples = 0
    
    for step, (context, target) in enumerate(train_loader):
        try:
            # 将数据加载到设备上
            context = context.to(args.device)
            target = target.to(args.device)

            # 更新学习率
            lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            with ctx:
                # Forward pass
                logits = model(context)  # [batch_size, vocab_size]
                # target是[batch_size, 1],需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
                loss = loss_fct(logits, target.squeeze())
                loss = loss / args.accumulation_steps
            
            # Print data types for debugging
            if step == 0 and (not ddp or dist.get_rank() == 0):
                Logger("---- Data Type Check ----")
                Logger(f"context.dtype: {context.dtype}")
                Logger(f"context.shape: {context.shape}")
                Logger(f"target.dtype: {target.dtype}")
                Logger(f"target.shape: {target.shape}")
                if hasattr(model, 'module'):  # DDP case
                    Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
                else:  # Non-DDP case
                    Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
                Logger(f"logits.dtype: {logits.dtype}")
                Logger(f"logits.shape: {logits.shape}")
                Logger(f"loss.dtype: {loss.dtype}")
                Logger("-------------------------")

            scaler.scale(loss).backward()

            if (step + 1) % args.accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

                scaler.step(optimizer)
                scaler.update()

                optimizer.zero_grad(set_to_none=True)
            
            total_loss += loss.item() * args.accumulation_steps
            total_samples += 1

            # 打印日志
            if step % args.log_interval == 0:
                spend_time = time.time() - start_time
                avg_loss = total_loss / total_samples if total_samples > 0 else 0
                Logger(
                    'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                        epoch + 1,
                        args.epochs,
                        step,
                        iter_per_epoch,
                        avg_loss,
                        optimizer.param_groups[-1]['lr'],
                        spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))

                if (wandb is not None) and (not ddp or dist.get_rank() == 0):
                    wandb.log({"loss": avg_loss,
                               "lr": optimizer.param_groups[-1]['lr'],
                               "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

        except Exception as e:
            print(f"Error occurred: {str(e)}")
            import traceback
            traceback.print_exc()
            # Modified checkpoint path for error
            save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
            if os.path.exists(save_path):
                os.remove(save_path)
            
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                state_dict = model.module.embeddings.state_dict()
            else:
                state_dict = model.embeddings.state_dict()
            torch.save(state_dict, save_path)

            for name, param in model.named_parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    print(f"NaN gradient in parameter: {name}")
            
            for name, param in model.named_parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    print(f"Parameter {name} values: {param.data}")
                    print(f"Parameter {name} gradients: {param.grad}")
            
            raise ValueError("NaN gradient detected")
    
    # Save model once at the end of each epoch
    if not ddp or dist.get_rank() == 0:
        model.eval()
        ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
        
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            embedding_state_dict = model.module.embeddings.state_dict()
        else:
            embedding_state_dict = model.embeddings.state_dict()

        torch.save(embedding_state_dict, ckp)
        Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
        model.train()


def init_model(lm_config_params: LMConfig):
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
    # Update vocab_size in lm_config if tokenizer has a different one
    if tokenizer.vocab_size != lm_config_params.vocab_size:
        Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
        lm_config_params.vocab_size = tokenizer.vocab_size

    # 加载word2vec CBOW模型
    model = CBOWModel(lm_config_params).to(args.device)
    # 打印模型参数
    Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
    return model, tokenizer


def init_distributed_mode():
    if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
    global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。

    dist.init_process_group(backend="nccl") #初始化分布式进程组,使用NCCL后端(NVIDIA Collective Communications Library),这是NVIDIA GPU之间通信的优化库。
    ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
    ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
    ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
    DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
    torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。


# torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
    parser.add_argument("--out_dir", type=str, default="out_word2vec")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--learning_rate", type=float, default=5e-4)
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--dtype", type=str, default="bfloat16")
    parser.add_argument("--use_wandb", default=False, action="store_true")
    parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
    parser.add_argument("--num_workers", type=int, default=32)
    parser.add_argument("--ddp", action="store_true")
    parser.add_argument("--accumulation_steps", type=int, default=8)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--log_interval", type=int, default=100)
    parser.add_argument("--save_interval", type=int, default=100)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--dim', default=768, type=int)
    parser.add_argument('--max_seq_len', default=512, type=int)
    parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
    parser.add_argument('--vocab_size', default=6400, type=int)
    parser.add_argument('--window_size', default=5, type=int)


    args = parser.parse_args()

    # Create LMConfig with relevant parameters for embedding
    lm_config = LMConfig(
        dim=args.dim, 
        vocab_size=args.vocab_size, # Will be updated by tokenizer
        max_seq_len=args.max_seq_len,
        n_layers=1, # Minimal
        n_heads=1, # Minimal
        n_kv_heads=1 #Minimal
    ) 
    args.save_dir = os.path.join(args.out_dir)
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.out_dir, exist_ok=True)
    tokens_per_iter = args.batch_size * lm_config.max_seq_len
    print(f"tokens_per_iter: {tokens_per_iter}")
    device_type = "cuda" if "cuda" in args.device else "cpu"

    # Determine the torch dtype
    pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]

    args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"

    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)

    ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
    ddp_local_rank, DEVICE = 0, "cuda:0" # Default values, will be overwritten in DDP

    base_seed = 1337
    torch.manual_seed(base_seed)
    torch.cuda.manual_seed(base_seed)

    if ddp:
        init_distributed_mode() # This sets DEVICE and ddp_local_rank
        args.device = torch.device(DEVICE) # Ensure args.device is updated
        rank = dist.get_rank()
        torch.manual_seed(base_seed + rank)
        # 同时设置 CUDA 的随机种子
        torch.cuda.manual_seed_all(base_seed + rank) # Use seed_all for DDP

    if args.use_wandb and (not ddp or dist.get_rank() == 0): # Check rank for DDP wandb init
        import wandb

        wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
    else:
        wandb = None

    model, tokenizer = init_model(lm_config) # Pass the lm_config instance
    
    # Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
    if lm_config.vocab_size != tokenizer.vocab_size:
        lm_config.vocab_size = tokenizer.vocab_size
        args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
        if wandb is not None and (not ddp or dist.get_rank() == 0):
            wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)

    # 添加collate函数处理不同长度的序列
    def collate_cbow_batch(batch):
        # 提取context和target
        contexts, targets = zip(*batch)
        
        # 获取当前批次中最长的context长度
        max_len = max([ctx.size(0) for ctx in contexts])
        
        # 创建填充后的tensor
        padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
        
        # 填充每个context
        for i, ctx in enumerate(contexts):
            ctx_len = ctx.size(0)
            padded_contexts[i, :ctx_len] = ctx
        
        # 将targets stack成一个tensor
        stacked_targets = torch.stack(targets)
        
        return padded_contexts, stacked_targets

    # Create Word2Vec CBOW dataset
    train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
    train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        pin_memory=True,
        drop_last=True,
        shuffle=(train_sampler is None),
        num_workers=args.num_workers,
        sampler=train_sampler,
        collate_fn=collate_cbow_batch
    )

    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    if ddp:
        model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
        
    iter_per_epoch = len(train_loader)
    Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
    for epoch in range(args.epochs):
        if ddp:
            train_sampler.set_epoch(epoch)
        train_epoch(epoch, wandb)

    if wandb is not None and (not ddp or dist.get_rank() == 0):
        wandb.finish()
    
    Logger("Word2Vec embedding training finished.")