Experiment 1.4.5:使用VQ-VAE的EMA来更新数据库

This commit is contained in:
Yu Chengzhang 2025-08-09 10:47:35 +08:00
parent 9244d47c39
commit a7fe947a35
5 changed files with 567 additions and 21 deletions

View File

@ -393,16 +393,7 @@ Loss: 2.0430
### ✅ **[AI完成]** 改进建议
**短期优化** (下个实验):
- `移除或大幅降低Balance Loss系数恢复记忆选择的自然模式`
- `对比不同balance_loss_coef取值(0.001, 0.01, 0.05)对性能的影响`
**中期改进** (未来3-5个实验):
- `优化记忆库质量和初始化策略,提升记忆检索的有效性`
- `探索更智能的记忆平衡策略,平衡多样性和效率`
**长期研究方向**:
- `研究自适应记忆选择机制,根据任务需求动态调整平衡程度`
- `探索记忆库与传统FFN的混合架构结合两者优势`
- 使用类似vq-vae的方式对Memory Bank进行约束。
---

View File

@ -42,6 +42,12 @@ class LMConfig(PretrainedConfig):
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
ema_decay: float = 0.999, # EMA衰减率类似VQ-VAE中的γ
ema_update_freq: int = 1, # EMA更新频率每N个训练步更新一次
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
@ -83,6 +89,12 @@ class LMConfig(PretrainedConfig):
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
self.use_ema_update = use_ema_update
self.ema_decay = ema_decay
self.ema_update_freq = ema_update_freq
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len

View File

