This commit is contained in:
iomgaa 2025-09-07 22:41:30 +08:00
parent 8379b45d80
commit e06f94b5f5
3 changed files with 47 additions and 152 deletions

View File

@ -191,82 +191,8 @@ class MemoryGate(nn.Module):
# 返回候选项用于后续的相似度选择
# 注意这里返回候选项在MiniMindBlock中进行相似度选择和多样性损失计算
return candidate_indices, candidate_scores, None, {}
return candidate_indices, candidate_scores
def _compute_balance_loss_and_stats(self, memory_indices, memory_scores):
"""
计算平衡损失和监控统计信息
Args:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
Returns:
balance_loss: 标量张量
stats: 统计信息字典
"""
bsz, seq_len, num_selected = memory_indices.shape
device = memory_indices.device
# 1. 计算记忆选择分布
# 将所有选择的记忆索引展平
flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected]
# 统计每个记忆条目被选中的次数
memory_counts = torch.zeros(self.knowledge_num, device=device)
memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float))
# 计算选择概率分布
total_selections = bsz * seq_len * num_selected
memory_probs = memory_counts / total_selections
# 2. 计算KL散度损失与均匀分布的KL散度
uniform_prob = 1.0 / self.knowledge_num
# 避免log(0)的问题
memory_probs_safe = memory_probs + 1e-10
kl_loss = F.kl_div(
torch.log(memory_probs_safe),
torch.full_like(memory_probs, uniform_prob),
reduction='sum'
)
# 3. 计算基尼系数损失(衡量分布不平等程度)
sorted_probs, _ = torch.sort(memory_probs)
n = self.knowledge_num
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
gini_loss = gini_coeff # 基尼系数越大,分布越不均匀
# 4. 组合平衡损失
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
# 5. 计算监控统计信息
with torch.no_grad():
# 记忆覆盖率:被选中的记忆条目占总数的比例
coverage_rate = (memory_counts > 0).float().mean().item()
# 热点记忆选择次数前10%的记忆条目
top10_threshold = torch.quantile(memory_counts, 0.9)
hot_memories = (memory_counts >= top10_threshold).sum().item()
# 死记忆:从未被选中的记忆条目
dead_memories = (memory_counts == 0).sum().item()
# 记忆选择方差(衡量不平衡程度)
selection_variance = memory_counts.var().item()
stats = {
'gini_coefficient': gini_coeff.item(),
'kl_divergence': kl_loss.item(),
'coverage_rate': coverage_rate,
'hot_memories': hot_memories,
'dead_memories': dead_memories,
'selection_variance': selection_variance,
'max_selections': memory_counts.max().item(),
'min_selections': memory_counts.min().item(),
}
return balance_loss, stats
class GatedMemoryFusion(nn.Module):
@ -405,7 +331,7 @@ class MiniMindBlock(nn.Module):
def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False):
"""
实验1.4.9: Gumbel-Softmax + 多样性损失 + 可微分相似度损失
实验: Gumbel-Softmax + 多样性损失 + 可微分相似度损失 (移除平衡损失)
Args:
x: [batch_size, seq_len, dim]
@ -416,10 +342,9 @@ class MiniMindBlock(nn.Module):
Returns:
out: [batch_size, seq_len, dim]
balance_loss: 该层的平衡损失 (从候选项计算)
similarity_loss: 相似度损失 (可微分)
diversity_loss: 多样性损失
layer_stats: 该层的监控统计信息
layer_stats: 该层的监控统计信息 (现在不包含balance相关)
ema_stats: EMA更新统计信息如果collect_ema_stats=True
cosine_stats: 查找向量与候选记忆条目的余弦相似度统计信息
"""
@ -431,7 +356,7 @@ class MiniMindBlock(nn.Module):
h_for_memory = self.memory_norm(h_attn)
# 🔥 新架构生成32个候选项
candidate_indices, candidate_scores, _, _ = self.memory_gate(h_for_memory)
candidate_indices, candidate_scores = self.memory_gate(h_for_memory)
# candidate_indices: [batch, seq_len, num_candidates]
# candidate_scores: [batch, seq_len, num_candidates]
@ -474,8 +399,8 @@ class MiniMindBlock(nn.Module):
# 残差连接
out = h + memory_output
# 🔥 计算平衡损失和统计信息 (基于候选项的选择分布)
balance_loss, layer_stats = self._compute_candidate_balance_stats(candidate_indices, selection_weights)
# 🔥 计算简化的统计信息 (移除balance相关)
layer_stats = self._compute_selection_stats(candidate_indices, selection_weights)
# 🔥 计算详细的相似度统计信息
cosine_stats = {
@ -501,23 +426,21 @@ class MiniMindBlock(nn.Module):
}
if collect_ema_stats:
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats
return out, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats
else:
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats
return out, similarity_loss, diversity_loss, layer_stats, cosine_stats
def _compute_candidate_balance_stats(self, candidate_indices, selection_weights):
def _compute_selection_stats(self, candidate_indices, selection_weights):
"""
计算基于候选项选择的平衡损失和统计信息
计算基于候选项选择的简化统计信息移除balance损失
Args:
candidate_indices: [batch_size, seq_len, num_candidates]
selection_weights: [batch_size, seq_len, num_candidates] - Gumbel-Softmax权重
Returns:
balance_loss: 标量张量
stats: 统计信息字典
"""
bsz, seq_len, num_candidates = candidate_indices.shape
device = candidate_indices.device
# 使用加权统计每个记忆条目被选中的概率
@ -528,30 +451,7 @@ class MiniMindBlock(nn.Module):
memory_counts = torch.zeros(self.config.knowledge_num, device=device)
memory_counts.scatter_add_(0, flat_indices, flat_weights)
# 计算选择概率分布
total_selections = memory_counts.sum()
memory_probs = memory_counts / (total_selections + 1e-10)
# 计算KL散度损失与均匀分布的KL散度
uniform_prob = 1.0 / self.config.knowledge_num
memory_probs_safe = memory_probs + 1e-10
kl_loss = F.kl_div(
torch.log(memory_probs_safe),
torch.full_like(memory_probs, uniform_prob),
reduction='sum'
)
# 计算基尼系数损失
sorted_probs, _ = torch.sort(memory_probs)
n = self.config.knowledge_num
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
gini_loss = gini_coeff
# 组合平衡损失
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
# 计算统计信息
# 计算统计信息不包括balance损失
with torch.no_grad():
coverage_rate = (memory_counts > 0.01).float().mean().item() # 被选中概率>1%的记忆比例
top10_threshold = torch.quantile(memory_counts, 0.9)
@ -560,8 +460,6 @@ class MiniMindBlock(nn.Module):
selection_variance = memory_counts.var().item()
stats = {
'gini_coefficient': gini_coeff.item(),
'kl_divergence': kl_loss.item(),
'coverage_rate': coverage_rate,
'hot_memories': hot_memories,
'dead_memories': dead_memories,
@ -570,7 +468,7 @@ class MiniMindBlock(nn.Module):
'min_selections': memory_counts.min().item(),
}
return balance_loss, stats
return stats
class MiniMindLM(PreTrainedModel):
@ -689,23 +587,23 @@ class MiniMindLM(PreTrainedModel):
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
# 收集所有层的损失和统计信息 - 实验1.4.9: 四损失系统
total_balance_loss = 0
# 收集所有层的损失和统计信息 - 实验: 两损失系统(移除balance_loss)
total_similarity_loss = 0
total_diversity_loss = 0
all_layer_stats = {}
all_ema_stats = {}
all_cosine_stats = {}
n = 0
for layer_idx, layer in enumerate(self.layers):
n = n+1
if collect_ema_stats:
h, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
h, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
else:
h, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
h, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
# 累加四种损失
total_balance_loss += balance_loss
# 累加两种损失移除balance_loss
total_similarity_loss += similarity_loss
total_diversity_loss += diversity_loss
@ -719,11 +617,10 @@ class MiniMindLM(PreTrainedModel):
logits = self.output(self.norm(h))
# 🔥 新的四损失结构
# 🔥 新的两损失结构移除balance_loss
aux_loss = {
'balance_loss': total_balance_loss,
'similarity_loss': total_similarity_loss,
'diversity_loss': total_diversity_loss,
'similarity_loss': total_similarity_loss / n,
'diversity_loss': total_diversity_loss / n,
}
self.OUT.__setitem__('last_hidden_state', h)

