update
This commit is contained in:
parent
8379b45d80
commit
e06f94b5f5
@ -191,82 +191,8 @@ class MemoryGate(nn.Module):
|
||||
|
||||
# 返回候选项用于后续的相似度选择
|
||||
# 注意:这里返回候选项,在MiniMindBlock中进行相似度选择和多样性损失计算
|
||||
return candidate_indices, candidate_scores, None, {}
|
||||
return candidate_indices, candidate_scores
|
||||
|
||||
def _compute_balance_loss_and_stats(self, memory_indices, memory_scores):
|
||||
"""
|
||||
计算平衡损失和监控统计信息
|
||||
|
||||
Args:
|
||||
memory_indices: [batch_size, seq_len, num_selected]
|
||||
memory_scores: [batch_size, seq_len, num_selected]
|
||||
|
||||
Returns:
|
||||
balance_loss: 标量张量
|
||||
stats: 统计信息字典
|
||||
"""
|
||||
bsz, seq_len, num_selected = memory_indices.shape
|
||||
device = memory_indices.device
|
||||
|
||||
# 1. 计算记忆选择分布
|
||||
# 将所有选择的记忆索引展平
|
||||
flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected]
|
||||
|
||||
# 统计每个记忆条目被选中的次数
|
||||
memory_counts = torch.zeros(self.knowledge_num, device=device)
|
||||
memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float))
|
||||
|
||||
# 计算选择概率分布
|
||||
total_selections = bsz * seq_len * num_selected
|
||||
memory_probs = memory_counts / total_selections
|
||||
|
||||
# 2. 计算KL散度损失(与均匀分布的KL散度)
|
||||
uniform_prob = 1.0 / self.knowledge_num
|
||||
# 避免log(0)的问题
|
||||
memory_probs_safe = memory_probs + 1e-10
|
||||
kl_loss = F.kl_div(
|
||||
torch.log(memory_probs_safe),
|
||||
torch.full_like(memory_probs, uniform_prob),
|
||||
reduction='sum'
|
||||
)
|
||||
|
||||
# 3. 计算基尼系数损失(衡量分布不平等程度)
|
||||
sorted_probs, _ = torch.sort(memory_probs)
|
||||
n = self.knowledge_num
|
||||
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
|
||||
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
|
||||
gini_loss = gini_coeff # 基尼系数越大,分布越不均匀
|
||||
|
||||
# 4. 组合平衡损失
|
||||
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
|
||||
|
||||
# 5. 计算监控统计信息
|
||||
with torch.no_grad():
|
||||
# 记忆覆盖率:被选中的记忆条目占总数的比例
|
||||
coverage_rate = (memory_counts > 0).float().mean().item()
|
||||
|
||||
# 热点记忆:选择次数前10%的记忆条目
|
||||
top10_threshold = torch.quantile(memory_counts, 0.9)
|
||||
hot_memories = (memory_counts >= top10_threshold).sum().item()
|
||||
|
||||
# 死记忆:从未被选中的记忆条目
|
||||
dead_memories = (memory_counts == 0).sum().item()
|
||||
|
||||
# 记忆选择方差(衡量不平衡程度)
|
||||
selection_variance = memory_counts.var().item()
|
||||
|
||||
stats = {
|
||||
'gini_coefficient': gini_coeff.item(),
|
||||
'kl_divergence': kl_loss.item(),
|
||||
'coverage_rate': coverage_rate,
|
||||
'hot_memories': hot_memories,
|
||||
'dead_memories': dead_memories,
|
||||
'selection_variance': selection_variance,
|
||||
'max_selections': memory_counts.max().item(),
|
||||
'min_selections': memory_counts.min().item(),
|
||||
}
|
||||
|
||||
return balance_loss, stats
|
||||
|
||||
|
||||
class GatedMemoryFusion(nn.Module):
|
||||
@ -405,7 +331,7 @@ class MiniMindBlock(nn.Module):
|
||||
|
||||
def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False):
|
||||
"""
|
||||
实验1.4.9: Gumbel-Softmax + 多样性损失 + 可微分相似度损失
|
||||
实验: Gumbel-Softmax + 多样性损失 + 可微分相似度损失 (移除平衡损失)
|
||||
|
||||
Args:
|
||||
x: [batch_size, seq_len, dim]
|
||||
@ -416,10 +342,9 @@ class MiniMindBlock(nn.Module):
|
||||
|
||||
Returns:
|
||||
out: [batch_size, seq_len, dim]
|
||||
balance_loss: 该层的平衡损失 (从候选项计算)
|
||||
similarity_loss: 相似度损失 (可微分)
|
||||
diversity_loss: 多样性损失
|
||||
layer_stats: 该层的监控统计信息
|
||||
layer_stats: 该层的监控统计信息 (现在不包含balance相关)
|
||||
ema_stats: EMA更新统计信息(如果collect_ema_stats=True)
|
||||
cosine_stats: 查找向量与候选记忆条目的余弦相似度统计信息
|
||||
"""
|
||||
@ -431,7 +356,7 @@ class MiniMindBlock(nn.Module):
|
||||
h_for_memory = self.memory_norm(h_attn)
|
||||
|
||||
# 🔥 新架构:生成32个候选项
|
||||
candidate_indices, candidate_scores, _, _ = self.memory_gate(h_for_memory)
|
||||
candidate_indices, candidate_scores = self.memory_gate(h_for_memory)
|
||||
# candidate_indices: [batch, seq_len, num_candidates]
|
||||
# candidate_scores: [batch, seq_len, num_candidates]
|
||||
|
||||
@ -474,8 +399,8 @@ class MiniMindBlock(nn.Module):
|
||||
# 残差连接
|
||||
out = h + memory_output
|
||||
|
||||
# 🔥 计算平衡损失和统计信息 (基于候选项的选择分布)
|
||||
balance_loss, layer_stats = self._compute_candidate_balance_stats(candidate_indices, selection_weights)
|
||||
# 🔥 计算简化的统计信息 (移除balance相关)
|
||||
layer_stats = self._compute_selection_stats(candidate_indices, selection_weights)
|
||||
|
||||
# 🔥 计算详细的相似度统计信息
|
||||
cosine_stats = {
|
||||
@ -501,23 +426,21 @@ class MiniMindBlock(nn.Module):
|
||||
}
|
||||
|
||||
if collect_ema_stats:
|
||||
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats
|
||||
return out, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats
|
||||
else:
|
||||
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats
|
||||
return out, similarity_loss, diversity_loss, layer_stats, cosine_stats
|
||||
|
||||
def _compute_candidate_balance_stats(self, candidate_indices, selection_weights):
|
||||
def _compute_selection_stats(self, candidate_indices, selection_weights):
|
||||
"""
|
||||
计算基于候选项选择的平衡损失和统计信息
|
||||
计算基于候选项选择的简化统计信息(移除balance损失)
|
||||
|
||||
Args:
|
||||
candidate_indices: [batch_size, seq_len, num_candidates]
|
||||
selection_weights: [batch_size, seq_len, num_candidates] - Gumbel-Softmax权重
|
||||
|
||||
Returns:
|
||||
balance_loss: 标量张量
|
||||
stats: 统计信息字典
|
||||
"""
|
||||
bsz, seq_len, num_candidates = candidate_indices.shape
|
||||
device = candidate_indices.device
|
||||
|
||||
# 使用加权统计每个记忆条目被选中的概率
|
||||
@ -528,30 +451,7 @@ class MiniMindBlock(nn.Module):
|
||||
memory_counts = torch.zeros(self.config.knowledge_num, device=device)
|
||||
memory_counts.scatter_add_(0, flat_indices, flat_weights)
|
||||
|
||||
# 计算选择概率分布
|
||||
total_selections = memory_counts.sum()
|
||||
memory_probs = memory_counts / (total_selections + 1e-10)
|
||||
|
||||
# 计算KL散度损失(与均匀分布的KL散度)
|
||||
uniform_prob = 1.0 / self.config.knowledge_num
|
||||
memory_probs_safe = memory_probs + 1e-10
|
||||
kl_loss = F.kl_div(
|
||||
torch.log(memory_probs_safe),
|
||||
torch.full_like(memory_probs, uniform_prob),
|
||||
reduction='sum'
|
||||
)
|
||||
|
||||
# 计算基尼系数损失
|
||||
sorted_probs, _ = torch.sort(memory_probs)
|
||||
n = self.config.knowledge_num
|
||||
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
|
||||
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
|
||||
gini_loss = gini_coeff
|
||||
|
||||
# 组合平衡损失
|
||||
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
|
||||
|
||||
# 计算统计信息
|
||||
# 计算统计信息(不包括balance损失)
|
||||
with torch.no_grad():
|
||||
coverage_rate = (memory_counts > 0.01).float().mean().item() # 被选中概率>1%的记忆比例
|
||||
top10_threshold = torch.quantile(memory_counts, 0.9)
|
||||
@ -560,8 +460,6 @@ class MiniMindBlock(nn.Module):
|
||||
selection_variance = memory_counts.var().item()
|
||||
|
||||
stats = {
|
||||
'gini_coefficient': gini_coeff.item(),
|
||||
'kl_divergence': kl_loss.item(),
|
||||
'coverage_rate': coverage_rate,
|
||||
'hot_memories': hot_memories,
|
||||
'dead_memories': dead_memories,
|
||||
@ -570,7 +468,7 @@ class MiniMindBlock(nn.Module):
|
||||
'min_selections': memory_counts.min().item(),
|
||||
}
|
||||
|
||||
return balance_loss, stats
|
||||
return stats
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
@ -689,23 +587,23 @@ class MiniMindLM(PreTrainedModel):
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
|
||||
# 收集所有层的损失和统计信息 - 实验1.4.9: 四损失系统
|
||||
total_balance_loss = 0
|
||||
# 收集所有层的损失和统计信息 - 实验: 两损失系统(移除balance_loss)
|
||||
total_similarity_loss = 0
|
||||
total_diversity_loss = 0
|
||||
all_layer_stats = {}
|
||||
all_ema_stats = {}
|
||||
all_cosine_stats = {}
|
||||
n = 0
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
n = n+1
|
||||
if collect_ema_stats:
|
||||
h, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
|
||||
h, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
|
||||
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
|
||||
else:
|
||||
h, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
|
||||
h, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
|
||||
|
||||
# 累加四种损失
|
||||
total_balance_loss += balance_loss
|
||||
# 累加两种损失(移除balance_loss)
|
||||
total_similarity_loss += similarity_loss
|
||||
total_diversity_loss += diversity_loss
|
||||
|
||||
@ -719,11 +617,10 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
logits = self.output(self.norm(h))
|
||||
|
||||
# 🔥 新的四损失结构
|
||||
# 🔥 新的两损失结构(移除balance_loss)
|
||||
aux_loss = {
|
||||
'balance_loss': total_balance_loss,
|
||||
'similarity_loss': total_similarity_loss,
|
||||
'diversity_loss': total_diversity_loss,
|
||||
'similarity_loss': total_similarity_loss / n,
|
||||
'diversity_loss': total_diversity_loss / n,
|
||||
}
|
||||
|
||||
self.OUT.__setitem__('last_hidden_state', h)
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
# 2. 强化DeepSpeed:参数+优化器CPU offload + 异步I/O优化
|
||||
#
|
||||
# 📝 优化策略说明:
|
||||
# - 不使用梯度检查点:避免对四损失系统和Gumbel-Softmax的数值稳定性影响
|
||||
# - 不使用梯度检查点:避免对三损失系统和Gumbel-Softmax的数值稳定性影响
|
||||
# - 专注安全优化:确保训练质量的同时减少显存占用
|
||||
#
|
||||
# 使用方法:
|
||||
@ -21,7 +21,7 @@
|
||||
# 🧑🔬 实验基本信息
|
||||
# ----------------------------------------------------------------------------
|
||||
EXPERIMENT_VERSION="1.4.10_optimized"
|
||||
EXPERIMENT_DESCRIPTION="四损失系统优化版 - 二大安全显存优化策略实现"
|
||||
EXPERIMENT_DESCRIPTION="三损失系统优化版 - 二大安全显存优化策略实现"
|
||||
RESEARCHER_NAME="AI Assistant"
|
||||
EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')"
|
||||
|
||||
@ -80,7 +80,7 @@ GRAD_CLIP="1.0"
|
||||
WARMUP_ITERS="0"
|
||||
|
||||
# 🔥 四损失系统配置 (保持与1.4.10一致)
|
||||
BALANCE_LOSS_COEF="0.01" # 平衡损失系数
|
||||
# BALANCE_LOSS_COEF removed - using only three losses now
|
||||
SIMILARITY_LOSS_COEF="0.8" # 相似度损失系数(核心损失)
|
||||
DIVERSITY_LOSS_COEF="0.2" # 多样性损失系数(避免候选重复)
|
||||
|
||||
@ -165,7 +165,7 @@ GPU设备: $CUDA_VISIBLE_DEVICES
|
||||
混合精度: $MIXED_PRECISION
|
||||
========================================
|
||||
模型配置:
|
||||
模型类型: $MODEL_TYPE (Token-based Memory + 四损失系统)
|
||||
模型类型: $MODEL_TYPE (Token-based Memory + 三损失系统)
|
||||
模型大小: $MODEL_SIZE MB
|
||||
维度: $DIM
|
||||
层数: $N_LAYERS
|
||||
@ -184,8 +184,7 @@ GPU设备: $CUDA_VISIBLE_DEVICES
|
||||
有效批次大小: $((BATCH_SIZE * ACCUMULATION_STEPS * 4))
|
||||
数据类型: $DTYPE
|
||||
========================================
|
||||
🔥 四损失系统配置:
|
||||
平衡损失系数: $BALANCE_LOSS_COEF (记忆选择平衡)
|
||||
🔥 三损失系统配置 (已移除Balance Loss):
|
||||
相似度损失系数: $SIMILARITY_LOSS_COEF (语义匹配优化)
|
||||
多样性损失系数: $DIVERSITY_LOSS_COEF (候选集多样性)
|
||||
========================================
|
||||
@ -255,8 +254,7 @@ run_experiment() {
|
||||
train_cmd+=" --model_size $MODEL_SIZE"
|
||||
train_cmd+=" --freeze_ratio $FREEZE_RATIO"
|
||||
|
||||
# 🔥 四损失系统参数
|
||||
train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF"
|
||||
# 🔥 三损失系统参数 (Balance Loss已移除)
|
||||
train_cmd+=" --similarity_loss_coef $SIMILARITY_LOSS_COEF"
|
||||
train_cmd+=" --diversity_loss_coef $DIVERSITY_LOSS_COEF"
|
||||
|
||||
@ -325,9 +323,9 @@ EOF
|
||||
echo "🛑 停止训练: kill $train_pid"
|
||||
echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT"
|
||||
echo ""
|
||||
echo "🧠 显存优化版四损失系统正在测试中..."
|
||||
echo "🧠 显存优化版三损失系统正在测试中..."
|
||||
echo " 🔥 二大安全优化策略已启用"
|
||||
echo " 🔥 损失结构: CE + Balance + Similarity + Diversity"
|
||||
echo " 🔥 损失结构: CE + Similarity + Diversity (三损失系统)"
|
||||
echo " 🔥 候选机制: 16个候选 → Gumbel-Softmax选择1个最佳"
|
||||
echo " 🔥 数值稳定性: 完全保持,无梯度检查点干扰"
|
||||
echo " 🔥 DeepSpeed优化: 参数+优化器CPU offload"
|
||||
@ -343,7 +341,7 @@ EOF
|
||||
echo "🎯 预期改进:"
|
||||
echo " - 显存使用: 适配A800 80GB (原版无法运行)"
|
||||
echo " - 训练稳定性: 优化版更稳定"
|
||||
echo " - 四损失收敛: 与原版期望一致"
|
||||
echo " - 三损失收敛: 与原版期望一致"
|
||||
echo " - 生成质量: 保持原版目标质量"
|
||||
echo ""
|
||||
echo "⏱️ 预计训练时间: 18-20小时 (无梯度检查点重复计算)"
|
||||
@ -351,7 +349,7 @@ EOF
|
||||
echo ""
|
||||
echo "🔍 关键监控指标:"
|
||||
echo " - GPU显存占用: 应保持在70GB以下"
|
||||
echo " - 四损失收敛: 与原版1.4.10对比"
|
||||
echo " - 三损失收敛: 与原版1.4.10对比"
|
||||
echo " - 训练稳定性: 无OOM错误"
|
||||
echo " - 优化效果验证: 记忆选择质量"
|
||||
echo ""
|
||||
@ -383,7 +381,7 @@ trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM
|
||||
main() {
|
||||
echo "============================================================================"
|
||||
echo "🧠 MiniMind 预训练实验 1.4.10 优化版"
|
||||
echo "🎯 四损失系统 + 二大安全显存优化策略"
|
||||
echo "🎯 三损失系统 + 二大安全显存优化策略"
|
||||
echo "============================================================================"
|
||||
echo ""
|
||||
echo "🔥 核心优化策略:"
|
||||
@ -395,7 +393,7 @@ main() {
|
||||
echo "🎯 显存优化目标:"
|
||||
echo " ✓ 原版1.4.10: 需要80GB+ → 优化版: 45-55GB (保守估计)"
|
||||
echo " ✓ A800 80GB兼容: 从无法运行 → 完全兼容"
|
||||
echo " ✓ 训练质量保持: 四损失系统功能完整,数值稳定"
|
||||
echo " ✓ 训练质量保持: 三损失系统功能完整,数值稳定"
|
||||
echo " ✓ 收敛行为一致: 与原版1.4.10期望完全一致"
|
||||
echo ""
|
||||
echo "🔧 技术实现细节:"
|
||||
|
||||
@ -1013,21 +1013,26 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
).view(Y.size())
|
||||
ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum()
|
||||
|
||||
# 🔥 实验1.4.9: 四损失系统处理
|
||||
balance_loss = 0
|
||||
# 🔥 实验: 三损失系统处理 + 分布式训练聚合修复
|
||||
similarity_loss = 0
|
||||
diversity_loss = 0
|
||||
|
||||
if hasattr(res, 'aux_loss') and res.aux_loss is not None:
|
||||
aux_loss = res.aux_loss
|
||||
if isinstance(aux_loss, dict):
|
||||
# 新的四损失结构
|
||||
balance_loss = aux_loss.get('balance_loss', 0)
|
||||
# 三损失结构(移除balance_loss)
|
||||
similarity_loss = aux_loss.get('similarity_loss', 0)
|
||||
diversity_loss = aux_loss.get('diversity_loss', 0)
|
||||
|
||||
# 🔥 修复分布式训练中的损失聚合问题
|
||||
# 对三损失进行跨GPU聚合,获得全局平均值
|
||||
if isinstance(similarity_loss, torch.Tensor):
|
||||
similarity_loss = accelerator.gather(similarity_loss).mean()
|
||||
if isinstance(diversity_loss, torch.Tensor):
|
||||
diversity_loss = accelerator.gather(diversity_loss).mean()
|
||||
else:
|
||||
# 向后兼容:旧的单一aux_loss
|
||||
balance_loss = aux_loss
|
||||
# 向后兼容:旧的单一aux_loss,现在忽略(因为我们移除了balance_loss)
|
||||
pass
|
||||
|
||||
# 获取余弦相似度统计信息(如果模型支持)
|
||||
cosine_stats = {}
|
||||
@ -1039,14 +1044,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
if selected_similarities:
|
||||
avg_selected_similarity = np.mean(selected_similarities)
|
||||
|
||||
# 🔥 四损失系统:CE + Balance + Similarity + Diversity
|
||||
# 🔥 三损失系统:CE + Similarity + Diversity(移除Balance Loss)
|
||||
# 损失系数可以通过命令行参数调整
|
||||
balance_coef = getattr(args, 'balance_loss_coef', 0.01)
|
||||
similarity_coef = getattr(args, 'similarity_loss_coef', 0.1)
|
||||
diversity_coef = getattr(args, 'diversity_loss_coef', 0.05)
|
||||
|
||||
total_loss = (ce_loss +
|
||||
balance_coef * balance_loss +
|
||||
similarity_coef * similarity_loss +
|
||||
diversity_coef * diversity_loss)
|
||||
loss = total_loss / args.accumulation_steps
|
||||
@ -1247,13 +1250,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
layer_stats = res.layer_stats
|
||||
|
||||
|
||||
# 🔥 构建四损失系统的日志字典
|
||||
# 🔥 构建三损失系统的日志字典
|
||||
log_dict = {
|
||||
"epoch": epoch + 1,
|
||||
"step": step + 1,
|
||||
"total_steps_in_epoch": total_steps_in_epoch,
|
||||
"train/loss_ce": ce_loss.item(),
|
||||
"train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss,
|
||||
"train/loss_similarity": similarity_loss.item() if isinstance(similarity_loss, torch.Tensor) else similarity_loss,
|
||||
"train/loss_diversity": diversity_loss.item() if isinstance(diversity_loss, torch.Tensor) else diversity_loss,
|
||||
"train/loss_total": total_loss.item(),
|
||||
@ -1287,10 +1289,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
'train/avg_selected_similarity': avg_selected_similarity, # 🔥 使用选中记忆的相似度
|
||||
})
|
||||
|
||||
# 🔥 四损失系统的控制台输出
|
||||
# 🔥 三损失系统的控制台输出
|
||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||
f"CE: {log_dict['train/loss_ce']:.4f}, "
|
||||
f"Bal: {log_dict['train/loss_balance']:.4f}, "
|
||||
f"Sim: {log_dict['train/loss_similarity']:.4f}, "
|
||||
f"Div: {log_dict['train/loss_diversity']:.4f}, "
|
||||
f"Total: {log_dict['train/loss_total']:.4f}, "
|
||||
@ -1387,7 +1388,6 @@ def main():
|
||||
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
|
||||
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
|
||||
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
|
||||
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
|
||||
parser.add_argument("--similarity_loss_coef", type=float, default=0.1, help="相似度损失系数(实验1.4.9)")
|
||||
parser.add_argument("--diversity_loss_coef", type=float, default=0.05, help="多样性损失系数(实验1.4.9)")
|
||||
parser.add_argument("--val_data_path", type=str, default="/home/zym/Code/stable/eval_data.json", help="验证数据集路径")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user