@ -335,17 +335,19 @@ class MiniMindBlock(nn.Module):
self.memory_gate = MemoryGate(config)
self.gated_memory_fusion = GatedMemoryFusion(config)
def forward(self, x, pos_cis, memory_bank):
def forward(self, x, pos_cis, memory_bank, collect_ema_stats=False):
"""
Args:
x: [batch_size, seq_len, dim]
pos_cis: positional encoding
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
collect_ema_stats: 是否收集EMA更新统计信息
Returns:
out: [batch_size, seq_len, dim]
balance_loss: 该层的平衡损失
layer_stats: 该层的监控统计信息
ema_stats: EMA更新统计信息如果collect_ema_stats=True
"""
# Self attention
h_attn = self.attention(self.attention_norm(x), pos_cis)
@ -369,7 +371,20 @@ class MiniMindBlock(nn.Module):
# 残差连接
out = h + memory_output
return out, balance_loss, layer_stats
# 收集EMA更新统计信息仅在训练时且启用时
ema_stats = None
if collect_ema_stats and self.training:
ema_stats = {
'memory_indices': memory_indices, # [batch, seq_len, num_selected]
'memory_scores': memory_scores, # [batch, seq_len, num_selected]
'h_for_memory': h_for_memory, # [batch, seq_len, dim]
'selected_memory': selected_memory, # [batch, seq_len, num_selected, knowledge_dim]
}
if collect_ema_stats:
return out, balance_loss, layer_stats, ema_stats
else:
return out, balance_loss, layer_stats
class MiniMindLM(PreTrainedModel):
@ -390,10 +405,25 @@ class MiniMindLM(PreTrainedModel):
persistent=False)
# 初始化共享记忆库
self.memory_bank = nn.Parameter(
torch.randn(params.knowledge_num, params.knowledge_dim),
requires_grad=True
)
# VQ-VAE风格memory_bank作为codebook使用EMA更新而非梯度更新
if params.use_ema_update:
self.memory_bank = nn.Parameter(
torch.randn(params.knowledge_num, params.knowledge_dim),
requires_grad=False # 禁用梯度更新使用EMA更新
)
else:
self.memory_bank = nn.Parameter(
torch.randn(params.knowledge_num, params.knowledge_dim),
requires_grad=True # 传统梯度更新
)
# EMA更新相关缓冲区
if params.use_ema_update:
# 记录每个memory条目的更新统计
self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False)
self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False)
# EMA更新频率计数器
self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False)
# 记录上一步的记忆库状态,用于计算更新统计
self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False)
@ -453,15 +483,23 @@ class MiniMindLM(PreTrainedModel):
**args):
"""Forward pass without KV cache support"""
start_pos = args.get('start_pos', 0)
collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
# 收集所有层的平衡损失和统计信息
total_balance_loss = 0
all_layer_stats = {}
all_ema_stats = {}
for layer_idx, layer in enumerate(self.layers):
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank)
if collect_ema_stats:
h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=True)
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
else:
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=False)
total_balance_loss += balance_loss
# 为每层的统计信息添加前缀
for key, value in layer_stats.items():
@ -476,6 +514,7 @@ class MiniMindLM(PreTrainedModel):
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
return self.OUT
@ -536,4 +575,132 @@ class MiniMindLM(PreTrainedModel):
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
break
def apply_ema_update(self, ema_stats):
"""
应用VQ-VAE风格的EMA更新到memory_bank
Args:
ema_stats: 从forward pass收集的EMA统计信息格式为
{'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...}
"""
if not self.params.use_ema_update:
return {}
# 增加EMA步数计数器
self.ema_step_counter += 1
# 检查是否需要进行EMA更新
if self.ema_step_counter % self.params.ema_update_freq != 0:
return {'ema_update_applied': False, 'reason': 'frequency_check_failed'}
with torch.no_grad():
device = self.memory_bank.device
knowledge_num, knowledge_dim = self.memory_bank.shape
# 重置累积缓冲区
self.ema_sum_buffer.zero_()
self.ema_update_count.zero_()
total_selections = 0
total_layers = 0
# 收集所有层的EMA统计信息
for layer_name, layer_ema_stats in ema_stats.items():
if layer_ema_stats is None:
continue
total_layers += 1
memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected]
h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim]
bsz, seq_len, num_selected = memory_indices.shape
total_selections += bsz * seq_len * num_selected
# 将h_for_memory投影到knowledge_dim维度如果维度不匹配
if h_for_memory.size(-1) != knowledge_dim:
# 使用简单的线性投影(截断或者填零)
if h_for_memory.size(-1) > knowledge_dim:
# 截断到knowledge_dim
h_proj = h_for_memory[..., :knowledge_dim]
else:
# 用零填充到knowledge_dim
pad_size = knowledge_dim - h_for_memory.size(-1)
h_proj = F.pad(h_for_memory, (0, pad_size), 'constant', 0)
else:
h_proj = h_for_memory
# 展平索引和对应的h_for_memory
flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected]
# 为每个选择位置复制对应的h_for_memory
# [batch, seq_len, num_selected] -> [batch, seq_len, num_selected, dim]
h_expanded = h_proj.unsqueeze(2).expand(-1, -1, num_selected, -1)
flat_h = h_expanded.reshape(-1, knowledge_dim) # [batch * seq_len * num_selected, knowledge_dim]
# 确保数据类型匹配
flat_indices = flat_indices.long().to(device) # 索引必须是long类型
flat_h = flat_h.to(dtype=self.ema_sum_buffer.dtype, device=device) # 数据类型匹配
# 累积每个memory条目的h_for_memory值
# scatter_add_: 将flat_h的值累加到ema_sum_buffer的对应位置
self.ema_sum_buffer.scatter_add_(0, flat_indices.unsqueeze(1).expand(-1, knowledge_dim), flat_h)
# 统计每个memory条目被选择的次数
count_ones = torch.ones_like(flat_indices, dtype=self.ema_update_count.dtype, device=device)
self.ema_update_count.scatter_add_(0, flat_indices, count_ones)
# 计算平均值并应用EMA更新
# 防止除零错误
non_zero_mask = self.ema_update_count > 0
avg_h_for_selected = torch.zeros_like(self.memory_bank)
if non_zero_mask.any():
# 计算被选择memory条目的平均h_for_memory
avg_h_for_selected[non_zero_mask] = (
self.ema_sum_buffer[non_zero_mask] / self.ema_update_count[non_zero_mask].unsqueeze(1)
)
# 确保数据类型匹配并应用EMA更新new = γ * old + (1-γ) * new_avg
# 只更新被选择的memory条目
old_memory = self.memory_bank[non_zero_mask]
new_avg = avg_h_for_selected[non_zero_mask].to(dtype=old_memory.dtype)
self.memory_bank[non_zero_mask] = (
self.params.ema_decay * old_memory +
(1 - self.params.ema_decay) * new_avg
)
# 计算更新统计信息
updated_memories = non_zero_mask.sum().item()
update_ratio = updated_memories / knowledge_num
# 计算EMA更新幅度统计
if hasattr(self, 'prev_memory_bank_ema') and self.prev_memory_bank_ema.numel() > 0:
l2_changes = torch.norm(self.memory_bank[non_zero_mask] - self.prev_memory_bank_ema[non_zero_mask], p=2, dim=1)
avg_change = l2_changes.mean().item() if len(l2_changes) > 0 else 0.0
max_change = l2_changes.max().item() if len(l2_changes) > 0 else 0.0
else:
avg_change = 0.0
max_change = 0.0
# 保存当前memory_bank状态用于下次比较
if not hasattr(self, 'prev_memory_bank_ema'):
self.register_buffer('prev_memory_bank_ema', torch.zeros_like(self.memory_bank), persistent=False)
self.prev_memory_bank_ema.copy_(self.memory_bank)
update_stats = {
'ema_update_applied': True,
'ema_step': self.ema_step_counter.item(),
'total_selections': total_selections,
'total_layers': total_layers,
'updated_memories': updated_memories,
'update_ratio': update_ratio,
'avg_ema_change': avg_change,
'max_ema_change': max_change,
'ema_decay': self.params.ema_decay,
'selected_memory_coverage': (self.ema_update_count > 0).float().mean().item(),
}
return update_stats

