Experiment 1.4.5:使用VQ-VAE的EMA来更新数据库
This commit is contained in:
parent
9244d47c39
commit
a7fe947a35
@ -393,16 +393,7 @@ Loss: 2.0430
|
||||
|
||||
### ✅ **[AI完成]** 改进建议
|
||||
**短期优化** (下个实验):
|
||||
- `移除或大幅降低Balance Loss系数,恢复记忆选择的自然模式`
|
||||
- `对比不同balance_loss_coef取值(0.001, 0.01, 0.05)对性能的影响`
|
||||
|
||||
**中期改进** (未来3-5个实验):
|
||||
- `优化记忆库质量和初始化策略,提升记忆检索的有效性`
|
||||
- `探索更智能的记忆平衡策略,平衡多样性和效率`
|
||||
|
||||
**长期研究方向**:
|
||||
- `研究自适应记忆选择机制,根据任务需求动态调整平衡程度`
|
||||
- `探索记忆库与传统FFN的混合架构,结合两者优势`
|
||||
- 使用类似vq-vae的方式对Memory Bank进行约束。
|
||||
|
||||
---
|
||||
|
||||
|
||||
@ -42,6 +42,12 @@ class LMConfig(PretrainedConfig):
|
||||
knowledge_length: int = 8,
|
||||
knowledge_dim: int = 128,
|
||||
####################################################
|
||||
# EMA update related configurations (inspired by VQ-VAE)
|
||||
####################################################
|
||||
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
|
||||
ema_decay: float = 0.999, # EMA衰减率,类似VQ-VAE中的γ
|
||||
ema_update_freq: int = 1, # EMA更新频率(每N个训练步更新一次)
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
max_subject_len: int = 8,
|
||||
@ -83,6 +89,12 @@ class LMConfig(PretrainedConfig):
|
||||
self.knowledge_length = knowledge_length
|
||||
self.knowledge_dim = knowledge_dim
|
||||
####################################################
|
||||
# EMA update related configurations (inspired by VQ-VAE)
|
||||
####################################################
|
||||
self.use_ema_update = use_ema_update
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_update_freq = ema_update_freq
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
self.max_subject_len = max_subject_len
|
||||
|
||||
@ -335,17 +335,19 @@ class MiniMindBlock(nn.Module):
|
||||
self.memory_gate = MemoryGate(config)
|
||||
self.gated_memory_fusion = GatedMemoryFusion(config)
|
||||
|
||||
def forward(self, x, pos_cis, memory_bank):
|
||||
def forward(self, x, pos_cis, memory_bank, collect_ema_stats=False):
|
||||
"""
|
||||
Args:
|
||||
x: [batch_size, seq_len, dim]
|
||||
pos_cis: positional encoding
|
||||
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
|
||||
collect_ema_stats: 是否收集EMA更新统计信息
|
||||
|
||||
Returns:
|
||||
out: [batch_size, seq_len, dim]
|
||||
balance_loss: 该层的平衡损失
|
||||
layer_stats: 该层的监控统计信息
|
||||
ema_stats: EMA更新统计信息(如果collect_ema_stats=True)
|
||||
"""
|
||||
# Self attention
|
||||
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
||||
@ -369,6 +371,19 @@ class MiniMindBlock(nn.Module):
|
||||
# 残差连接
|
||||
out = h + memory_output
|
||||
|
||||
# 收集EMA更新统计信息(仅在训练时且启用时)
|
||||
ema_stats = None
|
||||
if collect_ema_stats and self.training:
|
||||
ema_stats = {
|
||||
'memory_indices': memory_indices, # [batch, seq_len, num_selected]
|
||||
'memory_scores': memory_scores, # [batch, seq_len, num_selected]
|
||||
'h_for_memory': h_for_memory, # [batch, seq_len, dim]
|
||||
'selected_memory': selected_memory, # [batch, seq_len, num_selected, knowledge_dim]
|
||||
}
|
||||
|
||||
if collect_ema_stats:
|
||||
return out, balance_loss, layer_stats, ema_stats
|
||||
else:
|
||||
return out, balance_loss, layer_stats
|
||||
|
||||
|
||||
@ -390,10 +405,25 @@ class MiniMindLM(PreTrainedModel):
|
||||
persistent=False)
|
||||
|
||||
# 初始化共享记忆库
|
||||
# VQ-VAE风格:memory_bank作为codebook,使用EMA更新而非梯度更新
|
||||
if params.use_ema_update:
|
||||
self.memory_bank = nn.Parameter(
|
||||
torch.randn(params.knowledge_num, params.knowledge_dim),
|
||||
requires_grad=True
|
||||
requires_grad=False # 禁用梯度更新,使用EMA更新
|
||||
)
|
||||
else:
|
||||
self.memory_bank = nn.Parameter(
|
||||
torch.randn(params.knowledge_num, params.knowledge_dim),
|
||||
requires_grad=True # 传统梯度更新
|
||||
)
|
||||
|
||||
# EMA更新相关缓冲区
|
||||
if params.use_ema_update:
|
||||
# 记录每个memory条目的更新统计
|
||||
self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False)
|
||||
self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False)
|
||||
# EMA更新频率计数器
|
||||
self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False)
|
||||
|
||||
# 记录上一步的记忆库状态,用于计算更新统计
|
||||
self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False)
|
||||
@ -453,15 +483,23 @@ class MiniMindLM(PreTrainedModel):
|
||||
**args):
|
||||
"""Forward pass without KV cache support"""
|
||||
start_pos = args.get('start_pos', 0)
|
||||
collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training)
|
||||
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
|
||||
# 收集所有层的平衡损失和统计信息
|
||||
total_balance_loss = 0
|
||||
all_layer_stats = {}
|
||||
all_ema_stats = {}
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank)
|
||||
if collect_ema_stats:
|
||||
h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=True)
|
||||
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
|
||||
else:
|
||||
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, collect_ema_stats=False)
|
||||
|
||||
total_balance_loss += balance_loss
|
||||
# 为每层的统计信息添加前缀
|
||||
for key, value in layer_stats.items():
|
||||
@ -476,6 +514,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
|
||||
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
|
||||
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
||||
return self.OUT
|
||||
|
||||
@ -537,3 +576,131 @@ class MiniMindLM(PreTrainedModel):
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
||||
def apply_ema_update(self, ema_stats):
|
||||
"""
|
||||
应用VQ-VAE风格的EMA更新到memory_bank
|
||||
|
||||
Args:
|
||||
ema_stats: 从forward pass收集的EMA统计信息,格式为:
|
||||
{'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...}
|
||||
"""
|
||||
if not self.params.use_ema_update:
|
||||
return {}
|
||||
|
||||
# 增加EMA步数计数器
|
||||
self.ema_step_counter += 1
|
||||
|
||||
# 检查是否需要进行EMA更新
|
||||
if self.ema_step_counter % self.params.ema_update_freq != 0:
|
||||
return {'ema_update_applied': False, 'reason': 'frequency_check_failed'}
|
||||
|
||||
with torch.no_grad():
|
||||
device = self.memory_bank.device
|
||||
knowledge_num, knowledge_dim = self.memory_bank.shape
|
||||
|
||||
# 重置累积缓冲区
|
||||
self.ema_sum_buffer.zero_()
|
||||
self.ema_update_count.zero_()
|
||||
|
||||
total_selections = 0
|
||||
total_layers = 0
|
||||
|
||||
# 收集所有层的EMA统计信息
|
||||
for layer_name, layer_ema_stats in ema_stats.items():
|
||||
if layer_ema_stats is None:
|
||||
continue
|
||||
|
||||
total_layers += 1
|
||||
memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected]
|
||||
h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim]
|
||||
|
||||
bsz, seq_len, num_selected = memory_indices.shape
|
||||
total_selections += bsz * seq_len * num_selected
|
||||
|
||||
# 将h_for_memory投影到knowledge_dim维度(如果维度不匹配)
|
||||
if h_for_memory.size(-1) != knowledge_dim:
|
||||
# 使用简单的线性投影(截断或者填零)
|
||||
if h_for_memory.size(-1) > knowledge_dim:
|
||||
# 截断到knowledge_dim
|
||||
h_proj = h_for_memory[..., :knowledge_dim]
|
||||
else:
|
||||
# 用零填充到knowledge_dim
|
||||
pad_size = knowledge_dim - h_for_memory.size(-1)
|
||||
h_proj = F.pad(h_for_memory, (0, pad_size), 'constant', 0)
|
||||
else:
|
||||
h_proj = h_for_memory
|
||||
|
||||
# 展平索引和对应的h_for_memory
|
||||
flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected]
|
||||
|
||||
# 为每个选择位置复制对应的h_for_memory
|
||||
# [batch, seq_len, num_selected] -> [batch, seq_len, num_selected, dim]
|
||||
h_expanded = h_proj.unsqueeze(2).expand(-1, -1, num_selected, -1)
|
||||
flat_h = h_expanded.reshape(-1, knowledge_dim) # [batch * seq_len * num_selected, knowledge_dim]
|
||||
|
||||
# 确保数据类型匹配
|
||||
flat_indices = flat_indices.long().to(device) # 索引必须是long类型
|
||||
flat_h = flat_h.to(dtype=self.ema_sum_buffer.dtype, device=device) # 数据类型匹配
|
||||
|
||||
# 累积每个memory条目的h_for_memory值
|
||||
# scatter_add_: 将flat_h的值累加到ema_sum_buffer的对应位置
|
||||
self.ema_sum_buffer.scatter_add_(0, flat_indices.unsqueeze(1).expand(-1, knowledge_dim), flat_h)
|
||||
|
||||
# 统计每个memory条目被选择的次数
|
||||
count_ones = torch.ones_like(flat_indices, dtype=self.ema_update_count.dtype, device=device)
|
||||
self.ema_update_count.scatter_add_(0, flat_indices, count_ones)
|
||||
|
||||
# 计算平均值并应用EMA更新
|
||||
# 防止除零错误
|
||||
non_zero_mask = self.ema_update_count > 0
|
||||
avg_h_for_selected = torch.zeros_like(self.memory_bank)
|
||||
|
||||
if non_zero_mask.any():
|
||||
# 计算被选择memory条目的平均h_for_memory
|
||||
avg_h_for_selected[non_zero_mask] = (
|
||||
self.ema_sum_buffer[non_zero_mask] / self.ema_update_count[non_zero_mask].unsqueeze(1)
|
||||
)
|
||||
|
||||
# 确保数据类型匹配并应用EMA更新:new = γ * old + (1-γ) * new_avg
|
||||
# 只更新被选择的memory条目
|
||||
old_memory = self.memory_bank[non_zero_mask]
|
||||
new_avg = avg_h_for_selected[non_zero_mask].to(dtype=old_memory.dtype)
|
||||
|
||||
self.memory_bank[non_zero_mask] = (
|
||||
self.params.ema_decay * old_memory +
|
||||
(1 - self.params.ema_decay) * new_avg
|
||||
)
|
||||
|
||||
# 计算更新统计信息
|
||||
updated_memories = non_zero_mask.sum().item()
|
||||
update_ratio = updated_memories / knowledge_num
|
||||
|
||||
# 计算EMA更新幅度统计
|
||||
if hasattr(self, 'prev_memory_bank_ema') and self.prev_memory_bank_ema.numel() > 0:
|
||||
l2_changes = torch.norm(self.memory_bank[non_zero_mask] - self.prev_memory_bank_ema[non_zero_mask], p=2, dim=1)
|
||||
avg_change = l2_changes.mean().item() if len(l2_changes) > 0 else 0.0
|
||||
max_change = l2_changes.max().item() if len(l2_changes) > 0 else 0.0
|
||||
else:
|
||||
avg_change = 0.0
|
||||
max_change = 0.0
|
||||
|
||||
# 保存当前memory_bank状态用于下次比较
|
||||
if not hasattr(self, 'prev_memory_bank_ema'):
|
||||
self.register_buffer('prev_memory_bank_ema', torch.zeros_like(self.memory_bank), persistent=False)
|
||||
self.prev_memory_bank_ema.copy_(self.memory_bank)
|
||||
|
||||
update_stats = {
|
||||
'ema_update_applied': True,
|
||||
'ema_step': self.ema_step_counter.item(),
|
||||
'total_selections': total_selections,
|
||||
'total_layers': total_layers,
|
||||
'updated_memories': updated_memories,
|
||||
'update_ratio': update_ratio,
|
||||
'avg_ema_change': avg_change,
|
||||
'max_ema_change': max_change,
|
||||
'ema_decay': self.params.ema_decay,
|
||||
'selected_memory_coverage': (self.ema_update_count > 0).float().mean().item(),
|
||||
}
|
||||
|
||||
return update_stats
|
||||
352
run_file/experiment_1_4_5.sh
Normal file
352
run_file/experiment_1_4_5.sh
Normal file
@ -0,0 +1,352 @@
|
||||
#!/bin/bash
|
||||
|
||||
# ============================================================================
|
||||
# MiniMind 实验脚本 - Experiment 1.4.5
|
||||
# ============================================================================
|
||||
#
|
||||
# 🎯 实验目标:
|
||||
# 基于实验1.4.4,实现VQ-VAE风格的EMA更新机制替代memory_bank的梯度更新
|
||||
#
|
||||
# 使用方法:
|
||||
# bash run_file/experiment_1_4_5.sh
|
||||
# ============================================================================
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🧑🔬 实验基本信息
|
||||
# ----------------------------------------------------------------------------
|
||||
EXPERIMENT_VERSION="1.4.5"
|
||||
EXPERIMENT_DESCRIPTION="VQ-VAE风格EMA更新机制实验"
|
||||
RESEARCHER_NAME="AI Assistant"
|
||||
EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 环境配置
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# 调试和监控环境变量
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# SwanLab 配置
|
||||
export SWANLAB_PROJECT="MiniMind-Experiment-1.4.5"
|
||||
|
||||
# 日志配置
|
||||
LOG_DIR="out/experiment_${EXPERIMENT_VERSION}"
|
||||
mkdir -p "$LOG_DIR"
|
||||
LOG_FILE="$LOG_DIR/experiment.log"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 硬件配置
|
||||
# ----------------------------------------------------------------------------
|
||||
CUDA_VISIBLE_DEVICES="0"
|
||||
NUM_PROCESSES="1"
|
||||
MIXED_PRECISION="bf16"
|
||||
MAIN_PROCESS_PORT="29500"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 模型架构参数
|
||||
# ----------------------------------------------------------------------------
|
||||
MODEL_TYPE="model_memory"
|
||||
MODEL_SIZE="50.0"
|
||||
DIM="512"
|
||||
N_LAYERS="8"
|
||||
N_HEADS="32"
|
||||
MAX_SEQ_LEN="512"
|
||||
USE_MOE="false"
|
||||
|
||||
# 知识库配置(使用更大规模测试EMA机制)
|
||||
KNOWLEDGE_NUM="1048576" # 1024x1024 = 1048576,更大规模测试EMA
|
||||
KNOWLEDGE_LENGTH="32"
|
||||
KNOWLEDGE_DIM="128"
|
||||
DISABLE_DB="false"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 训练超参数
|
||||
# ----------------------------------------------------------------------------
|
||||
EPOCHS="3"
|
||||
EMBEDDING_EPOCH="2"
|
||||
BATCH_SIZE="96"
|
||||
ACCUMULATION_STEPS="8"
|
||||
LEARNING_RATE="2e-4"
|
||||
DTYPE="bfloat16"
|
||||
GRAD_CLIP="1.0"
|
||||
WARMUP_ITERS="0"
|
||||
|
||||
# 平衡损失配置(沿用1.4.4的成功配置)
|
||||
BALANCE_LOSS_COEF="0.1"
|
||||
|
||||
# 数据和缓存路径
|
||||
DATA_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl"
|
||||
DATABASE_INIT_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json"
|
||||
CLUSTER_CACHE_PATH="None" # 禁用聚类缓存以测试EMA效果
|
||||
VAL_DATA_PATH="dataset/stable/eval_data.json"
|
||||
|
||||
# 训练配置
|
||||
NUM_WORKERS="1"
|
||||
LOG_INTERVAL="100"
|
||||
VAL_INTERVAL="100"
|
||||
SAVE_INTERVAL="10000"
|
||||
|
||||
# 性能分析配置
|
||||
USE_PROFILE="true"
|
||||
PROFILE_INTERVAL="10"
|
||||
MEMORY_MONITOR_INTERVAL="100"
|
||||
|
||||
# 高级功能
|
||||
USE_FLASH_ATTN="true"
|
||||
FAST_CLUSTERING="true"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 预检查函数
|
||||
# ----------------------------------------------------------------------------
|
||||
check_environment() {
|
||||
echo "🔍 环境检查中..."
|
||||
|
||||
# 检查GPU可用性
|
||||
if ! nvidia-smi &> /dev/null; then
|
||||
echo "❌ 错误: 未检测到GPU或nvidia-smi不可用"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查CUDA设备
|
||||
if ! nvidia-smi -i "$CUDA_VISIBLE_DEVICES" &> /dev/null; then
|
||||
echo "❌ 错误: GPU $CUDA_VISIBLE_DEVICES 不可用"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查Python环境
|
||||
if ! .venv/bin/python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null; then
|
||||
echo "❌ 错误: PyTorch未正确安装"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查数据文件
|
||||
if [[ ! -f "$DATA_PATH" ]]; then
|
||||
echo "❌ 错误: 训练数据文件不存在: $DATA_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ! -f "$DATABASE_INIT_PATH" ]]; then
|
||||
echo "❌ 错误: 数据库初始化文件不存在: $DATABASE_INIT_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查EMA相关模型实现
|
||||
if ! .venv/bin/python -c "from model.model_memory import *; print('EMA模型实现检查通过')" 2>/dev/null; then
|
||||
echo "❌ 错误: EMA模型实现存在问题"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ 环境检查通过"
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 实验信息记录
|
||||
# ----------------------------------------------------------------------------
|
||||
log_experiment_info() {
|
||||
echo "📝 记录实验信息..."
|
||||
cat > "$LOG_DIR/experiment_info.txt" << EOF
|
||||
========================================
|
||||
MiniMind 实验信息
|
||||
========================================
|
||||
实验版本: $EXPERIMENT_VERSION
|
||||
实验描述: $EXPERIMENT_DESCRIPTION
|
||||
研究者: $RESEARCHER_NAME
|
||||
开始时间: $EXPERIMENT_DATE
|
||||
========================================
|
||||
硬件配置:
|
||||
GPU设备: $CUDA_VISIBLE_DEVICES
|
||||
进程数: $NUM_PROCESSES
|
||||
混合精度: $MIXED_PRECISION
|
||||
========================================
|
||||
模型配置:
|
||||
模型类型: $MODEL_TYPE
|
||||
模型大小: $MODEL_SIZE MB
|
||||
维度: $DIM
|
||||
层数: $N_LAYERS
|
||||
注意力头数: $N_HEADS
|
||||
最大序列长度: $MAX_SEQ_LEN
|
||||
知识库大小: $KNOWLEDGE_NUM
|
||||
知识长度: $KNOWLEDGE_LENGTH
|
||||
知识维度: $KNOWLEDGE_DIM
|
||||
========================================
|
||||
训练配置:
|
||||
训练轮次: $EPOCHS
|
||||
批次大小: $BATCH_SIZE
|
||||
学习率: $LEARNING_RATE
|
||||
梯度累积: $ACCUMULATION_STEPS
|
||||
数据类型: $DTYPE
|
||||
平衡损失系数: $BALANCE_LOSS_COEF
|
||||
========================================
|
||||
EMA配置:
|
||||
使用EMA更新: 是(VQ-VAE风格)
|
||||
EMA衰减率: 0.999(默认配置)
|
||||
EMA更新频率: 1(每步更新)
|
||||
========================================
|
||||
数据路径:
|
||||
训练数据: $DATA_PATH
|
||||
验证数据: $VAL_DATA_PATH
|
||||
数据库初始化: $DATABASE_INIT_PATH
|
||||
聚类缓存: $CLUSTER_CACHE_PATH
|
||||
========================================
|
||||
EOF
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 主执行函数
|
||||
# ----------------------------------------------------------------------------
|
||||
run_experiment() {
|
||||
echo "🚀 开始执行实验 $EXPERIMENT_VERSION"
|
||||
echo "📄 实验描述: $EXPERIMENT_DESCRIPTION"
|
||||
echo "⏰ 开始时间: $EXPERIMENT_DATE"
|
||||
|
||||
# 构建训练命令
|
||||
local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES .venv/bin/python train_pretrain_accelerate.py"
|
||||
|
||||
# 添加训练参数
|
||||
train_cmd+=" --out_dir \"$LOG_DIR\""
|
||||
train_cmd+=" --epochs $EPOCHS"
|
||||
train_cmd+=" --embedding_epoch $EMBEDDING_EPOCH"
|
||||
train_cmd+=" --batch_size $BATCH_SIZE"
|
||||
train_cmd+=" --learning_rate $LEARNING_RATE"
|
||||
train_cmd+=" --dtype $DTYPE"
|
||||
train_cmd+=" --num_workers $NUM_WORKERS"
|
||||
train_cmd+=" --accumulation_steps $ACCUMULATION_STEPS"
|
||||
train_cmd+=" --grad_clip $GRAD_CLIP"
|
||||
train_cmd+=" --warmup_iters $WARMUP_ITERS"
|
||||
train_cmd+=" --log_interval $LOG_INTERVAL"
|
||||
train_cmd+=" --val_interval $VAL_INTERVAL"
|
||||
train_cmd+=" --save_interval $SAVE_INTERVAL"
|
||||
train_cmd+=" --dim $DIM"
|
||||
train_cmd+=" --n_layers $N_LAYERS"
|
||||
train_cmd+=" --n_heads $N_HEADS"
|
||||
train_cmd+=" --max_seq_len $MAX_SEQ_LEN"
|
||||
train_cmd+=" --data_path \"$DATA_PATH\""
|
||||
train_cmd+=" --val_data_path \"$VAL_DATA_PATH\""
|
||||
train_cmd+=" --knowledge_num $KNOWLEDGE_NUM"
|
||||
train_cmd+=" --knowledge_length $KNOWLEDGE_LENGTH"
|
||||
train_cmd+=" --database_init_path \"$DATABASE_INIT_PATH\""
|
||||
train_cmd+=" --memory_monitor_interval $MEMORY_MONITOR_INTERVAL"
|
||||
train_cmd+=" --model_type \"$MODEL_TYPE\""
|
||||
train_cmd+=" --model_size $MODEL_SIZE"
|
||||
train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF"
|
||||
|
||||
# 可选参数
|
||||
if [[ "$USE_PROFILE" == "true" ]]; then
|
||||
train_cmd+=" --profile"
|
||||
train_cmd+=" --profile_interval $PROFILE_INTERVAL"
|
||||
fi
|
||||
|
||||
if [[ "$USE_FLASH_ATTN" == "true" ]]; then
|
||||
train_cmd+=" --use_flash_attn"
|
||||
fi
|
||||
|
||||
if [[ "$FAST_CLUSTERING" == "true" ]]; then
|
||||
train_cmd+=" --fast_clustering"
|
||||
fi
|
||||
|
||||
if [[ "$CLUSTER_CACHE_PATH" != "None" ]]; then
|
||||
train_cmd+=" --cluster_cache_path \"$CLUSTER_CACHE_PATH\""
|
||||
fi
|
||||
|
||||
# SwanLab配置
|
||||
train_cmd+=" --use_swanlab"
|
||||
train_cmd+=" --swanlab_project \"$SWANLAB_PROJECT\""
|
||||
train_cmd+=" --swanlab_online True"
|
||||
|
||||
echo "📋 执行命令:"
|
||||
echo "$train_cmd"
|
||||
echo
|
||||
|
||||
# 记录命令到日志文件
|
||||
echo "执行命令: $train_cmd" >> "$LOG_FILE"
|
||||
echo "开始时间: $(date)" >> "$LOG_FILE"
|
||||
|
||||
# 使用nohup执行训练(后台运行,输出写入日志文件)
|
||||
echo "🔄 使用nohup后台运行训练,输出将写入日志文件: $LOG_FILE"
|
||||
|
||||
# 创建训练脚本
|
||||
train_script="/tmp/train_${EXPERIMENT_VERSION}.sh"
|
||||
cat > "$train_script" << EOF
|
||||
#!/bin/bash
|
||||
cd /home/pci/ycz/Code/pretrain-worktree
|
||||
source /home/pci/ycz/Code/pretrain-worktree/.venv/bin/activate
|
||||
$train_cmd
|
||||
echo "结束时间: \$(date)"
|
||||
echo "退出代码: \$?"
|
||||
EOF
|
||||
chmod +x "$train_script"
|
||||
|
||||
# 使用nohup后台运行
|
||||
nohup bash "$train_script" >> "$LOG_FILE" 2>&1 &
|
||||
local train_pid=$!
|
||||
|
||||
echo "🔥 训练进程已启动,PID: $train_pid"
|
||||
echo "训练PID: $train_pid" >> "$LOG_FILE"
|
||||
echo "训练脚本: $train_script" >> "$LOG_FILE"
|
||||
|
||||
# 等待几秒确保进程启动
|
||||
sleep 5
|
||||
|
||||
# 检查进程是否还在运行
|
||||
if kill -0 $train_pid 2>/dev/null; then
|
||||
echo "✅ 训练进程正在后台运行"
|
||||
echo "📋 实时查看日志: tail -f $LOG_FILE"
|
||||
echo "📋 检查进程状态: ps -p $train_pid"
|
||||
echo "🛑 停止训练: kill $train_pid"
|
||||
echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT"
|
||||
echo ""
|
||||
echo "🧠 VQ-VAE风格EMA更新机制正在测试中..."
|
||||
echo " - memory_bank使用EMA更新而非梯度更新"
|
||||
echo " - EMA衰减率: 0.999"
|
||||
echo " - 每步更新频率"
|
||||
echo " - 预期: 更稳定的训练和更好的记忆表示学习"
|
||||
echo ""
|
||||
echo "训练正在后台运行,可以安全关闭终端。"
|
||||
else
|
||||
echo "❌ 训练进程启动失败"
|
||||
echo "📋 查看日志: $LOG_FILE"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 清理函数
|
||||
# ----------------------------------------------------------------------------
|
||||
cleanup() {
|
||||
echo "🧹 清理临时文件..."
|
||||
# 删除临时验证文件
|
||||
rm -f /tmp/temp_val.jsonl
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 信号处理
|
||||
# ----------------------------------------------------------------------------
|
||||
trap cleanup EXIT
|
||||
trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 🤖 主程序入口
|
||||
# ----------------------------------------------------------------------------
|
||||
main() {
|
||||
echo "============================================================================"
|
||||
echo "🧠 MiniMind 预训练实验 1.4.5"
|
||||
echo "🎯 VQ-VAE风格EMA更新机制 - 替代memory_bank梯度更新"
|
||||
echo "============================================================================"
|
||||
|
||||
# 执行检查和初始化
|
||||
check_environment
|
||||
log_experiment_info
|
||||
|
||||
# 运行实验
|
||||
run_experiment
|
||||
|
||||
echo "============================================================================"
|
||||
echo "✅ 实验 $EXPERIMENT_VERSION 启动完成"
|
||||
echo "📅 启动时间: $(date)"
|
||||
echo "============================================================================"
|
||||
}
|
||||
|
||||
# 执行主程序
|
||||
main "$@"
|
||||
@ -770,6 +770,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
# 但为了安全起见,我们仍然显式调用它
|
||||
optimizer.zero_grad()
|
||||
|
||||
# VQ-VAE风格的EMA更新(仅在启用时执行)
|
||||
if hasattr(res, 'ema_stats') and res.ema_stats is not None:
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
if hasattr(unwrapped_model, 'apply_ema_update'):
|
||||
ema_update_stats = unwrapped_model.apply_ema_update(res.ema_stats)
|
||||
# 记录EMA更新统计信息
|
||||
if step % args.log_interval == 0 and accelerator.is_main_process and ema_update_stats.get('ema_update_applied', False):
|
||||
total_memories = args.knowledge_num
|
||||
Logger(f"EMA Update - Step: {ema_update_stats['ema_step']}, "
|
||||
f"Updated memories: {ema_update_stats['updated_memories']}/{total_memories} "
|
||||
f"({ema_update_stats['update_ratio']:.4f}), "
|
||||
f"Avg change: {ema_update_stats['avg_ema_change']:.6f}, "
|
||||
f"Coverage: {ema_update_stats['selected_memory_coverage']:.4f}", accelerator)
|
||||
|
||||
# 计时优化器步骤结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
||||
optimizer_end.record()
|
||||
@ -1021,7 +1035,7 @@ def main():
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||
parser.add_argument("--batch_size", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||
@ -1052,7 +1066,7 @@ def main():
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
|
||||
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
|
||||
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
|
||||
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
|
||||
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
|
||||
@ -1214,6 +1228,16 @@ def main():
|
||||
#########################################################
|
||||
# 创建优化器
|
||||
#########################################################
|
||||
# 如果启用EMA更新,需要过滤掉memory_bank参数(因为它不再需要梯度更新)
|
||||
if hasattr(model.params, 'use_ema_update') and model.params.use_ema_update:
|
||||
# 只包含requires_grad=True的参数
|
||||
optimizer_params = [p for p in model.parameters() if p.requires_grad]
|
||||
Logger(f"EMA更新模式:优化器包含 {len(optimizer_params)} 个参数(过滤掉memory_bank)")
|
||||
Logger(f"总参数:{sum(p.numel() for p in model.parameters())} | 可训练参数:{sum(p.numel() for p in optimizer_params)}")
|
||||
optimizer = optim.AdamW(optimizer_params, lr=args.learning_rate)
|
||||
else:
|
||||
# 传统模式:所有参数都使用梯度更新
|
||||
Logger("传统梯度更新模式:优化器包含所有模型参数")
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
#########################################################
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user