View File

@ -10,7 +10,7 @@
# 2. 强化DeepSpeed参数+优化器CPU offload + 异步I/O优化
#
# 📝 优化策略说明:
# - 不使用梯度检查点:避免对损失系统和Gumbel-Softmax的数值稳定性影响
# - 不使用梯度检查点:避免对损失系统和Gumbel-Softmax的数值稳定性影响
# - 专注安全优化:确保训练质量的同时减少显存占用
#
# 使用方法:
@ -21,7 +21,7 @@
# 🧑‍🔬 实验基本信息
# ----------------------------------------------------------------------------
EXPERIMENT_VERSION="1.4.10_optimized"
EXPERIMENT_DESCRIPTION="损失系统优化版 - 二大安全显存优化策略实现"
EXPERIMENT_DESCRIPTION="损失系统优化版 - 二大安全显存优化策略实现"
RESEARCHER_NAME="AI Assistant"
EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')"
@ -80,7 +80,7 @@ GRAD_CLIP="1.0"
WARMUP_ITERS="0"
# 🔥 四损失系统配置 (保持与1.4.10一致)
BALANCE_LOSS_COEF="0.01" # 平衡损失系数
# BALANCE_LOSS_COEF removed - using only three losses now
SIMILARITY_LOSS_COEF="0.8" # 相似度损失系数(核心损失)
DIVERSITY_LOSS_COEF="0.2" # 多样性损失系数(避免候选重复)
@ -165,7 +165,7 @@ GPU设备: $CUDA_VISIBLE_DEVICES
混合精度: $MIXED_PRECISION
========================================
模型配置:
模型类型: $MODEL_TYPE (Token-based Memory + 损失系统)
模型类型: $MODEL_TYPE (Token-based Memory + 损失系统)
模型大小: $MODEL_SIZE MB
维度: $DIM
层数: $N_LAYERS
@ -184,8 +184,7 @@ GPU设备: $CUDA_VISIBLE_DEVICES
有效批次大小: $((BATCH_SIZE * ACCUMULATION_STEPS * 4))
数据类型: $DTYPE
========================================
🔥 四损失系统配置:
平衡损失系数: $BALANCE_LOSS_COEF (记忆选择平衡)
🔥 三损失系统配置 (已移除Balance Loss):
相似度损失系数: $SIMILARITY_LOSS_COEF (语义匹配优化)
多样性损失系数: $DIVERSITY_LOSS_COEF (候选集多样性)
========================================
@ -255,8 +254,7 @@ run_experiment() {
train_cmd+=" --model_size $MODEL_SIZE"
train_cmd+=" --freeze_ratio $FREEZE_RATIO"
# 🔥 四损失系统参数
train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF"
# 🔥 三损失系统参数 (Balance Loss已移除)
train_cmd+=" --similarity_loss_coef $SIMILARITY_LOSS_COEF"
train_cmd+=" --diversity_loss_coef $DIVERSITY_LOSS_COEF"
@ -325,9 +323,9 @@ EOF
echo "🛑 停止训练: kill $train_pid"
echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT"
echo ""
echo "🧠 显存优化版损失系统正在测试中..."
echo "🧠 显存优化版损失系统正在测试中..."
echo " 🔥 二大安全优化策略已启用"
echo " 🔥 损失结构: CE + Balance + Similarity + Diversity"
echo " 🔥 损失结构: CE + Similarity + Diversity (三损失系统)"
echo " 🔥 候选机制: 16个候选 → Gumbel-Softmax选择1个最佳"
echo " 🔥 数值稳定性: 完全保持,无梯度检查点干扰"
echo " 🔥 DeepSpeed优化: 参数+优化器CPU offload"
@ -343,7 +341,7 @@ EOF
echo "🎯 预期改进:"
echo " - 显存使用: 适配A800 80GB (原版无法运行)"
echo " - 训练稳定性: 优化版更稳定"
echo " - 损失收敛: 与原版期望一致"
echo " - 损失收敛: 与原版期望一致"
echo " - 生成质量: 保持原版目标质量"
echo ""
echo "⏱️ 预计训练时间: 18-20小时 (无梯度检查点重复计算)"
@ -351,7 +349,7 @@ EOF
echo ""
echo "🔍 关键监控指标:"
echo " - GPU显存占用: 应保持在70GB以下"
echo " - 损失收敛: 与原版1.4.10对比"
echo " - 损失收敛: 与原版1.4.10对比"
echo " - 训练稳定性: 无OOM错误"
echo " - 优化效果验证: 记忆选择质量"
echo ""
@ -383,7 +381,7 @@ trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM
main() {
echo "============================================================================"
echo "🧠 MiniMind 预训练实验 1.4.10 优化版"
echo "🎯 损失系统 + 二大安全显存优化策略"
echo "🎯 损失系统 + 二大安全显存优化策略"
echo "============================================================================"
echo ""
echo "🔥 核心优化策略:"
@ -395,7 +393,7 @@ main() {
echo "🎯 显存优化目标:"
echo " ✓ 原版1.4.10: 需要80GB+ → 优化版: 45-55GB (保守估计)"
echo " ✓ A800 80GB兼容: 从无法运行 → 完全兼容"
echo " ✓ 训练质量保持: 损失系统功能完整,数值稳定"
echo " ✓ 训练质量保持: 损失系统功能完整,数值稳定"
echo " ✓ 收敛行为一致: 与原版1.4.10期望完全一致"
echo ""
echo "🔧 技术实现细节:"

View File

@ -1013,21 +1013,26 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
).view(Y.size())
ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum()
# 🔥 实验1.4.9: 四损失系统处理
balance_loss = 0
# 🔥 实验: 三损失系统处理 + 分布式训练聚合修复
similarity_loss = 0
diversity_loss = 0
if hasattr(res, 'aux_loss') and res.aux_loss is not None:
aux_loss = res.aux_loss
if isinstance(aux_loss, dict):
# 新的四损失结构
balance_loss = aux_loss.get('balance_loss', 0)
# 三损失结构移除balance_loss
similarity_loss = aux_loss.get('similarity_loss', 0)
diversity_loss = aux_loss.get('diversity_loss', 0)
# 🔥 修复分布式训练中的损失聚合问题
# 对三损失进行跨GPU聚合获得全局平均值
if isinstance(similarity_loss, torch.Tensor):
similarity_loss = accelerator.gather(similarity_loss).mean()
if isinstance(diversity_loss, torch.Tensor):
diversity_loss = accelerator.gather(diversity_loss).mean()
else:
# 向后兼容旧的单一aux_loss
balance_loss = aux_loss
# 向后兼容旧的单一aux_loss现在忽略因为我们移除了balance_loss
pass
# 获取余弦相似度统计信息(如果模型支持)
cosine_stats = {}
@ -1039,14 +1044,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
if selected_similarities:
avg_selected_similarity = np.mean(selected_similarities)
# 🔥 四损失系统CE + Balance + Similarity + Diversity
# 🔥 三损失系统CE + Similarity + Diversity移除Balance Loss
# 损失系数可以通过命令行参数调整
balance_coef = getattr(args, 'balance_loss_coef', 0.01)
similarity_coef = getattr(args, 'similarity_loss_coef', 0.1)
diversity_coef = getattr(args, 'diversity_loss_coef', 0.05)
total_loss = (ce_loss +
balance_coef * balance_loss +
similarity_coef * similarity_loss +
diversity_coef * diversity_loss)
loss = total_loss / args.accumulation_steps
@ -1247,13 +1250,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
layer_stats = res.layer_stats
# 🔥 构建损失系统的日志字典
# 🔥 构建损失系统的日志字典
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"train/loss_ce": ce_loss.item(),
"train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss,
"train/loss_similarity": similarity_loss.item() if isinstance(similarity_loss, torch.Tensor) else similarity_loss,
"train/loss_diversity": diversity_loss.item() if isinstance(diversity_loss, torch.Tensor) else diversity_loss,
"train/loss_total": total_loss.item(),
@ -1287,10 +1289,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
'train/avg_selected_similarity': avg_selected_similarity, # 🔥 使用选中记忆的相似度
})
# 🔥 损失系统的控制台输出
# 🔥 损失系统的控制台输出
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"CE: {log_dict['train/loss_ce']:.4f}, "
f"Bal: {log_dict['train/loss_balance']:.4f}, "
f"Sim: {log_dict['train/loss_similarity']:.4f}, "
f"Div: {log_dict['train/loss_diversity']:.4f}, "
f"Total: {log_dict['train/loss_total']:.4f}, "
@ -1387,7 +1388,6 @@ def main():
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
parser.add_argument("--similarity_loss_coef", type=float, default=0.1, help="相似度损失系数实验1.4.9")
parser.add_argument("--diversity_loss_coef", type=float, default=0.05, help="多样性损失系数实验1.4.9")
parser.add_argument("--val_data_path", type=str, default="/home/zym/Code/stable/eval_data.json", help="验证数据集路径")