Experiment 1.4.10

This commit is contained in:
Aurora 2025-09-11 00:10:08 +08:00
parent e06f94b5f5
commit 3d69ae866e
3 changed files with 13 additions and 8 deletions

View File

@ -466,7 +466,9 @@ class MiniMindBlock(nn.Module):
diversity_loss = self.compute_diversity_loss(candidate_memories)
# 🔥 使用selection_weights进行加权选择最终记忆
selected_memory = (candidate_memories * selection_weights.unsqueeze(-1)).sum(dim=2) # [batch, seq_len, dim]
batch, seq_len, num_candidates, dim = candidate_memories.shape
selected_memory = (candidate_memories * selection_weights.unsqueeze(-1)) # [batch, seq_len, dim]
selected_memory = weighted_memories.reshape(batch_size, seq_len * num_candidates, dim)
# 门控MLP融合只融合选中的单个最佳记忆
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory)

View File

@ -63,7 +63,7 @@ USE_MOE="false"
# 🔥 知识库配置优化版16个候选项
KNOWLEDGE_NUM="1048576" # 1M entries
KNOWLEDGE_LENGTH="8" # 保持8个token长度
KNOWLEDGE_LENGTH="16" # 保持8个token长度
KNOWLEDGE_DIM="128" # 保持兼容性
DISABLE_DB="false"
@ -72,7 +72,7 @@ DISABLE_DB="false"
# ----------------------------------------------------------------------------
EPOCHS="3"
EMBEDDING_EPOCH="2"
BATCH_SIZE="24" # 🔥 显存优化: 从48减少到24 (减少50%)
BATCH_SIZE="48" # 🔥 显存优化: 从48减少到24 (减少50%)
ACCUMULATION_STEPS="16" # 🔥 显存优化: 从8增加到16 (保持有效批次: 24*16*4=1536)
LEARNING_RATE="2e-4" # 保持学习率稳定
DTYPE="bfloat16"
@ -81,8 +81,8 @@ WARMUP_ITERS="0"
# 🔥 四损失系统配置 (保持与1.4.10一致)
# BALANCE_LOSS_COEF removed - using only three losses now
SIMILARITY_LOSS_COEF="0.8" # 相似度损失系数(核心损失)
DIVERSITY_LOSS_COEF="0.2" # 多样性损失系数(避免候选重复)
SIMILARITY_LOSS_COEF="6" # 相似度损失系数(核心损失)
DIVERSITY_LOSS_COEF="4" # 多样性损失系数(避免候选重复)
# 数据和缓存路径
DATA_PATH="dataset/stable/merged_pretrain.jsonl"
@ -91,7 +91,7 @@ CLUSTER_CACHE_PATH="None" # 禁用聚类缓存
VAL_DATA_PATH="dataset/stable/eval_data.json"
# 训练配置
NUM_WORKERS="8"
NUM_WORKERS="4"
LOG_INTERVAL="100"
VAL_INTERVAL="100"
SAVE_INTERVAL="10000"

View File

@ -1052,6 +1052,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
total_loss = (ce_loss +
similarity_coef * similarity_loss +
diversity_coef * diversity_loss)
total_loaa = (ce_loss +
(similarity_loss) / (ce_loss/similarity_loss).detach() +
(diversity_coef) / (diversity_coef/similarity_loss).detach())
loss = total_loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
@ -1388,8 +1391,8 @@ def main():
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
parser.add_argument("--similarity_loss_coef", type=float, default=0.1, help="相似度损失系数实验1.4.9")
parser.add_argument("--diversity_loss_coef", type=float, default=0.05, help="多样性损失系数实验1.4.9")
parser.add_argument("--similarity_loss_coef", type=float, default=3, help="相似度损失系数实验1.4.9")
parser.add_argument("--diversity_loss_coef", type=float, default=3, help="多样性损失系数实验1.4.9")
parser.add_argument("--val_data_path", type=str, default="/home/zym/Code/stable/eval_data.json", help="验证数据集路径")
parser.add_argument("--val_interval", type=int, default=100, help="验证评估间隔")
parser.add_argument("--freeze_ratio", type=float, default=0.2, help="冻结率")