View File

@ -0,0 +1,352 @@
#!/bin/bash
# ============================================================================
# MiniMind 实验脚本 - Experiment 1.4.5
# ============================================================================
#
# 🎯 实验目标:
# 基于实验1.4.4实现VQ-VAE风格的EMA更新机制替代memory_bank的梯度更新
#
# 使用方法:
# bash run_file/experiment_1_4_5.sh
# ============================================================================
# ----------------------------------------------------------------------------
# 🧑‍🔬 实验基本信息
# ----------------------------------------------------------------------------
EXPERIMENT_VERSION="1.4.5"
EXPERIMENT_DESCRIPTION="VQ-VAE风格EMA更新机制实验"
RESEARCHER_NAME="AI Assistant"
EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')"
# ----------------------------------------------------------------------------
# 🤖 环境配置
# ----------------------------------------------------------------------------
# 调试和监控环境变量
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=1
# SwanLab 配置
export SWANLAB_PROJECT="MiniMind-Experiment-1.4.5"
# 日志配置
LOG_DIR="out/experiment_${EXPERIMENT_VERSION}"
mkdir -p "$LOG_DIR"
LOG_FILE="$LOG_DIR/experiment.log"
# ----------------------------------------------------------------------------
# 🤖 硬件配置
# ----------------------------------------------------------------------------
CUDA_VISIBLE_DEVICES="0"
NUM_PROCESSES="1"
MIXED_PRECISION="bf16"
MAIN_PROCESS_PORT="29500"
# ----------------------------------------------------------------------------
# 🤖 模型架构参数
# ----------------------------------------------------------------------------
MODEL_TYPE="model_memory"
MODEL_SIZE="50.0"
DIM="512"
N_LAYERS="8"
N_HEADS="32"
MAX_SEQ_LEN="512"
USE_MOE="false"
# 知识库配置使用更大规模测试EMA机制
KNOWLEDGE_NUM="1048576" # 1024x1024 = 1048576更大规模测试EMA
KNOWLEDGE_LENGTH="32"
KNOWLEDGE_DIM="128"
DISABLE_DB="false"
# ----------------------------------------------------------------------------
# 🤖 训练超参数
# ----------------------------------------------------------------------------
EPOCHS="3"
EMBEDDING_EPOCH="2"
BATCH_SIZE="96"
ACCUMULATION_STEPS="8"
LEARNING_RATE="2e-4"
DTYPE="bfloat16"
GRAD_CLIP="1.0"
WARMUP_ITERS="0"
# 平衡损失配置沿用1.4.4的成功配置)
BALANCE_LOSS_COEF="0.1"
# 数据和缓存路径
DATA_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl"
DATABASE_INIT_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json"
CLUSTER_CACHE_PATH="None" # 禁用聚类缓存以测试EMA效果
VAL_DATA_PATH="dataset/stable/eval_data.json"
# 训练配置
NUM_WORKERS="1"
LOG_INTERVAL="100"
VAL_INTERVAL="100"
SAVE_INTERVAL="10000"
# 性能分析配置
USE_PROFILE="true"
PROFILE_INTERVAL="10"
MEMORY_MONITOR_INTERVAL="100"
# 高级功能
USE_FLASH_ATTN="true"
FAST_CLUSTERING="true"
# ----------------------------------------------------------------------------
# 🤖 预检查函数
# ----------------------------------------------------------------------------
check_environment() {
echo "🔍 环境检查中..."
# 检查GPU可用性
if ! nvidia-smi &> /dev/null; then
echo "❌ 错误: 未检测到GPU或nvidia-smi不可用"
exit 1
fi
# 检查CUDA设备
if ! nvidia-smi -i "$CUDA_VISIBLE_DEVICES" &> /dev/null; then
echo "❌ 错误: GPU $CUDA_VISIBLE_DEVICES 不可用"
exit 1
fi
# 检查Python环境
if ! .venv/bin/python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null; then
echo "❌ 错误: PyTorch未正确安装"
exit 1
fi
# 检查数据文件
if [[ ! -f "$DATA_PATH" ]]; then
echo "❌ 错误: 训练数据文件不存在: $DATA_PATH"
exit 1
fi
if [[ ! -f "$DATABASE_INIT_PATH" ]]; then
echo "❌ 错误: 数据库初始化文件不存在: $DATABASE_INIT_PATH"
exit 1
fi
# 检查EMA相关模型实现
if ! .venv/bin/python -c "from model.model_memory import *; print('EMA模型实现检查通过')" 2>/dev/null; then
echo "❌ 错误: EMA模型实现存在问题"
exit 1
fi
echo "✅ 环境检查通过"
}
# ----------------------------------------------------------------------------
# 🤖 实验信息记录
# ----------------------------------------------------------------------------
log_experiment_info() {
echo "📝 记录实验信息..."
cat > "$LOG_DIR/experiment_info.txt" << EOF
========================================
MiniMind 实验信息
========================================
实验版本: $EXPERIMENT_VERSION
实验描述: $EXPERIMENT_DESCRIPTION
研究者: $RESEARCHER_NAME
开始时间: $EXPERIMENT_DATE
========================================
硬件配置:
GPU设备: $CUDA_VISIBLE_DEVICES
进程数: $NUM_PROCESSES
混合精度: $MIXED_PRECISION
========================================
模型配置:
模型类型: $MODEL_TYPE
模型大小: $MODEL_SIZE MB
维度: $DIM
层数: $N_LAYERS
注意力头数: $N_HEADS
最大序列长度: $MAX_SEQ_LEN
知识库大小: $KNOWLEDGE_NUM
知识长度: $KNOWLEDGE_LENGTH
知识维度: $KNOWLEDGE_DIM
========================================
训练配置:
训练轮次: $EPOCHS
批次大小: $BATCH_SIZE
学习率: $LEARNING_RATE
梯度累积: $ACCUMULATION_STEPS
数据类型: $DTYPE
平衡损失系数: $BALANCE_LOSS_COEF
========================================
EMA配置:
使用EMA更新: 是VQ-VAE风格
EMA衰减率: 0.999(默认配置)
EMA更新频率: 1每步更新
========================================
数据路径:
训练数据: $DATA_PATH
验证数据: $VAL_DATA_PATH
数据库初始化: $DATABASE_INIT_PATH
聚类缓存: $CLUSTER_CACHE_PATH
========================================
EOF
}
# ----------------------------------------------------------------------------
# 🤖 主执行函数
# ----------------------------------------------------------------------------
run_experiment() {
echo "🚀 开始执行实验 $EXPERIMENT_VERSION"
echo "📄 实验描述: $EXPERIMENT_DESCRIPTION"
echo "⏰ 开始时间: $EXPERIMENT_DATE"
# 构建训练命令
local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES .venv/bin/python train_pretrain_accelerate.py"
# 添加训练参数
train_cmd+=" --out_dir \"$LOG_DIR\""
train_cmd+=" --epochs $EPOCHS"
train_cmd+=" --embedding_epoch $EMBEDDING_EPOCH"
train_cmd+=" --batch_size $BATCH_SIZE"
train_cmd+=" --learning_rate $LEARNING_RATE"
train_cmd+=" --dtype $DTYPE"
train_cmd+=" --num_workers $NUM_WORKERS"
train_cmd+=" --accumulation_steps $ACCUMULATION_STEPS"
train_cmd+=" --grad_clip $GRAD_CLIP"
train_cmd+=" --warmup_iters $WARMUP_ITERS"
train_cmd+=" --log_interval $LOG_INTERVAL"
train_cmd+=" --val_interval $VAL_INTERVAL"
train_cmd+=" --save_interval $SAVE_INTERVAL"
train_cmd+=" --dim $DIM"
train_cmd+=" --n_layers $N_LAYERS"
train_cmd+=" --n_heads $N_HEADS"
train_cmd+=" --max_seq_len $MAX_SEQ_LEN"
train_cmd+=" --data_path \"$DATA_PATH\""
train_cmd+=" --val_data_path \"$VAL_DATA_PATH\""
train_cmd+=" --knowledge_num $KNOWLEDGE_NUM"
train_cmd+=" --knowledge_length $KNOWLEDGE_LENGTH"
train_cmd+=" --database_init_path \"$DATABASE_INIT_PATH\""
train_cmd+=" --memory_monitor_interval $MEMORY_MONITOR_INTERVAL"
train_cmd+=" --model_type \"$MODEL_TYPE\""
train_cmd+=" --model_size $MODEL_SIZE"
train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF"
# 可选参数
if [[ "$USE_PROFILE" == "true" ]]; then
train_cmd+=" --profile"
train_cmd+=" --profile_interval $PROFILE_INTERVAL"
fi
if [[ "$USE_FLASH_ATTN" == "true" ]]; then
train_cmd+=" --use_flash_attn"
fi
if [[ "$FAST_CLUSTERING" == "true" ]]; then
train_cmd+=" --fast_clustering"
fi
if [[ "$CLUSTER_CACHE_PATH" != "None" ]]; then
train_cmd+=" --cluster_cache_path \"$CLUSTER_CACHE_PATH\""
fi
# SwanLab配置
train_cmd+=" --use_swanlab"
train_cmd+=" --swanlab_project \"$SWANLAB_PROJECT\""
train_cmd+=" --swanlab_online True"
echo "📋 执行命令:"
echo "$train_cmd"
echo
# 记录命令到日志文件
echo "执行命令: $train_cmd" >> "$LOG_FILE"
echo "开始时间: $(date)" >> "$LOG_FILE"
# 使用nohup执行训练后台运行输出写入日志文件
echo "🔄 使用nohup后台运行训练输出将写入日志文件: $LOG_FILE"
# 创建训练脚本
train_script="/tmp/train_${EXPERIMENT_VERSION}.sh"
cat > "$train_script" << EOF
#!/bin/bash
cd /home/pci/ycz/Code/pretrain-worktree
source /home/pci/ycz/Code/pretrain-worktree/.venv/bin/activate
$train_cmd
echo "结束时间: \$(date)"
echo "退出代码: \$?"
EOF
chmod +x "$train_script"
# 使用nohup后台运行
nohup bash "$train_script" >> "$LOG_FILE" 2>&1 &
local train_pid=$!
echo "🔥 训练进程已启动PID: $train_pid"
echo "训练PID: $train_pid" >> "$LOG_FILE"
echo "训练脚本: $train_script" >> "$LOG_FILE"
# 等待几秒确保进程启动
sleep 5
# 检查进程是否还在运行
if kill -0 $train_pid 2>/dev/null; then
echo "✅ 训练进程正在后台运行"
echo "📋 实时查看日志: tail -f $LOG_FILE"
echo "📋 检查进程状态: ps -p $train_pid"
echo "🛑 停止训练: kill $train_pid"
echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT"
echo ""
echo "🧠 VQ-VAE风格EMA更新机制正在测试中..."
echo " - memory_bank使用EMA更新而非梯度更新"
echo " - EMA衰减率: 0.999"
echo " - 每步更新频率"
echo " - 预期: 更稳定的训练和更好的记忆表示学习"
echo ""
echo "训练正在后台运行,可以安全关闭终端。"
else
echo "❌ 训练进程启动失败"
echo "📋 查看日志: $LOG_FILE"
exit 1
fi
}
# ----------------------------------------------------------------------------
# 🤖 清理函数
# ----------------------------------------------------------------------------
cleanup() {
echo "🧹 清理临时文件..."
# 删除临时验证文件
rm -f /tmp/temp_val.jsonl
}
# ----------------------------------------------------------------------------
# 🤖 信号处理
# ----------------------------------------------------------------------------
trap cleanup EXIT
trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM
# ----------------------------------------------------------------------------
# 🤖 主程序入口
# ----------------------------------------------------------------------------
main() {
echo "============================================================================"
echo "🧠 MiniMind 预训练实验 1.4.5"
echo "🎯 VQ-VAE风格EMA更新机制 - 替代memory_bank梯度更新"
echo "============================================================================"
# 执行检查和初始化
check_environment
log_experiment_info
# 运行实验
run_experiment
echo "============================================================================"
echo "✅ 实验 $EXPERIMENT_VERSION 启动完成"
echo "📅 启动时间: $(date)"
echo "============================================================================"
}
# 执行主程序
main "$@"

