diff --git a/1-pretrain.py b/1-pretrain.py
index fabb445..ed70c4d 100644
--- a/1-pretrain.py
+++ b/1-pretrain.py
@@ -14,56 +14,62 @@ from model.model import Transformer
 from model.LMConfig import LMConfig
 from model.dataset import PretrainDataset
 
+# 忽略警告信息
 warnings.filterwarnings('ignore')
 
-
+# 定义日志打印函数,仅在主进程(rank 0)打印日志信息
 def Logger(content):
     if not ddp or dist.get_rank() == 0:
         print(content)
 
-
+# 定义学习率调度函数,根据当前迭代次数计算学习率,采用余弦退火策略
 def get_lr(it, all):
-    warmup_iters = 0
-    lr_decay_iters = all
-    min_lr = learning_rate / 10
+    warmup_iters = 0  # 预热迭代次数
+    lr_decay_iters = all  # 学习率衰减的总迭代次数
+    min_lr = learning_rate / 10  # 最小学习率
 
+    # 如果当前迭代次数小于预热迭代次数,使用线性预热策略
     if it < warmup_iters:
         return learning_rate * it / warmup_iters
+    # 如果当前迭代次数大于衰减迭代次数,返回最小学习率
     if it > lr_decay_iters:
         return min_lr
+    # 计算衰减系数,使用余弦退火策略
     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
     assert 0 <= decay_ratio <= 1
     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
     return min_lr + coeff * (learning_rate - min_lr)
 
-
+# 定义训练 epoch 的函数
 def train_epoch(epoch, accumulation_steps=8):
-    start_time = time.time()
-    for step, (X, Y) in enumerate(train_loader):
-        X = X.to(device)
-        Y = Y.to(device)
+    start_time = time.time()  # 记录开始时间
+    for step, (X, Y) in enumerate(train_loader):  # 遍历数据加载器
+        X = X.to(device)  # 将输入数据移动到设备上
+        Y = Y.to(device)  # 将目标数据移动到设备上
 
-        lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
+        lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)  # 计算当前学习率
         for param_group in optimizer.param_groups:
-            param_group['lr'] = lr
+            param_group['lr'] = lr  # 设置优化器的学习率
 
-        with ctx:
-            out = model(X, Y)
-            loss = out.last_loss / accumulation_steps
+        with ctx:  # 使用混合精度训练(如果设备是 GPU)
+            out = model(X, Y)  # 前向传播,计算输出
+            loss = out.last_loss / accumulation_steps  # 计算损失,并进行梯度累积
 
-        scaler.scale(loss).backward()
+        scaler.scale(loss).backward()  # 反向传播,计算梯度
 
+        # 每 accumulation_steps 步进行一次梯度更新
         if (step + 1) % accumulation_steps == 0:
-            scaler.unscale_(optimizer)
-            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+            scaler.unscale_(optimizer)  # 反缩放梯度
+            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪
 
-            scaler.step(optimizer)
-            scaler.update()
+            scaler.step(optimizer)  # 更新模型参数
+            scaler.update()  # 更新缩放器
 
-            optimizer.zero_grad(set_to_none=True)
+            optimizer.zero_grad(set_to_none=True)  # 清空梯度
 
+        # 每 100 步打印一次训练信息
         if step % 100 == 0:
