diff --git a/experiment/EXPERIMENT_1_4_4.md b/experiment/EXPERIMENT_1_4_4.md index 9e55e3e..78e5f14 100644 --- a/experiment/EXPERIMENT_1_4_4.md +++ b/experiment/EXPERIMENT_1_4_4.md @@ -393,16 +393,7 @@ Loss: 2.0430 ### ✅ **[AI完成]** 改进建议 **短期优化** (下个实验): -- `移除或大幅降低Balance Loss系数,恢复记忆选择的自然模式` -- `对比不同balance_loss_coef取值(0.001, 0.01, 0.05)对性能的影响` - -**中期改进** (未来3-5个实验): -- `优化记忆库质量和初始化策略,提升记忆检索的有效性` -- `探索更智能的记忆平衡策略,平衡多样性和效率` - -**长期研究方向**: -- `研究自适应记忆选择机制,根据任务需求动态调整平衡程度` -- `探索记忆库与传统FFN的混合架构,结合两者优势` +- 使用类似vq-vae的方式对Memory Bank进行约束。 --- diff --git a/model/LMConfig.py b/model/LMConfig.py index 505376d..46b3c5c 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -42,6 +42,12 @@ class LMConfig(PretrainedConfig): knowledge_length: int = 8, knowledge_dim: int = 128, #################################################### + # EMA update related configurations (inspired by VQ-VAE) + #################################################### + use_ema_update: bool = True, # 是否使用EMA更新memory_bank + ema_decay: float = 0.999, # EMA衰减率,类似VQ-VAE中的γ + ema_update_freq: int = 1, # EMA更新频率(每N个训练步更新一次) + #################################################### # Triple extraction related configurations #################################################### max_subject_len: int = 8, @@ -83,6 +89,12 @@ class LMConfig(PretrainedConfig): self.knowledge_length = knowledge_length self.knowledge_dim = knowledge_dim #################################################### + # EMA update related configurations (inspired by VQ-VAE) + #################################################### + self.use_ema_update = use_ema_update + self.ema_decay = ema_decay + self.ema_update_freq = ema_update_freq + #################################################### # Triple extraction related configurations #################################################### self.max_subject_len = max_subject_len diff --git a/model/model_memory.py b/model/model_memory.py index d62443b..f2edd90 100644 --- a/model/model_memory.py +++ b/model/model_memory.py @@ -335,17 +335,19 @@ class MiniMindBlock(nn.Module): self.memory_gate = MemoryGate(config) self.gated_memory_fusion = GatedMemoryFusion(config) - def forward(self, x, pos_cis, memory_bank): + def forward(self, x, pos_cis, memory_bank, collect_ema_stats=False): """ Args: x: [batch_size, seq_len, dim] pos_cis: positional encoding memory_bank: [knowledge_num, knowledge_dim] - shared memory bank + collect_ema_stats: 是否收集EMA更新统计信息 Returns: out: [batch_size, seq_len, dim] balance_loss: 该层的平衡损失 layer_stats: 该层的监控统计信息 + ema_stats: EMA更新统计信息(如果collect_ema_stats=True) """ # Self attention h_attn = self.attention(self.attention_norm(x), pos_cis) @@ -369,7 +371,20 @@ class MiniMindBlock(nn.Module): # 残差连接 out = h + memory_output - return out, balance_loss, layer_stats + # 收集EMA更新统计信息(仅在训练时且启用时) + ema_stats = None + if collect_ema_stats and self.training: + ema_stats = { + 'memory_indices': memory_indices, # [batch, seq_len, num_selected] + 'memory_scores': memory_scores, # [batch, seq_len, num_selected] + 'h_for_memory': h_for_memory, # [batch, seq_len, dim] + 'selected_memory': selected_memory, # [batch, seq_len, num_selected, knowledge_dim] + } + + if collect_ema_stats: + return out, balance_loss, layer_stats, ema_stats + else: + return out, balance_loss, layer_stats class MiniMindLM(PreTrainedModel): @@ -390,10 +405,25 @@ class MiniMindLM(PreTrainedModel): persistent=False) # 初始化共享记忆库 - self.memory_bank = nn.Parameter( - torch.randn(params.knowledge_num, params.knowledge_dim), - requires_grad=True - ) + # VQ-VAE风格:memory_bank作为codebook,使用EMA更新而非梯度更新 + if params.use_ema_update: + self.memory_bank = nn.Parameter( + torch.randn(params.knowledge_num, params.knowledge_dim), + requires_grad=False # 禁用梯度更新,使用EMA更新 + ) + else: + self.memory_bank = nn.Parameter( + torch.randn(params.knowledge_num, params.knowledge_dim), + requires_grad=True # 传统梯度更新 + ) + + # EMA更新相关缓冲区 + if params.use_ema_update: + # 记录每个memory条目的更新统计 + self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False) + self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False) + # EMA更新频率计数器 + self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False) # 记录上一步的记忆库状态,用于计算更新统计 self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False) @@ -453,15 +483,23 @@ class MiniMindLM(PreTrainedModel): **args): """Forward pass without KV cache support""" start_pos = args.get('start_pos', 0) + collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training) + h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] # 收集所有层的平衡损失和统计信息 total_balance_loss = 0 all_layer_stats = {} + all_ema_stats = {} for layer_idx, layer in enumerate(self.layers): - h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank) + if collect_ema_stats: + h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=True) + all_ema_stats[f'layer_{layer_idx}'] = ema_stats + else: + h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=False) + total_balance_loss += balance_loss # 为每层的统计信息添加前缀 for key, value in layer_stats.items(): @@ -476,6 +514,7 @@ class MiniMindLM(PreTrainedModel): self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 + self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息 self.OUT.__setitem__('past_key_values', None) # 不支持KV cache return self.OUT @@ -536,4 +575,132 @@ class MiniMindLM(PreTrainedModel): input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: - break \ No newline at end of file + break + + def apply_ema_update(self, ema_stats): + """ + 应用VQ-VAE风格的EMA更新到memory_bank + + Args: + ema_stats: 从forward pass收集的EMA统计信息,格式为: + {'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...} + """ + if not self.params.use_ema_update: + return {} + + # 增加EMA步数计数器 + self.ema_step_counter += 1 + + # 检查是否需要进行EMA更新 + if self.ema_step_counter % self.params.ema_update_freq != 0: + return {'ema_update_applied': False, 'reason': 'frequency_check_failed'} + + with torch.no_grad(): + device = self.memory_bank.device + knowledge_num, knowledge_dim = self.memory_bank.shape + + # 重置累积缓冲区 + self.ema_sum_buffer.zero_() + self.ema_update_count.zero_() + + total_selections = 0 + total_layers = 0 + + # 收集所有层的EMA统计信息 + for layer_name, layer_ema_stats in ema_stats.items(): + if layer_ema_stats is None: + continue + + total_layers += 1 + memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected] + h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim] + + bsz, seq_len, num_selected = memory_indices.shape + total_selections += bsz * seq_len * num_selected + + # 将h_for_memory投影到knowledge_dim维度(如果维度不匹配) + if h_for_memory.size(-1) != knowledge_dim: + # 使用简单的线性投影(截断或者填零) + if h_for_memory.size(-1) > knowledge_dim: + # 截断到knowledge_dim + h_proj = h_for_memory[..., :knowledge_dim] + else: + # 用零填充到knowledge_dim + pad_size = knowledge_dim - h_for_memory.size(-1) + h_proj = F.pad(h_for_memory, (0, pad_size), 'constant', 0) + else: + h_proj = h_for_memory + + # 展平索引和对应的h_for_memory + flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected] + + # 为每个选择位置复制对应的h_for_memory + # [batch, seq_len, num_selected] -> [batch, seq_len, num_selected, dim] + h_expanded = h_proj.unsqueeze(2).expand(-1, -1, num_selected, -1) + flat_h = h_expanded.reshape(-1, knowledge_dim) # [batch * seq_len * num_selected, knowledge_dim] + + # 确保数据类型匹配 + flat_indices = flat_indices.long().to(device) # 索引必须是long类型 + flat_h = flat_h.to(dtype=self.ema_sum_buffer.dtype, device=device) # 数据类型匹配 + + # 累积每个memory条目的h_for_memory值 + # scatter_add_: 将flat_h的值累加到ema_sum_buffer的对应位置 + self.ema_sum_buffer.scatter_add_(0, flat_indices.unsqueeze(1).expand(-1, knowledge_dim), flat_h) + + # 统计每个memory条目被选择的次数 + count_ones = torch.ones_like(flat_indices, dtype=self.ema_update_count.dtype, device=device) + self.ema_update_count.scatter_add_(0, flat_indices, count_ones) + + # 计算平均值并应用EMA更新 + # 防止除零错误 + non_zero_mask = self.ema_update_count > 0 + avg_h_for_selected = torch.zeros_like(self.memory_bank) + + if non_zero_mask.any(): + # 计算被选择memory条目的平均h_for_memory + avg_h_for_selected[non_zero_mask] = ( + self.ema_sum_buffer[non_zero_mask] / self.ema_update_count[non_zero_mask].unsqueeze(1) + ) + + # 确保数据类型匹配并应用EMA更新:new = γ * old + (1-γ) * new_avg + # 只更新被选择的memory条目 + old_memory = self.memory_bank[non_zero_mask] + new_avg = avg_h_for_selected[non_zero_mask].to(dtype=old_memory.dtype) + + self.memory_bank[non_zero_mask] = ( + self.params.ema_decay * old_memory + + (1 - self.params.ema_decay) * new_avg + ) + + # 计算更新统计信息 + updated_memories = non_zero_mask.sum().item() + update_ratio = updated_memories / knowledge_num + + # 计算EMA更新幅度统计 + if hasattr(self, 'prev_memory_bank_ema') and self.prev_memory_bank_ema.numel() > 0: + l2_changes = torch.norm(self.memory_bank[non_zero_mask] - self.prev_memory_bank_ema[non_zero_mask], p=2, dim=1) + avg_change = l2_changes.mean().item() if len(l2_changes) > 0 else 0.0 + max_change = l2_changes.max().item() if len(l2_changes) > 0 else 0.0 + else: + avg_change = 0.0 + max_change = 0.0 + + # 保存当前memory_bank状态用于下次比较 + if not hasattr(self, 'prev_memory_bank_ema'): + self.register_buffer('prev_memory_bank_ema', torch.zeros_like(self.memory_bank), persistent=False) + self.prev_memory_bank_ema.copy_(self.memory_bank) + + update_stats = { + 'ema_update_applied': True, + 'ema_step': self.ema_step_counter.item(), + 'total_selections': total_selections, + 'total_layers': total_layers, + 'updated_memories': updated_memories, + 'update_ratio': update_ratio, + 'avg_ema_change': avg_change, + 'max_ema_change': max_change, + 'ema_decay': self.params.ema_decay, + 'selected_memory_coverage': (self.ema_update_count > 0).float().mean().item(), + } + + return update_stats \ No newline at end of file diff --git a/run_file/experiment_1_4_5.sh b/run_file/experiment_1_4_5.sh new file mode 100644 index 0000000..560a0a4 --- /dev/null +++ b/run_file/experiment_1_4_5.sh @@ -0,0 +1,352 @@ +#!/bin/bash + +# ============================================================================ +# MiniMind 实验脚本 - Experiment 1.4.5 +# ============================================================================ +# +# 🎯 实验目标: +# 基于实验1.4.4,实现VQ-VAE风格的EMA更新机制替代memory_bank的梯度更新 +# +# 使用方法: +# bash run_file/experiment_1_4_5.sh +# ============================================================================ + +# ---------------------------------------------------------------------------- +# 🧑‍🔬 实验基本信息 +# ---------------------------------------------------------------------------- +EXPERIMENT_VERSION="1.4.5" +EXPERIMENT_DESCRIPTION="VQ-VAE风格EMA更新机制实验" +RESEARCHER_NAME="AI Assistant" +EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')" + +# ---------------------------------------------------------------------------- +# 🤖 环境配置 +# ---------------------------------------------------------------------------- + +# 调试和监控环境变量 +export NCCL_DEBUG=INFO +export PYTHONFAULTHANDLER=1 +export CUDA_LAUNCH_BLOCKING=1 + +# SwanLab 配置 +export SWANLAB_PROJECT="MiniMind-Experiment-1.4.5" + +# 日志配置 +LOG_DIR="out/experiment_${EXPERIMENT_VERSION}" +mkdir -p "$LOG_DIR" +LOG_FILE="$LOG_DIR/experiment.log" + +# ---------------------------------------------------------------------------- +# 🤖 硬件配置 +# ---------------------------------------------------------------------------- +CUDA_VISIBLE_DEVICES="0" +NUM_PROCESSES="1" +MIXED_PRECISION="bf16" +MAIN_PROCESS_PORT="29500" + +# ---------------------------------------------------------------------------- +# 🤖 模型架构参数 +# ---------------------------------------------------------------------------- +MODEL_TYPE="model_memory" +MODEL_SIZE="50.0" +DIM="512" +N_LAYERS="8" +N_HEADS="32" +MAX_SEQ_LEN="512" +USE_MOE="false" + +# 知识库配置(使用更大规模测试EMA机制) +KNOWLEDGE_NUM="1048576" # 1024x1024 = 1048576,更大规模测试EMA +KNOWLEDGE_LENGTH="32" +KNOWLEDGE_DIM="128" +DISABLE_DB="false" + +# ---------------------------------------------------------------------------- +# 🤖 训练超参数 +# ---------------------------------------------------------------------------- +EPOCHS="3" +EMBEDDING_EPOCH="2" +BATCH_SIZE="96" +ACCUMULATION_STEPS="8" +LEARNING_RATE="2e-4" +DTYPE="bfloat16" +GRAD_CLIP="1.0" +WARMUP_ITERS="0" + +# 平衡损失配置(沿用1.4.4的成功配置) +BALANCE_LOSS_COEF="0.1" + +# 数据和缓存路径 +DATA_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl" +DATABASE_INIT_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json" +CLUSTER_CACHE_PATH="None" # 禁用聚类缓存以测试EMA效果 +VAL_DATA_PATH="dataset/stable/eval_data.json" + +# 训练配置 +NUM_WORKERS="1" +LOG_INTERVAL="100" +VAL_INTERVAL="100" +SAVE_INTERVAL="10000" + +# 性能分析配置 +USE_PROFILE="true" +PROFILE_INTERVAL="10" +MEMORY_MONITOR_INTERVAL="100" + +# 高级功能 +USE_FLASH_ATTN="true" +FAST_CLUSTERING="true" + +# ---------------------------------------------------------------------------- +# 🤖 预检查函数 +# ---------------------------------------------------------------------------- +check_environment() { + echo "🔍 环境检查中..." + + # 检查GPU可用性 + if ! nvidia-smi &> /dev/null; then + echo "❌ 错误: 未检测到GPU或nvidia-smi不可用" + exit 1 + fi + + # 检查CUDA设备 + if ! nvidia-smi -i "$CUDA_VISIBLE_DEVICES" &> /dev/null; then + echo "❌ 错误: GPU $CUDA_VISIBLE_DEVICES 不可用" + exit 1 + fi + + # 检查Python环境 + if ! .venv/bin/python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null; then + echo "❌ 错误: PyTorch未正确安装" + exit 1 + fi + + # 检查数据文件 + if [[ ! -f "$DATA_PATH" ]]; then + echo "❌ 错误: 训练数据文件不存在: $DATA_PATH" + exit 1 + fi + + if [[ ! -f "$DATABASE_INIT_PATH" ]]; then + echo "❌ 错误: 数据库初始化文件不存在: $DATABASE_INIT_PATH" + exit 1 + fi + + # 检查EMA相关模型实现 + if ! .venv/bin/python -c "from model.model_memory import *; print('EMA模型实现检查通过')" 2>/dev/null; then + echo "❌ 错误: EMA模型实现存在问题" + exit 1 + fi + + echo "✅ 环境检查通过" +} + +# ---------------------------------------------------------------------------- +# 🤖 实验信息记录 +# ---------------------------------------------------------------------------- +log_experiment_info() { + echo "📝 记录实验信息..." + cat > "$LOG_DIR/experiment_info.txt" << EOF +======================================== +MiniMind 实验信息 +======================================== +实验版本: $EXPERIMENT_VERSION +实验描述: $EXPERIMENT_DESCRIPTION +研究者: $RESEARCHER_NAME +开始时间: $EXPERIMENT_DATE +======================================== +硬件配置: +GPU设备: $CUDA_VISIBLE_DEVICES +进程数: $NUM_PROCESSES +混合精度: $MIXED_PRECISION +======================================== +模型配置: +模型类型: $MODEL_TYPE +模型大小: $MODEL_SIZE MB +维度: $DIM +层数: $N_LAYERS +注意力头数: $N_HEADS +最大序列长度: $MAX_SEQ_LEN +知识库大小: $KNOWLEDGE_NUM +知识长度: $KNOWLEDGE_LENGTH +知识维度: $KNOWLEDGE_DIM +======================================== +训练配置: +训练轮次: $EPOCHS +批次大小: $BATCH_SIZE +学习率: $LEARNING_RATE +梯度累积: $ACCUMULATION_STEPS +数据类型: $DTYPE +平衡损失系数: $BALANCE_LOSS_COEF +======================================== +EMA配置: +使用EMA更新: 是(VQ-VAE风格) +EMA衰减率: 0.999(默认配置) +EMA更新频率: 1(每步更新) +======================================== +数据路径: +训练数据: $DATA_PATH +验证数据: $VAL_DATA_PATH +数据库初始化: $DATABASE_INIT_PATH +聚类缓存: $CLUSTER_CACHE_PATH +======================================== +EOF +} + +# ---------------------------------------------------------------------------- +# 🤖 主执行函数 +# ---------------------------------------------------------------------------- +run_experiment() { + echo "🚀 开始执行实验 $EXPERIMENT_VERSION" + echo "📄 实验描述: $EXPERIMENT_DESCRIPTION" + echo "⏰ 开始时间: $EXPERIMENT_DATE" + + # 构建训练命令 + local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES .venv/bin/python train_pretrain_accelerate.py" + + # 添加训练参数 + train_cmd+=" --out_dir \"$LOG_DIR\"" + train_cmd+=" --epochs $EPOCHS" + train_cmd+=" --embedding_epoch $EMBEDDING_EPOCH" + train_cmd+=" --batch_size $BATCH_SIZE" + train_cmd+=" --learning_rate $LEARNING_RATE" + train_cmd+=" --dtype $DTYPE" + train_cmd+=" --num_workers $NUM_WORKERS" + train_cmd+=" --accumulation_steps $ACCUMULATION_STEPS" + train_cmd+=" --grad_clip $GRAD_CLIP" + train_cmd+=" --warmup_iters $WARMUP_ITERS" + train_cmd+=" --log_interval $LOG_INTERVAL" + train_cmd+=" --val_interval $VAL_INTERVAL" + train_cmd+=" --save_interval $SAVE_INTERVAL" + train_cmd+=" --dim $DIM" + train_cmd+=" --n_layers $N_LAYERS" + train_cmd+=" --n_heads $N_HEADS" + train_cmd+=" --max_seq_len $MAX_SEQ_LEN" + train_cmd+=" --data_path \"$DATA_PATH\"" + train_cmd+=" --val_data_path \"$VAL_DATA_PATH\"" + train_cmd+=" --knowledge_num $KNOWLEDGE_NUM" + train_cmd+=" --knowledge_length $KNOWLEDGE_LENGTH" + train_cmd+=" --database_init_path \"$DATABASE_INIT_PATH\"" + train_cmd+=" --memory_monitor_interval $MEMORY_MONITOR_INTERVAL" + train_cmd+=" --model_type \"$MODEL_TYPE\"" + train_cmd+=" --model_size $MODEL_SIZE" + train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF" + + # 可选参数 + if [[ "$USE_PROFILE" == "true" ]]; then + train_cmd+=" --profile" + train_cmd+=" --profile_interval $PROFILE_INTERVAL" + fi + + if [[ "$USE_FLASH_ATTN" == "true" ]]; then + train_cmd+=" --use_flash_attn" + fi + + if [[ "$FAST_CLUSTERING" == "true" ]]; then + train_cmd+=" --fast_clustering" + fi + + if [[ "$CLUSTER_CACHE_PATH" != "None" ]]; then + train_cmd+=" --cluster_cache_path \"$CLUSTER_CACHE_PATH\"" + fi + + # SwanLab配置 + train_cmd+=" --use_swanlab" + train_cmd+=" --swanlab_project \"$SWANLAB_PROJECT\"" + train_cmd+=" --swanlab_online True" + + echo "📋 执行命令:" + echo "$train_cmd" + echo + + # 记录命令到日志文件 + echo "执行命令: $train_cmd" >> "$LOG_FILE" + echo "开始时间: $(date)" >> "$LOG_FILE" + + # 使用nohup执行训练(后台运行,输出写入日志文件) + echo "🔄 使用nohup后台运行训练,输出将写入日志文件: $LOG_FILE" + + # 创建训练脚本 + train_script="/tmp/train_${EXPERIMENT_VERSION}.sh" + cat > "$train_script" << EOF +#!/bin/bash +cd /home/pci/ycz/Code/pretrain-worktree +source /home/pci/ycz/Code/pretrain-worktree/.venv/bin/activate +$train_cmd +echo "结束时间: \$(date)" +echo "退出代码: \$?" +EOF + chmod +x "$train_script" + + # 使用nohup后台运行 + nohup bash "$train_script" >> "$LOG_FILE" 2>&1 & + local train_pid=$! + + echo "🔥 训练进程已启动,PID: $train_pid" + echo "训练PID: $train_pid" >> "$LOG_FILE" + echo "训练脚本: $train_script" >> "$LOG_FILE" + + # 等待几秒确保进程启动 + sleep 5 + + # 检查进程是否还在运行 + if kill -0 $train_pid 2>/dev/null; then + echo "✅ 训练进程正在后台运行" + echo "📋 实时查看日志: tail -f $LOG_FILE" + echo "📋 检查进程状态: ps -p $train_pid" + echo "🛑 停止训练: kill $train_pid" + echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT" + echo "" + echo "🧠 VQ-VAE风格EMA更新机制正在测试中..." + echo " - memory_bank使用EMA更新而非梯度更新" + echo " - EMA衰减率: 0.999" + echo " - 每步更新频率" + echo " - 预期: 更稳定的训练和更好的记忆表示学习" + echo "" + echo "训练正在后台运行,可以安全关闭终端。" + else + echo "❌ 训练进程启动失败" + echo "📋 查看日志: $LOG_FILE" + exit 1 + fi +} + +# ---------------------------------------------------------------------------- +# 🤖 清理函数 +# ---------------------------------------------------------------------------- +cleanup() { + echo "🧹 清理临时文件..." + # 删除临时验证文件 + rm -f /tmp/temp_val.jsonl +} + +# ---------------------------------------------------------------------------- +# 🤖 信号处理 +# ---------------------------------------------------------------------------- +trap cleanup EXIT +trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM + +# ---------------------------------------------------------------------------- +# 🤖 主程序入口 +# ---------------------------------------------------------------------------- +main() { + echo "============================================================================" + echo "🧠 MiniMind 预训练实验 1.4.5" + echo "🎯 VQ-VAE风格EMA更新机制 - 替代memory_bank梯度更新" + echo "============================================================================" + + # 执行检查和初始化 + check_environment + log_experiment_info + + # 运行实验 + run_experiment + + echo "============================================================================" + echo "✅ 实验 $EXPERIMENT_VERSION 启动完成" + echo "📅 启动时间: $(date)" + echo "============================================================================" +} + +# 执行主程序 +main "$@" \ No newline at end of file diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index d92b548..02d1c22 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -769,6 +769,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # 当使用DeepSpeed时,zero_grad()会在step()之后自动调用 # 但为了安全起见,我们仍然显式调用它 optimizer.zero_grad() + + # VQ-VAE风格的EMA更新(仅在启用时执行) + if hasattr(res, 'ema_stats') and res.ema_stats is not None: + unwrapped_model = accelerator.unwrap_model(model) + if hasattr(unwrapped_model, 'apply_ema_update'): + ema_update_stats = unwrapped_model.apply_ema_update(res.ema_stats) + # 记录EMA更新统计信息 + if step % args.log_interval == 0 and accelerator.is_main_process and ema_update_stats.get('ema_update_applied', False): + total_memories = args.knowledge_num + Logger(f"EMA Update - Step: {ema_update_stats['ema_step']}, " + f"Updated memories: {ema_update_stats['updated_memories']}/{total_memories} " + f"({ema_update_stats['update_ratio']:.4f}), " + f"Avg change: {ema_update_stats['avg_ema_change']:.6f}, " + f"Coverage: {ema_update_stats['selected_memory_coverage']:.4f}", accelerator) # 计时优化器步骤结束 (只在主进程进行) if args.profile and accelerator.is_main_process and optimizer_end is not None: @@ -1021,7 +1035,7 @@ def main(): parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数") - parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数 @@ -1052,7 +1066,7 @@ def main(): parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控") parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)") - parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed + parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed parser.add_argument("--model_size", type=float, default=50.0, help="模型大小") parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务") parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数") @@ -1214,7 +1228,17 @@ def main(): ######################################################### # 创建优化器 ######################################################### - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + # 如果启用EMA更新,需要过滤掉memory_bank参数(因为它不再需要梯度更新) + if hasattr(model.params, 'use_ema_update') and model.params.use_ema_update: + # 只包含requires_grad=True的参数 + optimizer_params = [p for p in model.parameters() if p.requires_grad] + Logger(f"EMA更新模式:优化器包含 {len(optimizer_params)} 个参数(过滤掉memory_bank)") + Logger(f"总参数:{sum(p.numel() for p in model.parameters())} | 可训练参数:{sum(p.numel() for p in optimizer_params)}") + optimizer = optim.AdamW(optimizer_params, lr=args.learning_rate) + else: + # 传统模式:所有参数都使用梯度更新 + Logger("传统梯度更新模式:优化器包含所有模型参数") + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) ######################################################### # 创建学习率调度器