View File

@ -769,6 +769,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
# VQ-VAE风格的EMA更新仅在启用时执行
if hasattr(res, 'ema_stats') and res.ema_stats is not None:
unwrapped_model = accelerator.unwrap_model(model)
if hasattr(unwrapped_model, 'apply_ema_update'):
ema_update_stats = unwrapped_model.apply_ema_update(res.ema_stats)
# 记录EMA更新统计信息
if step % args.log_interval == 0 and accelerator.is_main_process and ema_update_stats.get('ema_update_applied', False):
total_memories = args.knowledge_num
Logger(f"EMA Update - Step: {ema_update_stats['ema_step']}, "
f"Updated memories: {ema_update_stats['updated_memories']}/{total_memories} "
f"({ema_update_stats['update_ratio']:.4f}), "
f"Avg change: {ema_update_stats['avg_ema_change']:.6f}, "
f"Coverage: {ema_update_stats['selected_memory_coverage']:.4f}", accelerator)
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_end is not None:
@ -1021,7 +1035,7 @@ def main():
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
@ -1052,7 +1066,7 @@ def main():
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
@ -1214,7 +1228,17 @@ def main():
#########################################################
# 创建优化器
#########################################################
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# 如果启用EMA更新需要过滤掉memory_bank参数因为它不再需要梯度更新
if hasattr(model.params, 'use_ema_update') and model.params.use_ema_update:
# 只包含requires_grad=True的参数
optimizer_params = [p for p in model.parameters() if p.requires_grad]
Logger(f"EMA更新模式优化器包含 {len(optimizer_params)} 个参数过滤掉memory_bank")
Logger(f"总参数:{sum(p.numel() for p in model.parameters())} | 可训练参数:{sum(p.numel() for p in optimizer_params)}")
optimizer = optim.AdamW(optimizer_params, lr=args.learning_rate)
else:
# 传统模式:所有参数都使用梯度更新
Logger("传统梯度更新模式:优化器包含所有模型参数")
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
#########################################################
# 创建学习率调度器