-            spend_time = time.time() - start_time
+            spend_time = time.time() - start_time  # 计算已用时间
             Logger(
                 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                     epoch,
@@ -74,26 +80,27 @@ def train_epoch(epoch, accumulation_steps=8):
                     optimizer.param_groups[-1]['lr'],
                     spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
 
+        # 每 1000 步保存一次模型
         if (step + 1) % 1000 == 0 and (not ddp or dist.get_rank() == 0):
-            model.eval()
+            model.eval()  # 切换到评估模式
             # torch.save(model.state_dict(), '{}/iter_{}.pth'.format(save_dir, int(step + epoch * iter_per_epoch)))
-            moe_path = '_moe' if lm_config.use_moe else ''
+            moe_path = '_moe' if lm_config.use_moe else ''  # 根据是否使用 MoE 设置保存路径
             ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
 
             if isinstance(model, torch.nn.parallel.DistributedDataParallel):
-                state_dict = model.module.state_dict()
+                state_dict = model.module.state_dict()  # 获取模型状态字典
             else:
                 state_dict = model.state_dict()
 
-            torch.save(state_dict, ckp)
-            model.train()
-
+            torch.save(state_dict, ckp)  # 保存模型
+            model.train()  # 切换回训练模式
 
+# 定义初始化模型的函数
 def init_model():
     def count_parameters(model):
-        return sum(p.numel() for p in model.parameters() if p.requires_grad)
+        return sum(p.numel() for p in model.parameters() if p.requires_grad)  # 计算模型可训练参数的数量
 
-    # model init
+    # 初始化模型
     model = Transformer(lm_config).to(device)
     moe_path = '_moe' if lm_config.use_moe else ''
     # ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
@@ -105,57 +112,57 @@ def init_model():
     #         state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
     # model.load_state_dict(state_dict, strict=False)
 
-    Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
+    Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')  # 打印模型总参数量
     return model
 
-
+# 定义初始化分布式训练环境的函数
 def init_distributed_mode():
     if not ddp: return
     global ddp_local_rank, DEVICE
 
-    dist.init_process_group(backend="nccl")
-    ddp_rank = int(os.environ["RANK"])
-    ddp_local_rank = int(os.environ["LOCAL_RANK"])
-    ddp_world_size = int(os.environ["WORLD_SIZE"])
-    DEVICE = f"cuda:{ddp_local_rank}"
-    torch.cuda.set_device(DEVICE)
+    dist.init_process_group(backend="nccl")  # 初始化分布式进程组,使用 NCCL 后端
+    ddp_rank = int(os.environ["RANK"])  # 获取当前进程的 rank
+    ddp_local_rank = int(os.environ["LOCAL_RANK"])  # 获取当前进程的本地 rank
+    ddp_world_size = int(os.environ["WORLD_SIZE"])  # 获取分布式训练的总进程数
+    DEVICE = f"cuda:{ddp_local_rank}"  # 设置当前设备的 CUDA 设备
+    torch.cuda.set_device(DEVICE)  # 设置当前设备的 CUDA 设备
 
 
 # torchrun --nproc_per_node 2 1-pretrain.py
 # I/O
 if __name__ == "__main__":
     # -----------------------------------------------------------------------------
-    lm_config = LMConfig()
-    max_seq_len = lm_config.max_seq_len
-    out_dir = 'out'
-    epochs = 20
-    batch_size = 64
-    learning_rate = 2e-4
-    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-    dtype = 'bfloat16'
-    save_dir = os.path.join(out_dir)
-    os.makedirs(save_dir, exist_ok=True)
-    os.makedirs(out_dir, exist_ok=True)
-    tokens_per_iter = batch_size * max_seq_len
-    torch.manual_seed(1337)
-    device_type = device if "cuda" in device else "cpu"
+    lm_config = LMConfig()  # 加载配置文件
+    max_seq_len = lm_config.max_seq_len  # 获取最大序列长度
+    out_dir = 'out'  # 设置输出目录
+    epochs = 20  # 设置训练 epoch 数
+    batch_size = 64  # 设置批量大小
+    learning_rate = 2e-4  # 设置初始学习率
+    device = 'cuda:0'  # 设置设备为 CUDA:0
+    dtype = 'bfloat16'  # 设置数据类型为 bfloat16
+    save_dir = os.path.join(out_dir)  # 设置模型保存目录
+    os.makedirs(save_dir, exist_ok=True)  # 创建模型保存目录
+    os.makedirs(out_dir, exist_ok=True)  # 创建输出目录
+    tokens_per_iter = batch_size * max_seq_len  # 计算每个迭代处理的 token 数量
+    torch.manual_seed(1337)  # 设置随机种子
+    device_type = device if "cuda" in device else "cpu"  # 设置设备类型
     ctx = (
-        nullcontext()
+        nullcontext()  # 如果设备是 CPU,使用 nullcontext
         if device_type == "cpu"
-        else torch.cuda.amp.autocast()
+        else torch.cuda.amp.autocast()  # 如果设备是 GPU,使用混合精度训练
     )
-    ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
-    ddp_local_rank, DEVICE = 0, "cuda:0"
+    ddp = int(os.environ.get("RANK", -1)) != -1  # 判断是否为分布式训练
+    ddp_local_rank, DEVICE = 0, "cuda:0"  # 初始化分布式训练的本地 rank 和设备
     if ddp:
-        init_distributed_mode()
-        device = torch.device(DEVICE)
+        init_distributed_mode()  # 初始化分布式训练环境
+        device = torch.device(DEVICE)  # 设置设备
     # -----------------------------------------------------------------------------
 
     # -----init dataloader------
-    data_path_list = ['./dataset/pretrain_data.bin']
-    train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
-    train_sampler = DistributedSampler(train_ds) if ddp else None
-    num_workers = 8  # 可以根据系统的 CPU 核心数来调整
+    data_path_list = ['./dataset/pretrain_data.bin']  # 设置数据路径
+    train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)  # 初始化数据集
+    train_sampler = DistributedSampler(train_ds) if ddp else None  # 如果是分布式训练,使用分布式采样器
+    num_workers = 8  # 设置数据加载器的 num_workers
     train_loader = DataLoader(
         train_ds,
         batch_size=batch_size,
@@ -164,27 +171,27 @@ if __name__ == "__main__":
         shuffle=False,
         num_workers=num_workers,
         sampler=train_sampler
-    )
+    )  # 初始化数据加载器
 
     # init model
-    model = init_model()
+    model = init_model()  # 初始化模型
 
-    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
+    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))  # 初始化梯度缩放器
     # optimizer
-    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+    optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 初始化优化器
     # compile the model
     if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
         Logger("compiling the model... (takes a ~minute)")
         unoptimized_model = model
-        model = torch.compile(model)
+        model = torch.compile(model)  # 编译模型(如果条件满足)
 
     if ddp:
         # Ignore the freqs_cis buffer so that DDP does not broadcast it at
         # construction time since NCCL does not support ComplexFloat
-        model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
-        model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
+        model._ddp_params_and_buffers_to_ignore = {"pos_cis"}  # 设置 DDP 忽略的参数和缓冲区
+        model = DistributedDataParallel(model, device_ids=[ddp_local_rank])  # 使用 DDP 包装模型
 
     # training loop
-    iter_per_epoch = len(train_loader)
-    for epoch in range(epochs):
-        train_epoch(epoch)
+    iter_per_epoch = len(train_loader)  # 计算每个 epoch 的迭代次数
+    for epoch in range(epochs):  # 遍历每个 epoch
+        train_epoch(epoch)  # 训练一个 epoch
\ No newline at end of file
diff --git a/model/LMConfig.py b/model/LMConfig.py
index bf0e4b9..f216b48 100644
--- a/model/LMConfig.py
+++ b/model/LMConfig.py
@@ -1,58 +1,58 @@
 from transformers import PretrainedConfig
 from typing import List
 
