diff --git a/eval_model.py b/eval_model.py index 4e2e454..7e30a54 100644 --- a/eval_model.py +++ b/eval_model.py @@ -58,6 +58,13 @@ def load_model(model_path, model_type, device, config_params=None): from model.model_no_feed import MiniMindLM elif model_type == "model_memory": from model.model_memory import MiniMindLM + elif model_type.startswith("model_memory_"): + # 支持通用的model_memory_X_X_X格式 + try: + module = __import__(f"model.{model_type}", fromlist=["MiniMindLM"]) + MiniMindLM = getattr(module, "MiniMindLM") + except (ImportError, AttributeError) as e: + raise ValueError(f"无法导入模型类型 {model_type}: {e}") else: raise ValueError(f"不支持的模型类型: {model_type}") @@ -254,6 +261,12 @@ def evaluate_sample(model, tokenizer, text, input_length=100, predict_length=100 ground_truth_text: 真实文本 loss: 预测损失(如果可计算) """ + # 添加与训练时一致的BOS/EOS token处理 + if not text.startswith(tokenizer.bos_token): + text = f"{tokenizer.bos_token}{text}" + if not text.endswith(tokenizer.eos_token): + text = f"{text}{tokenizer.eos_token}" + # 对文本进行分词 tokens = tokenizer.encode(text, add_special_tokens=False) @@ -347,11 +360,10 @@ def evaluate_sample(model, tokenizer, text, input_length=100, predict_length=100 def main(): parser = argparse.ArgumentParser(description='评估预训练模型') - parser.add_argument('--model_path', type=str, default='out/experiment_1_4_0/pretrain_512.pth', + parser.add_argument('--model_path', type=str, default='out/experiment_1_4_1/pretrain_512.pth', help='模型权重文件路径') - parser.add_argument('--model_type', type=str, default='model', - choices=['model', 'model_original', 'model_no_feed', 'model_memory'], - help='模型类型') + parser.add_argument('--model_type', type=str, default='model_memory', + help='模型类型 (支持model, model_original, model_no_feed, model_memory, model_memory_X_X_X等)') parser.add_argument('--data_path', type=str, default='dataset/stable/eval_data.json', help='评估数据集路径') parser.add_argument('--num_samples', type=int, default=20, @@ -427,8 +439,8 @@ def main(): 'n_routed_experts': args.n_routed_experts, } - # 只有model、model_no_feed和model_memory需要KnowledgeDataset参数 - if args.model_type in ['model', 'model_no_feed', 'model_memory']: + # 只有model、model_no_feed和model_memory系列需要KnowledgeDataset参数 + if args.model_type in ['model', 'model_no_feed', 'model_memory'] or args.model_type.startswith('model_memory_'): config_params.update({ 'knowledge_num': args.knowledge_num, 'knowledge_length': args.knowledge_length, diff --git a/model/model_memory.py b/model/model_memory.py index 115d21f..d62443b 100644 --- a/model/model_memory.py +++ b/model/model_memory.py @@ -153,6 +153,8 @@ class MemoryGate(nn.Module): Returns: memory_indices: [batch_size, seq_len, num_selected] memory_scores: [batch_size, seq_len, num_selected] + balance_loss: 平衡损失(KL散度 + 基尼系数) + stats: 监控统计信息字典 """ bsz, seq_len, _ = x.shape @@ -186,80 +188,132 @@ class MemoryGate(nn.Module): memory_scores = F.softmax(final_scores, dim=-1) memory_scores = self.dropout(memory_scores) - return memory_indices, memory_scores + # 计算平衡损失和监控统计 + balance_loss, stats = self._compute_balance_loss_and_stats(memory_indices, memory_scores) + + return memory_indices, memory_scores, balance_loss, stats + + def _compute_balance_loss_and_stats(self, memory_indices, memory_scores): + """ + 计算平衡损失和监控统计信息 + + Args: + memory_indices: [batch_size, seq_len, num_selected] + memory_scores: [batch_size, seq_len, num_selected] + + Returns: + balance_loss: 标量张量 + stats: 统计信息字典 + """ + bsz, seq_len, num_selected = memory_indices.shape + device = memory_indices.device + + # 1. 计算记忆选择分布 + # 将所有选择的记忆索引展平 + flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected] + + # 统计每个记忆条目被选中的次数 + memory_counts = torch.zeros(self.knowledge_num, device=device) + memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float)) + + # 计算选择概率分布 + total_selections = bsz * seq_len * num_selected + memory_probs = memory_counts / total_selections + + # 2. 计算KL散度损失(与均匀分布的KL散度) + uniform_prob = 1.0 / self.knowledge_num + # 避免log(0)的问题 + memory_probs_safe = memory_probs + 1e-10 + kl_loss = F.kl_div( + torch.log(memory_probs_safe), + torch.full_like(memory_probs, uniform_prob), + reduction='sum' + ) + + # 3. 计算基尼系数损失(衡量分布不平等程度) + sorted_probs, _ = torch.sort(memory_probs) + n = self.knowledge_num + index = torch.arange(1, n + 1, device=device, dtype=torch.float) + gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n + gini_loss = gini_coeff # 基尼系数越大,分布越不均匀 + + # 4. 组合平衡损失 + balance_loss = 0.5 * kl_loss + 0.5 * gini_loss + + # 5. 计算监控统计信息 + with torch.no_grad(): + # 记忆覆盖率:被选中的记忆条目占总数的比例 + coverage_rate = (memory_counts > 0).float().mean().item() + + # 热点记忆:选择次数前10%的记忆条目 + top10_threshold = torch.quantile(memory_counts, 0.9) + hot_memories = (memory_counts >= top10_threshold).sum().item() + + # 死记忆:从未被选中的记忆条目 + dead_memories = (memory_counts == 0).sum().item() + + # 记忆选择方差(衡量不平衡程度) + selection_variance = memory_counts.var().item() + + stats = { + 'gini_coefficient': gini_coeff.item(), + 'kl_divergence': kl_loss.item(), + 'coverage_rate': coverage_rate, + 'hot_memories': hot_memories, + 'dead_memories': dead_memories, + 'selection_variance': selection_variance, + 'max_selections': memory_counts.max().item(), + 'min_selections': memory_counts.min().item(), + } + + return balance_loss, stats -class CrossAttentionMemory(nn.Module): - """Cross attention using selected memory as K and V""" +class GatedMemoryFusion(nn.Module): + """Gated MLP fusion for concatenated h_attn and selected memories""" def __init__(self, config: LMConfig): super().__init__() self.config = config - self.n_heads = config.n_heads - self.head_dim = config.dim // config.n_heads self.dim = config.dim self.knowledge_dim = config.knowledge_dim + self.num_selected = getattr(config, 'num_selected', 16) - # Q从self-attention输出计算 - self.wq = nn.Linear(config.dim, config.dim, bias=False) + # 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆) + concat_dim = self.dim + self.num_selected * self.knowledge_dim - # K,V从记忆数据计算 - self.wk = nn.Linear(config.knowledge_dim, config.dim, bias=False) - self.wv = nn.Linear(config.knowledge_dim, config.dim, bias=False) + # 类似SwiGLU的门控MLP结构 + self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False) + self.up_proj = nn.Linear(concat_dim, self.dim, bias=False) + self.down_proj = nn.Linear(self.dim, self.dim, bias=False) - # 输出投影 - self.wo = nn.Linear(config.dim, config.dim, bias=False) self.dropout = nn.Dropout(config.dropout) - def forward(self, x: torch.Tensor, memory_data: torch.Tensor, memory_scores: torch.Tensor): + def forward(self, h_attn: torch.Tensor, selected_memories: torch.Tensor, memory_scores: torch.Tensor): """ Args: - x: [batch_size, seq_len, dim] - Query from self attention - memory_data: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data - memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights + h_attn: [batch_size, seq_len, dim] - Self attention output + selected_memories: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data + memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights (not used in concatenation approach) Returns: output: [batch_size, seq_len, dim] """ - bsz, seq_len, _ = x.shape - num_selected = memory_data.shape[2] + bsz, seq_len, _ = h_attn.shape - # 计算Query - q = self.wq(x) # [batch, seq_len, dim] - q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [batch, n_heads, seq_len, head_dim] + # 将选中的记忆展平为一维向量 + # [batch, seq_len, num_selected, knowledge_dim] -> [batch, seq_len, num_selected * knowledge_dim] + memory_flat = selected_memories.view(bsz, seq_len, -1) - # 对选中的记忆数据计算K和V - memory_flat = memory_data.view(bsz * seq_len * num_selected, self.knowledge_dim) - k_flat = self.wk(memory_flat) # [batch * seq_len * num_selected, dim] - v_flat = self.wv(memory_flat) # [batch * seq_len * num_selected, dim] + # 拼接h_attn和记忆信息 + concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim] - # 重塑K和V - k = k_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim] - v = v_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim] + # 门控MLP处理(类似SwiGLU) + gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim] + up = self.up_proj(concat_input) # [batch, seq_len, dim] + fusion_output = gate * up # Element-wise multiplication - # 扩展Q以匹配记忆维度进行交叉注意力 - q_expanded = q.unsqueeze(3) # [batch, n_heads, seq_len, 1, head_dim] - - # 计算注意力分数 - # q_expanded: [batch, n_heads, seq_len, 1, head_dim] - # k: [batch, n_heads, seq_len, num_selected, head_dim] - scores = torch.matmul(q_expanded, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [batch, n_heads, seq_len, 1, num_selected] - scores = scores.squeeze(3) # [batch, n_heads, seq_len, num_selected] - - # 应用记忆选择权重 - memory_scores_expanded = memory_scores.unsqueeze(1).expand(-1, self.n_heads, -1, -1) # [batch, n_heads, seq_len, num_selected] - scores = scores + memory_scores_expanded.log() # 在log空间相加 - - # Softmax归一化 - attn_weights = F.softmax(scores, dim=-1) # [batch, n_heads, seq_len, num_selected] - attn_weights = self.dropout(attn_weights) - - # 应用注意力权重到V - # attn_weights: [batch, n_heads, seq_len, num_selected] - # v: [batch, n_heads, seq_len, num_selected, head_dim] - output = torch.einsum('bhlk,bhlkd->bhld', attn_weights, v) # [batch, n_heads, seq_len, head_dim] - - # 重塑输出 - output = output.transpose(1, 2).reshape(bsz, seq_len, self.dim) # [batch, seq_len, dim] - output = self.wo(output) + # 输出投影 + output = self.down_proj(fusion_output) # [batch, seq_len, dim] + output = self.dropout(output) return output @@ -279,7 +333,7 @@ class MiniMindBlock(nn.Module): # 记忆相关模块 self.memory_gate = MemoryGate(config) - self.cross_attention_memory = CrossAttentionMemory(config) + self.gated_memory_fusion = GatedMemoryFusion(config) def forward(self, x, pos_cis, memory_bank): """ @@ -287,16 +341,21 @@ class MiniMindBlock(nn.Module): x: [batch_size, seq_len, dim] pos_cis: positional encoding memory_bank: [knowledge_num, knowledge_dim] - shared memory bank + + Returns: + out: [batch_size, seq_len, dim] + balance_loss: 该层的平衡损失 + layer_stats: 该层的监控统计信息 """ # Self attention h_attn = self.attention(self.attention_norm(x), pos_cis) h = x + h_attn # 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出) - h_for_memory = self.memory_norm(h) + h_for_memory = self.memory_norm(h_attn) # 门控选择记忆 - memory_indices, memory_scores = self.memory_gate(h_for_memory) + memory_indices, memory_scores, balance_loss, layer_stats = self.memory_gate(h_for_memory) # 根据索引获取记忆数据 bsz, seq_len, num_selected = memory_indices.shape @@ -304,14 +363,13 @@ class MiniMindBlock(nn.Module): selected_memory = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_dim] selected_memory = selected_memory.view(bsz, seq_len, num_selected, -1) # [batch, seq_len, num_selected, knowledge_dim] - h = x + selected_memory - # 交叉注意力:Q来自h_attn,K和V来自选中的记忆 - memory_output = self.cross_attention_memory(x, selected_memory, memory_scores) + # 门控MLP融合:串型连接h_attn和选中的记忆 + memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores) # 残差连接 out = h + memory_output - return out + return out, balance_loss, layer_stats class MiniMindLM(PreTrainedModel): @@ -337,7 +395,58 @@ class MiniMindLM(PreTrainedModel): requires_grad=True ) + # 记录上一步的记忆库状态,用于计算更新统计 + self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False) + self.OUT = CausalLMOutputWithPast() + + def get_memory_update_stats(self): + """ + 计算记忆库更新统计信息 + + Returns: + update_stats: 包含更新统计的字典 + """ + with torch.no_grad(): + if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0: + # 计算L2距离变化 + l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1) + avg_l2_distance = l2_distance.mean().item() + max_l2_distance = l2_distance.max().item() + + # 计算余弦相似度 + cos_sim = F.cosine_similarity( + self.memory_bank.view(-1), + self.prev_memory_bank.view(-1), + dim=0 + ).item() + + # 计算更新率(发生显著变化的记忆条目比例) + threshold = 0.01 # 更新阈值 + updated_memories = (l2_distance > threshold).sum().item() + update_rate = updated_memories / self.memory_bank.size(0) + + update_stats = { + 'memory_avg_l2_change': avg_l2_distance, + 'memory_max_l2_change': max_l2_distance, + 'memory_cosine_similarity': cos_sim, + 'memory_update_rate': update_rate, + 'memory_updated_count': updated_memories + } + else: + # 第一次调用时的默认值 + update_stats = { + 'memory_avg_l2_change': 0.0, + 'memory_max_l2_change': 0.0, + 'memory_cosine_similarity': 1.0, + 'memory_update_rate': 0.0, + 'memory_updated_count': 0 + } + + # 更新prev_memory_bank + self.prev_memory_bank.copy_(self.memory_bank) + + return update_stats def forward(self, input_ids: Optional[torch.Tensor] = None, @@ -347,16 +456,26 @@ class MiniMindLM(PreTrainedModel): h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] - for layer in self.layers: - h = layer(h, pos_cis, self.memory_bank) + # 收集所有层的平衡损失和统计信息 + total_balance_loss = 0 + all_layer_stats = {} + + for layer_idx, layer in enumerate(self.layers): + h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank) + total_balance_loss += balance_loss + # 为每层的统计信息添加前缀 + for key, value in layer_stats.items(): + all_layer_stats[f'layer_{layer_idx}_{key}'] = value logits = self.output(self.norm(h)) - # 统一不使用 aux_loss - aux_loss = 0 + # 使用总的平衡损失作为aux_loss + aux_loss = total_balance_loss + self.OUT.__setitem__('last_hidden_state', h) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) + self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 self.OUT.__setitem__('past_key_values', None) # 不支持KV cache return self.OUT diff --git a/pyproject.toml b/pyproject.toml index d9002d1..e22456e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,7 @@ dependencies = [ "smmap==5.0.2", "sniffio==1.3.1", "streamlit==1.30.0", + "superclaude>=3.0.0.2", "swankit==0.2.4", "swanlab==0.6.4", "sympy==1.13.3", diff --git a/run_file/experiment_1_4_4.sh b/run_file/experiment_1_4_4.sh new file mode 100644 index 0000000..3dcd0d4 --- /dev/null +++ b/run_file/experiment_1_4_4.sh @@ -0,0 +1,335 @@ +#!/bin/bash + +# ============================================================================ +# MiniMind 实验脚本 - Experiment 1.4.4 +# ============================================================================ +# +# 🎯 实验目标: +# 基于实验1.4.2的model_memory架构,深度验证记忆库机制,实现平衡损失和四维度监控体系 +# +# 使用方法: +# bash run_file/experiment_1_4_4.sh +# ============================================================================ + +# ---------------------------------------------------------------------------- +# 🧑‍🔬 实验基本信息 +# ---------------------------------------------------------------------------- +EXPERIMENT_VERSION="1.4.4" +EXPERIMENT_DESCRIPTION="model_memory平衡损失与四维度监控实验" +RESEARCHER_NAME="AI Assistant" +EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')" + +# ---------------------------------------------------------------------------- +# 🤖 环境配置 +# ---------------------------------------------------------------------------- + +# 调试和监控环境变量 +export NCCL_DEBUG=INFO +export PYTHONFAULTHANDLER=1 +export CUDA_LAUNCH_BLOCKING=1 + +# SwanLab 配置 +export SWANLAB_PROJECT="MiniMind-Experiment-1.4.4" + +# 日志配置 +LOG_DIR="out/experiment_${EXPERIMENT_VERSION}" +mkdir -p "$LOG_DIR" +LOG_FILE="$LOG_DIR/experiment.log" + +# ---------------------------------------------------------------------------- +# 🤖 硬件配置 +# ---------------------------------------------------------------------------- +CUDA_VISIBLE_DEVICES="0" +NUM_PROCESSES="1" +MIXED_PRECISION="bf16" +MAIN_PROCESS_PORT="29500" + +# ---------------------------------------------------------------------------- +# 🤖 模型架构参数 +# ---------------------------------------------------------------------------- +MODEL_TYPE="model_memory" +MODEL_SIZE="50.0" +DIM="512" +N_LAYERS="8" +N_HEADS="32" +MAX_SEQ_LEN="512" +USE_MOE="false" + +# 知识库配置(使用更小的记忆库以适应实验需求) +KNOWLEDGE_NUM="65536" # 256x256 = 65536,确保是完全平方数 +KNOWLEDGE_LENGTH="32" +KNOWLEDGE_DIM="128" +DISABLE_DB="false" + +# ---------------------------------------------------------------------------- +# 🤖 训练超参数 +# ---------------------------------------------------------------------------- +EPOCHS="3" +EMBEDDING_EPOCH="2" +BATCH_SIZE="128" +ACCUMULATION_STEPS="8" +LEARNING_RATE="2e-4" +DTYPE="bfloat16" +GRAD_CLIP="1.0" +WARMUP_ITERS="0" + +# 平衡损失配置 +BALANCE_LOSS_COEF="0.1" + +# 数据和缓存路径 +DATA_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl" +DATABASE_INIT_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json" +CLUSTER_CACHE_PATH="/home/pci/ycz/Code/Minimind/cache/cluster_tokens_single.pt" +VAL_DATA_PATH="dataset/stable/eval_data.json" + +# 训练配置(合并log_interval和profile参数) +NUM_WORKERS="1" +LOG_INTERVAL="100" +VAL_INTERVAL="100" +SAVE_INTERVAL="10000" + +# 性能分析配置 +USE_PROFILE="true" +PROFILE_INTERVAL="10" +MEMORY_MONITOR_INTERVAL="100" + +# 高级功能 +USE_FLASH_ATTN="true" +FAST_CLUSTERING="true" + +# ---------------------------------------------------------------------------- +# 🤖 预检查函数 +# ---------------------------------------------------------------------------- +check_environment() { + echo "🔍 环境检查中..." + + # 检查GPU可用性 + if ! nvidia-smi &> /dev/null; then + echo "❌ 错误: 未检测到GPU或nvidia-smi不可用" + exit 1 + fi + + # 检查CUDA设备 + if ! nvidia-smi -i "$CUDA_VISIBLE_DEVICES" &> /dev/null; then + echo "❌ 错误: GPU $CUDA_VISIBLE_DEVICES 不可用" + exit 1 + fi + + # 检查Python环境 + if ! .venv/bin/python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null; then + echo "❌ 错误: PyTorch未正确安装" + exit 1 + fi + + # 检查数据文件 + if [[ ! -f "$DATA_PATH" ]]; then + echo "❌ 错误: 训练数据文件不存在: $DATA_PATH" + exit 1 + fi + + if [[ ! -f "$DATABASE_INIT_PATH" ]]; then + echo "❌ 错误: 数据库初始化文件不存在: $DATABASE_INIT_PATH" + exit 1 + fi + + echo "✅ 环境检查通过" +} + +# ---------------------------------------------------------------------------- +# 🤖 实验信息记录 +# ---------------------------------------------------------------------------- +log_experiment_info() { + echo "📝 记录实验信息..." + cat > "$LOG_DIR/experiment_info.txt" << EOF +======================================== +MiniMind 实验信息 +======================================== +实验版本: $EXPERIMENT_VERSION +实验描述: $EXPERIMENT_DESCRIPTION +研究者: $RESEARCHER_NAME +开始时间: $EXPERIMENT_DATE +======================================== +硬件配置: +GPU设备: $CUDA_VISIBLE_DEVICES +进程数: $NUM_PROCESSES +混合精度: $MIXED_PRECISION +======================================== +模型配置: +模型类型: $MODEL_TYPE +模型大小: $MODEL_SIZE MB +维度: $DIM +层数: $N_LAYERS +注意力头数: $N_HEADS +最大序列长度: $MAX_SEQ_LEN +知识库大小: $KNOWLEDGE_NUM +知识长度: $KNOWLEDGE_LENGTH +知识维度: $KNOWLEDGE_DIM +======================================== +训练配置: +训练轮次: $EPOCHS +批次大小: $BATCH_SIZE +学习率: $LEARNING_RATE +梯度累积: $ACCUMULATION_STEPS +数据类型: $DTYPE +平衡损失系数: $BALANCE_LOSS_COEF +======================================== +数据路径: +训练数据: $DATA_PATH +验证数据: $VAL_DATA_PATH +数据库初始化: $DATABASE_INIT_PATH +聚类缓存: $CLUSTER_CACHE_PATH +======================================== +EOF +} + +# ---------------------------------------------------------------------------- +# 🤖 主执行函数 +# ---------------------------------------------------------------------------- +run_experiment() { + echo "🚀 开始执行实验 $EXPERIMENT_VERSION" + echo "📄 实验描述: $EXPERIMENT_DESCRIPTION" + echo "⏰ 开始时间: $EXPERIMENT_DATE" + + # 构建训练命令 + local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES .venv/bin/python train_pretrain_accelerate.py" + + # 添加训练参数 + train_cmd+=" --out_dir \"$LOG_DIR\"" + train_cmd+=" --epochs $EPOCHS" + train_cmd+=" --embedding_epoch $EMBEDDING_EPOCH" + train_cmd+=" --batch_size $BATCH_SIZE" + train_cmd+=" --learning_rate $LEARNING_RATE" + train_cmd+=" --dtype $DTYPE" + train_cmd+=" --num_workers $NUM_WORKERS" + train_cmd+=" --accumulation_steps $ACCUMULATION_STEPS" + train_cmd+=" --grad_clip $GRAD_CLIP" + train_cmd+=" --warmup_iters $WARMUP_ITERS" + train_cmd+=" --log_interval $LOG_INTERVAL" + train_cmd+=" --val_interval $VAL_INTERVAL" + train_cmd+=" --save_interval $SAVE_INTERVAL" + train_cmd+=" --dim $DIM" + train_cmd+=" --n_layers $N_LAYERS" + train_cmd+=" --n_heads $N_HEADS" + train_cmd+=" --max_seq_len $MAX_SEQ_LEN" + train_cmd+=" --data_path \"$DATA_PATH\"" + train_cmd+=" --val_data_path \"$VAL_DATA_PATH\"" + train_cmd+=" --knowledge_num $KNOWLEDGE_NUM" + train_cmd+=" --knowledge_length $KNOWLEDGE_LENGTH" + train_cmd+=" --database_init_path \"$DATABASE_INIT_PATH\"" + train_cmd+=" --memory_monitor_interval $MEMORY_MONITOR_INTERVAL" + train_cmd+=" --model_type \"$MODEL_TYPE\"" + train_cmd+=" --model_size $MODEL_SIZE" + train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF" + + # 可选参数 + if [[ "$USE_PROFILE" == "true" ]]; then + train_cmd+=" --profile" + train_cmd+=" --profile_interval $PROFILE_INTERVAL" + fi + + if [[ "$USE_FLASH_ATTN" == "true" ]]; then + train_cmd+=" --use_flash_attn" + fi + + if [[ "$FAST_CLUSTERING" == "true" ]]; then + train_cmd+=" --fast_clustering" + fi + + if [[ "$CLUSTER_CACHE_PATH" != "None" ]]; then + train_cmd+=" --cluster_cache_path \"$CLUSTER_CACHE_PATH\"" + fi + + # SwanLab配置 + train_cmd+=" --use_swanlab" + train_cmd+=" --swanlab_project \"$SWANLAB_PROJECT\"" + train_cmd+=" --swanlab_online True" + + echo "📋 执行命令:" + echo "$train_cmd" + echo + + # 记录命令到日志文件 + echo "执行命令: $train_cmd" >> "$LOG_FILE" + echo "开始时间: $(date)" >> "$LOG_FILE" + + # 使用nohup执行训练(后台运行,输出写入日志文件) + echo "🔄 使用nohup后台运行训练,输出将写入日志文件: $LOG_FILE" + + # 创建训练脚本 + train_script="/tmp/train_${EXPERIMENT_VERSION}.sh" + cat > "$train_script" << EOF +#!/bin/bash +cd /home/pci/ycz/Code/pretrain-worktree +source /home/pci/ycz/Code/pretrain-worktree/.venv/bin/activate +$train_cmd +echo "结束时间: \$(date)" +echo "退出代码: \$?" +EOF + chmod +x "$train_script" + + # 使用nohup后台运行 + nohup bash "$train_script" >> "$LOG_FILE" 2>&1 & + local train_pid=$! + + echo "🔥 训练进程已启动,PID: $train_pid" + echo "训练PID: $train_pid" >> "$LOG_FILE" + echo "训练脚本: $train_script" >> "$LOG_FILE" + + # 等待几秒确保进程启动 + sleep 5 + + # 检查进程是否还在运行 + if kill -0 $train_pid 2>/dev/null; then + echo "✅ 训练进程正在后台运行" + echo "📋 实时查看日志: tail -f $LOG_FILE" + echo "📋 检查进程状态: ps -p $train_pid" + echo "🛑 停止训练: kill $train_pid" + echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT" + echo "" + echo "训练正在后台运行,可以安全关闭终端。" + else + echo "❌ 训练进程启动失败" + echo "📋 查看日志: $LOG_FILE" + exit 1 + fi +} + +# ---------------------------------------------------------------------------- +# 🤖 清理函数 +# ---------------------------------------------------------------------------- +cleanup() { + echo "🧹 清理临时文件..." + # 删除临时验证文件 + rm -f /tmp/temp_val.jsonl +} + +# ---------------------------------------------------------------------------- +# 🤖 信号处理 +# ---------------------------------------------------------------------------- +trap cleanup EXIT +trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM + +# ---------------------------------------------------------------------------- +# 🤖 主程序入口 +# ---------------------------------------------------------------------------- +main() { + echo "============================================================================" + echo "🧠 MiniMind 预训练实验 1.4.4" + echo "🎯 深度验证记忆库机制 - 平衡损失与四维度监控" + echo "============================================================================" + + # 执行检查和初始化 + check_environment + log_experiment_info + + # 运行实验 + run_experiment + + echo "============================================================================" + echo "✅ 实验 $EXPERIMENT_VERSION 启动完成" + echo "📅 启动时间: $(date)" + echo "============================================================================" +} + +# 执行主程序 +main "$@" \ No newline at end of file diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 74d3307..d92b548 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -24,7 +24,7 @@ from sklearn.metrics.pairwise import cosine_similarity import swanlab # 替换wandb导入 import gc # 添加垃圾回收模块 import psutil # 添加系统资源监控模块 - +import json # 添加JSON支持 from model.LMConfig import LMConfig from model.dataset import PretrainDataset @@ -98,6 +98,86 @@ def Logger(msg, accelerator=None): def format_time(seconds): return str(datetime.timedelta(seconds=int(seconds))) +def create_validation_dataset(val_data_path, tokenizer, max_length, num_samples=200): + """ + 创建验证数据集 + + Args: + val_data_path: 验证数据文件路径 + tokenizer: tokenizer实例 + max_length: 最大序列长度 + num_samples: 验证样本数量 + + Returns: + val_dataset: 验证数据集 + """ + if not os.path.exists(val_data_path): + Logger(f"警告:验证数据文件不存在: {val_data_path},跳过验证评估") + return None + + # 读取验证数据 + val_data = [] + with open(val_data_path, 'r', encoding='utf-8') as f: + for i, line in enumerate(f): + if i >= num_samples: # 限制验证样本数量 + break + line = line.strip() + if line: + try: + sample = json.loads(line) + val_data.append(sample['text']) + except json.JSONDecodeError: + continue + + # 创建临时验证文件 + temp_val_file = "/tmp/temp_val.jsonl" + with open(temp_val_file, 'w', encoding='utf-8') as f: + for text in val_data: + f.write(json.dumps({'text': text}) + '\n') + + # 使用PretrainDataset创建验证集 + val_dataset = PretrainDataset(temp_val_file, tokenizer, max_length=max_length) + Logger(f"创建验证数据集成功,包含 {len(val_data)} 个样本") + + return val_dataset + +def validate_model(model, val_loader, loss_fct, ctx, accelerator): + """ + 执行模型验证 + + Args: + model: 模型实例 + val_loader: 验证数据加载器 + loss_fct: 损失函数 + ctx: 上下文管理器 + accelerator: Accelerator实例 + + Returns: + avg_val_loss: 平均验证损失 + """ + model.eval() + total_loss = 0 + num_batches = 0 + + with torch.no_grad(): + for batch in val_loader: + X, Y, loss_mask = batch + + with ctx: + res = model(X) + loss = loss_fct( + res.logits.view(-1, res.logits.size(-1)), + Y.view(-1) + ).view(Y.size()) + loss = (loss * loss_mask).sum() / loss_mask.sum() + + total_loss += loss.item() + num_batches += 1 + + model.train() + avg_val_loss = total_loss / num_batches if num_batches > 0 else float('inf') + return avg_val_loss + # 获取学习率函数 def get_lr(it, num_iters, learning_rate): # 余弦学习率衰减 @@ -541,7 +621,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non return model, tokenizer -def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer): +def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader=None): loss_fct = nn.CrossEntropyLoss(reduction='none') epoch_start_time = time.time() total_steps_in_epoch = len(train_loader) @@ -644,13 +724,22 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a unwrapped_model.freeze_embedding = True Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator) res = model(X, step=step) - loss = loss_fct( + + # 计算主要损失(交叉熵损失) + ce_loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) - loss = (loss * loss_mask).sum() / loss_mask.sum() - # 移除辅助损失计算,统一不使用 aux_loss - loss = loss / args.accumulation_steps + ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum() + + # 获取平衡损失(如果模型支持) + balance_loss = 0 + if hasattr(res, 'aux_loss') and res.aux_loss is not None: + balance_loss = res.aux_loss + + # 计算总损失 + total_loss = ce_loss + args.balance_loss_coef * balance_loss + loss = total_loss / args.accumulation_steps # 计时前向传播结束 (只在主进程进行) if args.profile and accelerator.is_main_process and forward_end is not None: @@ -685,8 +774,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a if args.profile and accelerator.is_main_process and optimizer_end is not None: optimizer_end.record() - # 打印训练信息 (只在主进程进行) - if (step + 1) % args.log_interval == 0 and accelerator.is_main_process: + # 验证评估和日志记录 (只在主进程进行) + if (step + 1) % args.val_interval == 0 and accelerator.is_main_process: current_time = time.time() # 记录日志输出时的详细内存状态 @@ -809,19 +898,72 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0 last_log_time = current_time # 更新上次日志时间 + # 执行验证评估 + val_loss = None + if val_loader is not None: + try: + val_loss = validate_model(model, val_loader, loss_fct, ctx, accelerator) + Logger(f"验证损失: {val_loss:.4f}", accelerator) + except Exception as e: + Logger(f"验证评估失败: {e}", accelerator) + val_loss = None + + # 获取记忆库更新统计(如果模型支持) + memory_update_stats = {} + if hasattr(model, 'get_memory_update_stats'): + try: + unwrapped_model = accelerator.unwrap_model(model) + if hasattr(unwrapped_model, 'get_memory_update_stats'): + memory_update_stats = unwrapped_model.get_memory_update_stats() + except Exception as e: + Logger(f"获取记忆更新统计失败: {e}", accelerator) + + # 获取层级统计信息(如果模型支持) + layer_stats = {} + if hasattr(res, 'layer_stats') and res.layer_stats is not None: + layer_stats = res.layer_stats + + # 构建日志字典 log_dict = { "epoch": epoch + 1, "step": step + 1, "total_steps_in_epoch": total_steps_in_epoch, - "loss": loss.item() * args.accumulation_steps, + "train/loss_ce": ce_loss.item(), + "train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss, + "train/loss_total": total_loss.item(), "lr": current_lr, "tokens_per_sec": tokens_per_sec, "epoch_time_left_seconds": epoch_remaining_time, "total_time_left_seconds": total_remaining_time } + # 添加验证损失 + if val_loss is not None: + log_dict["val/loss"] = val_loss + + # 添加记忆库更新统计 + log_dict.update(memory_update_stats) + + # 添加层级统计信息(选择性添加关键指标) + if layer_stats: + # 计算所有层的平均统计 + avg_gini = np.mean([v for k, v in layer_stats.items() if k.endswith('_gini_coefficient')]) + avg_coverage = np.mean([v for k, v in layer_stats.items() if k.endswith('_coverage_rate')]) + total_dead = sum([v for k, v in layer_stats.items() if k.endswith('_dead_memories')]) + total_hot = sum([v for k, v in layer_stats.items() if k.endswith('_hot_memories')]) + + log_dict.update({ + 'memory/avg_gini_coefficient': avg_gini, + 'memory/avg_coverage_rate': avg_coverage, + 'memory/total_dead_memories': total_dead, + 'memory/total_hot_memories': total_hot, + }) + Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " - f"Loss: {log_dict['loss']:.4f}, " + f"CE Loss: {log_dict['train/loss_ce']:.4f}, " + f"Balance Loss: {log_dict['train/loss_balance']:.4f}, " + f"Total Loss: {log_dict['train/loss_total']:.4f}, " + f"Val Loss: {log_dict.get('val/loss', 'N/A')}, " f"LR: {log_dict['lr']:.6f}, " f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | " f"Epoch Time Left: {format_time(epoch_remaining_time)} | " @@ -832,7 +974,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # 保存模型 (只在主进程进行) loss_total = loss.item() * args.accumulation_steps - if epoch > 1 and best_loss > loss_total and accelerator.is_main_process: + if epoch >= 0 and best_loss > loss_total and accelerator.is_main_process: best_loss = loss_total # 使用函数开始处定义的moe_path变量 ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' @@ -913,6 +1055,9 @@ def main(): parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed parser.add_argument("--model_size", type=float, default=50.0, help="模型大小") parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务") + parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数") + parser.add_argument("--val_data_path", type=str, default="dataset/stable/eval_data.json", help="验证数据集路径") + parser.add_argument("--val_interval", type=int, default=100, help="验证评估间隔") args = parser.parse_args() ######################################################### @@ -1053,6 +1198,19 @@ def main(): prefetch_factor=2 if args.num_workers > 0 else None ) + # 创建验证数据集和加载器 + val_loader = None + val_ds = create_validation_dataset(args.val_data_path, tokenizer, lm_config.max_seq_len) + if val_ds is not None: + val_loader = DataLoader( + val_ds, + batch_size=args.batch_size // 2, # 验证时使用较小批次 + pin_memory=True, + drop_last=False, + shuffle=False, + num_workers=0, # 验证时不使用多进程 + ) + ######################################################### # 创建优化器 ######################################################### @@ -1072,9 +1230,14 @@ def main(): ######################################################### # 准备训练 ######################################################### - model, optimizer, train_loader, scheduler = accelerator.prepare( - model, optimizer, train_loader, scheduler - ) + if val_loader is not None: + model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, val_loader, scheduler + ) + else: + model, optimizer, train_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, scheduler + ) ######################################################### # 训练循环 @@ -1082,7 +1245,7 @@ def main(): overall_start_time = time.time() # Record overall start time for epoch in range(args.epochs): Logger(f"开始第{epoch+1}轮训练", accelerator) - train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer + train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader) # Pass tokenizer and val_loader # 每个epoch结束后进行内存清理 Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)