-
+# 定义 LMConfig 类,继承自 PretrainedConfig
 class LMConfig(PretrainedConfig):
-    model_type = "minimind"
+    model_type = "minimind"  # 设置模型类型为 "minimind"
 
     def __init__(
             self,
-            dim: int = 512,
-            n_layers: int = 8,
-            n_heads: int = 16,
-            n_kv_heads: int = 8,
-            vocab_size: int = 6400,
-            hidden_dim: int = None,
-            multiple_of: int = 64,
-            norm_eps: float = 1e-5,
-            max_seq_len: int = 512,
-            dropout: float = 0.0,
-            flash_attn: bool = True,
+            dim: int = 512,  # 模型维度,默认为 512
+            n_layers: int = 8,  # Transformer 层数,默认为 8
+            n_heads: int = 16,  # 注意力头数,默认为 16
+            n_kv_heads: int = 8,  # KV 头数,默认为 8
+            vocab_size: int = 6400,  # 词汇表大小,默认为 6400
+            hidden_dim: int = None,  # 隐藏层维度,默认为 None
+            multiple_of: int = 64,  # 隐藏层维度的倍数,默认为 64
+            norm_eps: float = 1e-5,  # 归一化层的 epsilon 值,默认为 1e-5
+            max_seq_len: int = 512,  # 最大序列长度,默认为 512
+            dropout: float = 0.0,  # Dropout 概率,默认为 0.0
+            flash_attn: bool = True,  # 是否使用 Flash Attention,默认为 True
             ####################################################
-            # Here are the specific configurations of MOE
-            # When use_moe is false, the following is invalid
+            # 以下是 MOE(Mixture of Experts)的特定配置
+            # 当 use_moe 为 False 时,以下配置无效
             ####################################################
-            use_moe: bool = False,
-            num_experts_per_tok=2,
-            n_routed_experts=4,
-            n_shared_experts: bool = True,
-            scoring_func='softmax',
-            aux_loss_alpha=0.01,
-            seq_aux=True,
-            norm_topk_prob=True,
+            use_moe: bool = False,  # 是否使用 MOE,默认为 False
+            num_experts_per_tok=2,  # 每个 token 选择的专家数量,默认为 2
+            n_routed_experts=4,  # 总的专家数量,默认为 4
+            n_shared_experts: bool = True,  # 是否使用共享专家,默认为 True
+            scoring_func='softmax',  # 评分函数,默认为 'softmax'
+            aux_loss_alpha=0.01,  # 辅助损失的 alpha 参数,默认为 0.01
+            seq_aux=True,  # 是否在序列级别上计算辅助损失,默认为 True
+            norm_topk_prob=True,  # 是否标准化 top-k 概率,默认为 True
             **kwargs,
     ):
-        self.dim = dim
-        self.n_layers = n_layers
-        self.n_heads = n_heads
-        self.n_kv_heads = n_kv_heads
-        self.vocab_size = vocab_size
-        self.hidden_dim = hidden_dim
-        self.multiple_of = multiple_of
-        self.norm_eps = norm_eps
-        self.max_seq_len = max_seq_len
-        self.dropout = dropout
-        self.flash_attn = flash_attn
+        self.dim = dim  # 设置模型维度
+        self.n_layers = n_layers  # 设置 Transformer 层数
+        self.n_heads = n_heads  # 设置注意力头数
+        self.n_kv_heads = n_kv_heads  # 设置 KV 头数
+        self.vocab_size = vocab_size  # 设置词汇表大小
+        self.hidden_dim = hidden_dim  # 设置隐藏层维度
+        self.multiple_of = multiple_of  # 设置隐藏层维度的倍数
+        self.norm_eps = norm_eps  # 设置归一化层的 epsilon 值
+        self.max_seq_len = max_seq_len  # 设置最大序列长度
+        self.dropout = dropout  # 设置 Dropout 概率
+        self.flash_attn = flash_attn  # 设置是否使用 Flash Attention
         ####################################################
-        # Here are the specific configurations of MOE
-        # When use_moe is false, the following is invalid
+        # 以下是 MOE(Mixture of Experts)的特定配置
+        # 当 use_moe 为 False 时,以下配置无效
         ####################################################
-        self.use_moe = use_moe
-        self.num_experts_per_tok = num_experts_per_tok  # 每个token选择的专家数量
-        self.n_routed_experts = n_routed_experts  # 总的专家数量
-        self.n_shared_experts = n_shared_experts  # 共享专家
-        self.scoring_func = scoring_func  # 评分函数,默认为'softmax'
-        self.aux_loss_alpha = aux_loss_alpha  # 辅助损失的alpha参数
-        self.seq_aux = seq_aux  # 是否在序列级别上计算辅助损失
-        self.norm_topk_prob = norm_topk_prob  # 是否标准化top-k概率
-        super().__init__(**kwargs)
+        self.use_moe = use_moe  # 设置是否使用 MOE
+        self.num_experts_per_tok = num_experts_per_tok  # 设置每个 token 选择的专家数量
+        self.n_routed_experts = n_routed_experts  # 设置总的专家数量
+        self.n_shared_experts = n_shared_experts  # 设置是否使用共享专家
+        self.scoring_func = scoring_func  # 设置评分函数
+        self.aux_loss_alpha = aux_loss_alpha  # 设置辅助损失的 alpha 参数
+        self.seq_aux = seq_aux  # 设置是否在序列级别上计算辅助损失
+        self.norm_topk_prob = norm_topk_prob  # 设置是否标准化 top-k 概率
+        super().__init__(**kwargs)  # 调用父类 PretrainedConfig 的初始化方法
\ No newline at end of file
diff --git a/model/dataset.py b/model/dataset.py
index ef58956..82c9be8 100644
--- a/model/dataset.py
+++ b/model/dataset.py
@@ -9,79 +9,79 @@ import torch
 from sklearn.model_selection import train_test_split
 import os
 
-os.environ["TOKENIZERS_PARALLELISM"] = "false"
-
+os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 禁用 tokenizer 的并行处理
 
+# 定义 PretrainDataset 类,继承自 Dataset
 class PretrainDataset(Dataset):
     def __init__(self, data_path_lst, max_length=512, memmap=False):
         super().__init__()
-        #
+        # 如果使用内存映射(memmap)
         if memmap:
             with open(data_path_lst[0], 'r') as f:
-                nbytes = f.seek(0, 2)
-                flen = f.tell() // np.dtype('uint16').itemsize
-            self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))
+                nbytes = f.seek(0, 2)  # 获取文件总字节数
+                flen = f.tell() // np.dtype('uint16').itemsize  # 计算文件长度
+            self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))  # 使用内存映射加载数据
         else:
             data_lst = []
             for data_path in data_path_lst:
                 with open(data_path, 'rb') as f:
-                    data = np.fromfile(f, dtype=np.uint16)
+                    data = np.fromfile(f, dtype=np.uint16)  # 从文件中读取数据
                     data_lst.append(data)
-            data = np.concatenate(data_lst)
-            data = data[:max_length * int(len(data) / max_length)]
-            # np.random.shuffle(data)
-            self.data = data.reshape(-1, max_length)
-        #
+            data = np.concatenate(data_lst)  # 合并所有数据
+            data = data[:max_length * int(len(data) / max_length)]  # 截取数据
+            # np.random.shuffle(data)  # 打乱数据(注释掉了)
+            self.data = data.reshape(-1, max_length)  # 将数据重塑为 (样本数, 最大长度) 的形状
+        # 打印数据形状
         print("memmap:{} train data.shape:{}".format(memmap, self.data.shape))
         print("downloading finished.....")
 
     def __len__(self):
-        return self.data.shape[0]
+        return self.data.shape[0]  # 返回数据集的长度
 
     def __getitem__(self, index: int):
-        #
+        # 获取指定索引的样本
         sample = self.data[index]
-        X = np.array(sample[:-1]).astype(np.int64)
-        Y = np.array(sample[1:]).astype(np.int64)
-
-        return torch.from_numpy(X), torch.from_numpy(Y)
+        X = np.array(sample[:-1]).astype(np.int64)  # 输入数据(去掉最后一个 token)
+        Y = np.array(sample[1:]).astype(np.int64)  # 目标数据(去掉第一个 token)
 
+        return torch.from_numpy(X), torch.from_numpy(Y)  # 返回 PyTorch 张量
 
+# 定义 SFTDataset 类,继承自 Dataset
 class SFTDataset(Dataset):
     def __init__(self, df, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
         super().__init__()
-        self.df = df
-        self.max_length = max_length
-        self.prompt_max_len = prompt_max_len
-        self.answer_max_len = answer_max_len
+        self.df = df  # 数据框
+        self.max_length = max_length  # 最大序列长度
+        self.prompt_max_len = prompt_max_len  # 提示的最大长度
+        self.answer_max_len = answer_max_len  # 回答的最大长度
         #
-        self.tokenizer = tokenizer
-        self.padding = 0  # self.tokenizer.special_tokens['<pad>']
-        self.bos_id = self.tokenizer('<s>assistant').data['input_ids']
+        self.tokenizer = tokenizer  # 分词器
+        self.padding = 0  # 填充 token ID
+        self.bos_id = self.tokenizer('<s>assistant').data['input_ids']  # 开始 token ID
 
     def __len__(self):
-        return self.df.shape[0]
+        return self.df.shape[0]  # 返回数据集的长度
 
     def find_sublist_index(self, main_list, sub_list) -> int:
         last_index = -1
         for i in range(len(main_list) - len(sub_list) + 1):
             if main_list[i:i + len(sub_list)] == sub_list:
                 last_index = i
-        return last_index
+        return last_index  # 查找子列表在主列表中的最后一个索引
 
     def safe_eval(self, s):
         try:
             res = eval(s)
         except Exception as e:
             return []
-        return res
+        return res  # 安全地执行 eval 函数
 
     def __getitem__(self, index: int):
-        #
+        # 获取指定索引的样本
         sample = self.df.iloc[index]
-        history = self.safe_eval(sample['history'])
-        q = str(sample['q'])
-        a = str(sample['a'])
+        history = self.safe_eval(sample['history'])  # 获取历史对话
+        q = str(sample['q'])  # 获取问题
+        a = str(sample['a'])  # 获取回答
 
         messages = []
         for history_message in history:
@@ -102,29 +102,29 @@ class SFTDataset(Dataset):
             messages,
             tokenize=False,
             add_generation_prompt=True
-        )
-        input_id = self.tokenizer(new_prompt).data['input_ids'][:self.max_length]
+        )  # 生成新的提示
+        input_id = self.tokenizer(new_prompt).data['input_ids'][:self.max_length]  # 分词并截取
 
         # 实际长度
         question_length = self.find_sublist_index(input_id, self.bos_id) + len(self.bos_id)
         # 没满最大长度的剩余部分
         padding_len = self.max_length - len(input_id)
-        input_id = input_id + [self.padding] * padding_len
+        input_id = input_id + [self.padding] * padding_len  # 填充到最大长度
         mask_len = len(input_id) - question_length - padding_len
         # 0表示不计算损失
         loss_mask = [0] * question_length + [1] * (mask_len) + [0] * padding_len
 
         input_id = np.array(input_id)
-        X = np.array(input_id[:-1]).astype(np.int64)
-        Y = np.array(input_id[1:]).astype(np.int64)
-        loss_mask = np.array(loss_mask[1:]).astype(np.int64)
+        X = np.array(input_id[:-1]).astype(np.int64)  # 输入数据(去掉最后一个 token)
+        Y = np.array(input_id[1:]).astype(np.int64)  # 目标数据(去掉第一个 token)
+        loss_mask = np.array(loss_mask[1:]).astype(np.int64)  # 损失掩码
 
         X_tensor = torch.from_numpy(X)
         Y_tensor = torch.from_numpy(Y)
         loss_mask_tensor = torch.from_numpy(loss_mask)
 
-        return X_tensor, Y_tensor, loss_mask_tensor
-
+        return X_tensor, Y_tensor, loss_mask_tensor  # 返回 PyTorch 张量
 
+# 主函数
 if __name__ == "__main__":
-    pass
+    pass
\ No newline at end of file
diff --git a/model/model.py b/model/model.py
index 4901bd7..9716a87 100644
--- a/model/model.py
+++ b/model/model.py
@@ -10,29 +10,29 @@ from torch import nn
 from transformers import PreTrainedModel
 from transformers.modeling_outputs import CausalLMOutputWithPast
 
-
+# 定义 RMSNorm 类,实现一种归一化方法,类似于 LayerNorm,但计算方式不同
 class RMSNorm(torch.nn.Module):
     def __init__(self, dim: int, eps: float):
         super().__init__()
-        self.eps = eps
-        self.weight = nn.Parameter(torch.ones(dim))
+        self.eps = eps  # 设置 epsilon,防止除零错误
+        self.weight = nn.Parameter(torch.ones(dim))  # 初始化权重参数
 
     def _norm(self, x):
-        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  # 计算 RMSNorm
 
     def forward(self, x):
-        output = self._norm(x.float()).type_as(x)
-        return output * self.weight
-
+        output = self._norm(x.float()).type_as(x)  # 应用 RMSNorm
+        return output * self.weight  # 乘以权重参数
 
+# 定义 precompute_pos_cis 函数,用于预计算位置编码的复数形式
 def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
-    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
-    t = torch.arange(end, device=freqs.device)  # type: ignore
-    freqs = torch.outer(t, freqs).float()  # type: ignore
-    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # 计算频率
+    t = torch.arange(end, device=freqs.device)  # 生成时间序列
+    freqs = torch.outer(t, freqs).float()  # 计算外积
+    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # 计算复数形式的位置编码
     return pos_cis
 
-
+# 定义 apply_rotary_emb 函数,用于应用旋转位置编码
 def apply_rotary_emb(xq, xk, pos_cis):
     def unite_shape(pos_cis, x):
         ndim = x.ndim
@@ -41,14 +41,14 @@ def apply_rotary_emb(xq, xk, pos_cis):
         shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         return pos_cis.view(*shape)
 
-    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
-    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
-    pos_cis = unite_shape(pos_cis, xq_)
-    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
-    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
-    return xq_out.type_as(xq), xk_out.type_as(xk)
-
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # 将 xq 转换为复数形式
+    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # 将 xk 转换为复数形式
+    pos_cis = unite_shape(pos_cis, xq_)  # 调整 pos_cis 的形状
+    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)  # 应用旋转位置编码
+    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)  # 应用旋转位置编码
+    return xq_out.type_as(xq), xk_out.type_as(xk)  # 返回结果
 
+# 定义 repeat_kv 函数,用于重复 KV 头的值
 def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
     """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
     bs, slen, n_kv_heads, head_dim = x.shape
@@ -60,130 +60,130 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
         .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
     )
 
-
+# 定义 Attention 类,实现自注意力机制
 class Attention(nn.Module):
     def __init__(self, args: LMConfig):
         super().__init__()
-        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
-        assert args.n_heads % self.n_kv_heads == 0
-        self.n_local_heads = args.n_heads
-        self.n_local_kv_heads = self.n_kv_heads
-        self.n_rep = self.n_local_heads // self.n_local_kv_heads
-        self.head_dim = args.dim // args.n_heads
-        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
-        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
-        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
-        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
-        self.k_cache, self.v_cache = None, None
-        self.attn_dropout = nn.Dropout(args.dropout)
-        self.resid_dropout = nn.Dropout(args.dropout)
-        self.dropout = args.dropout
-        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
+        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads  # 设置 KV 头的数量
+        assert args.n_heads % self.n_kv_heads == 0  # 确保 KV 头的数量是总头数的因数
+        self.n_local_heads = args.n_heads  # 设置本地头的数量
+        self.n_local_kv_heads = self.n_kv_heads  # 设置本地 KV 头的数量
+        self.n_rep = self.n_local_heads // self.n_local_kv_heads  # 计算重复次数
+        self.head_dim = args.dim // args.n_heads  # 计算每个头的维度
+        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)  # 初始化 Q 矩阵
+        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)  # 初始化 K 矩阵
+        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)  # 初始化 V 矩阵
+        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)  # 初始化输出矩阵
+        self.k_cache, self.v_cache = None, None  # 初始化 KV 缓存
+        self.attn_dropout = nn.Dropout(args.dropout)  # 初始化注意力 dropout
+        self.resid_dropout = nn.Dropout(args.dropout)  # 初始化残差 dropout
+        self.dropout = args.dropout  # 设置 dropout 概率
+        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn  # 判断是否使用 Flash Attention
 
         if not self.flash:
             # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
-            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
-            mask = torch.triu(mask, diagonal=1)
-            self.register_buffer("mask", mask)
+            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))  # 初始化掩码
+            mask = torch.triu(mask, diagonal=1)  # 生成上三角掩码
+            self.register_buffer("mask", mask)  # 注册掩码
 
     def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
         bsz, seqlen, _ = x.shape
-        if use_kv_cache and self.eval():
+        if use_kv_cache and self.eval():  # 如果使用 KV 缓存且在评估模式下
             if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
-                xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+                xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  # 计算 Q, K, V
             else:
-                token = x[:, -1:, :]
-                xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
-                xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
-                xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
+                token = x[:, -1:, :]  # 获取最后一个 token
+                xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)  # 更新 Q
+                xk = torch.cat((self.k_cache, self.wk(token)), dim=1)  # 更新 K
+                xv = torch.cat((self.v_cache, self.wv(token)), dim=1)  # 更新 V
 
-            self.k_cache, self.v_cache = xk, xv
+            self.k_cache, self.v_cache = xk, xv  # 更新 KV 缓存
         else:
-            xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+            xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  # 计算 Q, K, V
 
-        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
-        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
-        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  # 调整 Q 的形状
+        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)  # 调整 K 的形状
+        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)  # 调整 V 的形状
 
-        xq, xk = apply_rotary_emb(xq, xk, pos_cis)
+        xq, xk = apply_rotary_emb(xq, xk, pos_cis)  # 应用旋转位置编码
 
-        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
-        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
+        xk = repeat_kv(xk, self.n_rep)  # 重复 K 的值
+        xv = repeat_kv(xv, self.n_rep)  # 重复 V 的值
 
-        xq = xq.transpose(1, 2)
-        xk = xk.transpose(1, 2)
-        xv = xv.transpose(1, 2)
+        xq = xq.transpose(1, 2)  # 调整 Q 的形状
+        xk = xk.transpose(1, 2)  # 调整 K 的形状
+        xv = xv.transpose(1, 2)  # 调整 V 的形状
 
         if self.flash:
             output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
                                                                       dropout_p=self.dropout if self.training else 0.0,
-                                                                      is_causal=True)
+                                                                      is_causal=True)  # 使用 Flash Attention
         else:
-            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
+            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)  # 计算注意力分数
             assert hasattr(self, 'mask')
-            scores = scores + self.mask[:, :, :seqlen, :seqlen]  # (bs, n_local_heads, seqlen, cache_len + seqlen)
-            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
-            scores = self.attn_dropout(scores)
-            output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)
+            scores = scores + self.mask[:, :, :seqlen, :seqlen]  # 应用掩码
+            scores = F.softmax(scores.float(), dim=-1).type_as(xq)  # 计算 softmax
+            scores = self.attn_dropout(scores)  # 应用注意力 dropout
+            output = torch.matmul(scores, xv)  # 计算输出
 
-        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
-
-        output = self.wo(output)
-        output = self.resid_dropout(output)
-        return output
+        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)  # 调整输出的形状
 
+        output = self.wo(output)  # 应用输出矩阵
+        output = self.resid_dropout(output)  # 应用残差 dropout
+        return output  # 返回输出
 
+# 定义 FeedForward 类,实现前馈神经网络
 class FeedForward(nn.Module):
     def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
         super().__init__()
         if hidden_dim is None:
-            hidden_dim = 4 * dim
-            hidden_dim = int(2 * hidden_dim / 3)
-            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
-        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
-        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
-        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
-        self.dropout = nn.Dropout(dropout)
+            hidden_dim = 4 * dim  # 设置隐藏层维度
+            hidden_dim = int(2 * hidden_dim / 3)  # 调整隐藏层维度
+            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  # 调整隐藏层维度
+        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # 初始化第一层线性变换
+        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # 初始化第二层线性变换
+        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # 初始化第三层线性变换
+        self.dropout = nn.Dropout(dropout)  # 初始化 dropout
 
     def forward(self, x):
-        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
-
+        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))  # 前向传播
 
+# 定义 MoEGate 类,实现专家混合(MoE)的门控机制
 class MoEGate(nn.Module):
     def __init__(self, config: LMConfig):
         super().__init__()
         self.config = config
-        self.top_k = config.num_experts_per_tok
-        self.n_routed_experts = config.n_routed_experts
+        self.top_k = config.num_experts_per_tok  # 设置每个 token 选择的专家数量
+        self.n_routed_experts = config.n_routed_experts  # 设置路由专家的数量
 
-        self.scoring_func = config.scoring_func
-        self.alpha = config.aux_loss_alpha
-        self.seq_aux = config.seq_aux
+        self.scoring_func = config.scoring_func  # 设置评分函数
+        self.alpha = config.aux_loss_alpha  # 设置辅助损失的权重
+        self.seq_aux = config.seq_aux  # 设置序列辅助损失
 
-        self.norm_topk_prob = config.norm_topk_prob
-        self.gating_dim = config.dim
-        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
-        self.reset_parameters()
+        self.norm_topk_prob = config.norm_topk_prob  # 设置是否归一化 top-k 概率
+        self.gating_dim = config.dim  # 设置门控维度
+        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))  # 初始化权重参数
+        self.reset_parameters()  # 重置参数
 
     def reset_parameters(self) -> None:
         import torch.nn.init as init
-        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+        init.kaiming_uniform_(self.weight, a=math.sqrt(5))  # 使用 Kaiming 初始化权重
 
     def forward(self, hidden_states):
         bsz, seq_len, h = hidden_states.shape
 
-        hidden_states = hidden_states.view(-1, h)
-        logits = F.linear(hidden_states, self.weight, None)
+        hidden_states = hidden_states.view(-1, h)  # 调整隐藏状态的形状
+        logits = F.linear(hidden_states, self.weight, None)  # 计算 logits
         if self.scoring_func == 'softmax':
-            scores = logits.softmax(dim=-1)
+            scores = logits.softmax(dim=-1)  # 计算 softmax 评分
         else:
             raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
 
-        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
+        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)  # 选择 top-k 专家
 
         if self.top_k > 1 and self.norm_topk_prob:
-            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
-            topk_weight = topk_weight / denominator
+            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20  # 计算归一化分母
+            topk_weight = topk_weight / denominator  # 归一化 top-k 概率
 
         if self.training and self.alpha > 0.0:
             scores_for_aux = scores
@@ -204,9 +204,9 @@ class MoEGate(nn.Module):
                 aux_loss = (Pi * fi).sum() * self.alpha
         else:
             aux_loss = None
-        return topk_idx, topk_weight, aux_loss
-
+        return topk_idx, topk_weight, aux_loss  # 返回 top-k 专家索引、权重和辅助损失
 
+# 定义 MOEFeedForward 类,实现专家混合(MoE)的前馈神经网络
 class MOEFeedForward(nn.Module):
     def __init__(self, config: LMConfig):
         super().__init__()
@@ -219,16 +219,16 @@ class MOEFeedForward(nn.Module):
                 dropout=config.dropout,
             )
             for _ in range(config.n_routed_experts)
-        ])
+        ])  # 初始化专家列表
 
-        self.gate = MoEGate(config)
+        self.gate = MoEGate(config)  # 初始化门控机制
         if config.n_shared_experts is not None:
             self.shared_experts = FeedForward(
                 dim=config.dim,
                 hidden_dim=config.hidden_dim,
                 multiple_of=config.multiple_of,
                 dropout=config.dropout,
-            )
+            )  # 初始化共享专家
 
     def forward(self, x):
         identity = x
@@ -281,35 +281,46 @@ class MOEFeedForward(nn.Module):
 
         return expert_cache
 
-
+# 定义 TransformerBlock 类,实现 Transformer 的一个块,包括自注意力和前馈神经网络
 class TransformerBlock(nn.Module):
     def __init__(self, layer_id: int, args: LMConfig):
         super().__init__()
         self.n_heads = args.n_heads
         self.dim = args.dim
         self.head_dim = args.dim // args.n_heads
-        self.attention = Attention(args)
+        self.attention = Attention(args)  # 初始化自注意力机制
 
         self.layer_id = layer_id
-        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
-        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)  # 初始化注意力归一化
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)  # 初始化前馈神经网络归一化
 
         if args.use_moe:
-            self.feed_forward = MOEFeedForward(args)
+            self.feed_forward = MOEFeedForward(args)  # 初始化专家混合前馈神经网络
         else:
             self.feed_forward = FeedForward(
                 dim=args.dim,
                 hidden_dim=args.hidden_dim,
                 multiple_of=args.multiple_of,
                 dropout=args.dropout,
-            )
+            )  # 初始化前馈神经网络
 
     def forward(self, x, pos_cis, use_kv_cache=False):
-        h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
-        out = h + self.feed_forward(self.ffn_norm(h))
-        return out
+        h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)  # 计算自注意力
+        out = h + self.feed_forward(self.ffn_norm(h))  # 计算前馈神经网络
+        return out  # 返回输出
 
+# 定义 Transformer 类,实现整个 Transformer 模型
+class Transformer(PreTrainedModel):
+    config_class = LMConfig
+    last_loss: Optional[torch.Tensor]
 
+    def __init__(self, params: LMConfig = None):
+        super().__init__(params)
+        if not params:
+            params = LMConfig()
+        self.params = params
+        self.vocab_size = params.vocab_size
+        self.n_layers = params.n_layers
 class Transformer(PreTrainedModel):
     config_class = LMConfig
     last_loss: Optional[torch.Tensor]
@@ -322,99 +333,99 @@ class Transformer(PreTrainedModel):
         self.vocab_size = params.vocab_size
         self.n_layers = params.n_layers
 
-        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
-        self.dropout = nn.Dropout(params.dropout)
-        self.layers = torch.nn.ModuleList()
+        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)  # 初始化词嵌入层
+        self.dropout = nn.Dropout(params.dropout)  # 初始化 dropout 层
+        self.layers = torch.nn.ModuleList()  # 初始化 Transformer 块列表
         for layer_id in range(self.n_layers):
-            self.layers.append(TransformerBlock(layer_id, params))
-        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
-        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
-        self.tok_embeddings.weight = self.output.weight
-        pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
-        self.register_buffer("pos_cis", pos_cis, persistent=False)
+            self.layers.append(TransformerBlock(layer_id, params))  # 添加 Transformer 块
+        self.norm = RMSNorm(params.dim, eps=params.norm_eps)  # 初始化归一化层
+        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)  # 初始化输出层
+        self.tok_embeddings.weight = self.output.weight  # 共享词嵌入和输出层的权重
+        pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)  # 预计算位置编码
+        self.register_buffer("pos_cis", pos_cis, persistent=False)  # 注册位置编码缓冲区
 
-        self.apply(self._init_weights)
+        self.apply(self._init_weights)  # 初始化模型权重
 
         for pn, p in self.named_parameters():
             if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
-                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
+                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))  # 对特定权重进行初始化
 
-        self.last_loss = None
-        self.OUT = CausalLMOutputWithPast()
+        self.last_loss = None  # 初始化最后一个损失
+        self.OUT = CausalLMOutputWithPast()  # 初始化输出对象
 
     def _init_weights(self, module):
         if isinstance(module, nn.Linear):
-            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)  # 初始化线性层的权重
             if module.bias is not None:
-                torch.nn.init.zeros_(module.bias)
+                torch.nn.init.zeros_(module.bias)  # 初始化线性层的偏置
         elif isinstance(module, nn.Embedding):
-            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)  # 初始化嵌入层的权重
 
     def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
                 use_kv_cache=False, **keyargs):
         if 'input_ids' in keyargs:
-            tokens = keyargs['input_ids']
+            tokens = keyargs['input_ids']  # 如果传入了 input_ids,则使用 input_ids
         if 'attention_mask' in keyargs:
-            targets = keyargs['attention_mask']
+            targets = keyargs['attention_mask']  # 如果传入了 attention_mask,则使用 attention_mask
 
-        _bsz, seqlen = tokens.shape
-        h = self.tok_embeddings(tokens)
-        h = self.dropout(h)
-        pos_cis = self.pos_cis[:seqlen]
+        _bsz, seqlen = tokens.shape  # 获取批量大小和序列长度
+        h = self.tok_embeddings(tokens)  # 获取词嵌入
+        h = self.dropout(h)  # 应用 dropout
+        pos_cis = self.pos_cis[:seqlen]  # 获取对应序列长度的位置编码
         for idx, layer in enumerate(self.layers):
-            h = layer(h, pos_cis, use_kv_cache)
+            h = layer(h, pos_cis, use_kv_cache)  # 逐层应用 Transformer 块
 
-        h = self.norm(h)
+        h = self.norm(h)  # 应用归一化
 
         if targets is not None:
-            logits = self.output(h)
-            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
+            logits = self.output(h)  # 计算 logits
+            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)  # 计算交叉熵损失
         else:
-            logits = self.output(h[:, [-1], :])
-            self.last_loss = None
+            logits = self.output(h[:, [-1], :])  # 计算最后一个 token 的 logits
+            self.last_loss = None  # 没有目标时,损失为 None
 
-        self.OUT.__setitem__('logits', logits)
-        self.OUT.__setitem__('last_loss', self.last_loss)
+        self.OUT.__setitem__('logits', logits)  # 设置输出对象的 logits
+        self.OUT.__setitem__('last_loss', self.last_loss)  # 设置输出对象的 last_loss
 
-        return self.OUT
+        return self.OUT  # 返回输出对象
 
-    @torch.inference_mode()
+    @torch.inference_mode()  # 推理模式
     def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
                  use_kv_cache=True):
-        index = idx.shape[1]
-        while idx.shape[1] < max_new_tokens - 1:
-            inference_res = self(idx, use_kv_cache=use_kv_cache)
-            logits = inference_res.logits
-            logits = logits[:, -1, :]
+        index = idx.shape[1]  # 获取当前序列长度
+        while idx.shape[1] < max_new_tokens - 1:  # 当生成的 token 数量小于最大数量时
+            inference_res = self(idx, use_kv_cache=use_kv_cache)  # 进行前向传播
+            logits = inference_res.logits  # 获取 logits
+            logits = logits[:, -1, :]  # 获取最后一个 token 的 logits
 
-            for token in set(idx.tolist()[0]):
+            for token in set(idx.tolist()[0]):  # 对重复 token 进行惩罚
                 logits[:, token] /= repetition_penalty
 
-            if temperature == 0.0:
+            if temperature == 0.0:  # 如果温度为 0,直接选择概率最高的 token
                 _, idx_next = torch.topk(logits, k=1, dim=-1)
             else:
-                logits = logits / temperature
-                if top_k is not None:
+                logits = logits / temperature  # 调整 logits
+                if top_k is not None:  # 如果设置了 top-k 采样
                     v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
-                    logits[logits < v[:, [-1]]] = -float('Inf')
+                    logits[logits < v[:, [-1]]] = -float('Inf')  # 将小于 top-k 的 logits 设为负无穷
 
-                probs = F.softmax(logits, dim=-1)
-                idx_next = torch.multinomial(probs, num_samples=1, generator=None)
+                probs = F.softmax(logits, dim=-1)  # 计算概率
+                idx_next = torch.multinomial(probs, num_samples=1, generator=None)  # 采样下一个 token
 
-            if idx_next == eos:
+            if idx_next == eos:  # 如果生成的 token 是结束符,停止生成
                 break
 
-            idx = torch.cat((idx, idx_next), dim=1)
-            if stream:
-                yield idx[:, index:]
+            idx = torch.cat((idx, idx_next), dim=1)  # 将生成的 token 添加到序列中
+            if stream:  # 如果需要流式输出
+                yield idx[:, index:]  # 返回生成的 token
 
-        if not stream:
-            yield idx[:, index:]
+        if not stream:  # 如果不需要流式输出
+            yield idx[:, index:]  # 返回生成的 token
 
-    @torch.inference_mode()
+    @torch.inference_mode()  # 推理模式
     def eval_answer(self, idx):
-        idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
-        inference_res = self(idx_cond)
-        logits = inference_res.logits
-        logits = logits[:, -1, :]
-        return logits
+        idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]  # 截取序列
+        inference_res = self(idx_cond)  # 进行前向传播
+        logits = inference_res.logits  # 获取 logits
+        logits = logits[:, -1, :]  # 获取最后一个 token 的 logits
+        return logits  # 返回 logits
\ No newline at end of file