Experiment 1.4.4:负载平衡有效
This commit is contained in:
parent
fcdbd220a8
commit
e61d92c4bc
@ -58,6 +58,13 @@ def load_model(model_path, model_type, device, config_params=None):
|
|||||||
from model.model_no_feed import MiniMindLM
|
from model.model_no_feed import MiniMindLM
|
||||||
elif model_type == "model_memory":
|
elif model_type == "model_memory":
|
||||||
from model.model_memory import MiniMindLM
|
from model.model_memory import MiniMindLM
|
||||||
|
elif model_type.startswith("model_memory_"):
|
||||||
|
# 支持通用的model_memory_X_X_X格式
|
||||||
|
try:
|
||||||
|
module = __import__(f"model.{model_type}", fromlist=["MiniMindLM"])
|
||||||
|
MiniMindLM = getattr(module, "MiniMindLM")
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ValueError(f"无法导入模型类型 {model_type}: {e}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||||
|
|
||||||
@ -254,6 +261,12 @@ def evaluate_sample(model, tokenizer, text, input_length=100, predict_length=100
|
|||||||
ground_truth_text: 真实文本
|
ground_truth_text: 真实文本
|
||||||
loss: 预测损失(如果可计算)
|
loss: 预测损失(如果可计算)
|
||||||
"""
|
"""
|
||||||
|
# 添加与训练时一致的BOS/EOS token处理
|
||||||
|
if not text.startswith(tokenizer.bos_token):
|
||||||
|
text = f"{tokenizer.bos_token}{text}"
|
||||||
|
if not text.endswith(tokenizer.eos_token):
|
||||||
|
text = f"{text}{tokenizer.eos_token}"
|
||||||
|
|
||||||
# 对文本进行分词
|
# 对文本进行分词
|
||||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
@ -347,11 +360,10 @@ def evaluate_sample(model, tokenizer, text, input_length=100, predict_length=100
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='评估预训练模型')
|
parser = argparse.ArgumentParser(description='评估预训练模型')
|
||||||
parser.add_argument('--model_path', type=str, default='out/experiment_1_4_0/pretrain_512.pth',
|
parser.add_argument('--model_path', type=str, default='out/experiment_1_4_1/pretrain_512.pth',
|
||||||
help='模型权重文件路径')
|
help='模型权重文件路径')
|
||||||
parser.add_argument('--model_type', type=str, default='model',
|
parser.add_argument('--model_type', type=str, default='model_memory',
|
||||||
choices=['model', 'model_original', 'model_no_feed', 'model_memory'],
|
help='模型类型 (支持model, model_original, model_no_feed, model_memory, model_memory_X_X_X等)')
|
||||||
help='模型类型')
|
|
||||||
parser.add_argument('--data_path', type=str, default='dataset/stable/eval_data.json',
|
parser.add_argument('--data_path', type=str, default='dataset/stable/eval_data.json',
|
||||||
help='评估数据集路径')
|
help='评估数据集路径')
|
||||||
parser.add_argument('--num_samples', type=int, default=20,
|
parser.add_argument('--num_samples', type=int, default=20,
|
||||||
@ -427,8 +439,8 @@ def main():
|
|||||||
'n_routed_experts': args.n_routed_experts,
|
'n_routed_experts': args.n_routed_experts,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 只有model、model_no_feed和model_memory需要KnowledgeDataset参数
|
# 只有model、model_no_feed和model_memory系列需要KnowledgeDataset参数
|
||||||
if args.model_type in ['model', 'model_no_feed', 'model_memory']:
|
if args.model_type in ['model', 'model_no_feed', 'model_memory'] or args.model_type.startswith('model_memory_'):
|
||||||
config_params.update({
|
config_params.update({
|
||||||
'knowledge_num': args.knowledge_num,
|
'knowledge_num': args.knowledge_num,
|
||||||
'knowledge_length': args.knowledge_length,
|
'knowledge_length': args.knowledge_length,
|
||||||
|
|||||||
@ -153,6 +153,8 @@ class MemoryGate(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
memory_indices: [batch_size, seq_len, num_selected]
|
memory_indices: [batch_size, seq_len, num_selected]
|
||||||
memory_scores: [batch_size, seq_len, num_selected]
|
memory_scores: [batch_size, seq_len, num_selected]
|
||||||
|
balance_loss: 平衡损失(KL散度 + 基尼系数)
|
||||||
|
stats: 监控统计信息字典
|
||||||
"""
|
"""
|
||||||
bsz, seq_len, _ = x.shape
|
bsz, seq_len, _ = x.shape
|
||||||
|
|
||||||
@ -186,80 +188,132 @@ class MemoryGate(nn.Module):
|
|||||||
memory_scores = F.softmax(final_scores, dim=-1)
|
memory_scores = F.softmax(final_scores, dim=-1)
|
||||||
memory_scores = self.dropout(memory_scores)
|
memory_scores = self.dropout(memory_scores)
|
||||||
|
|
||||||
return memory_indices, memory_scores
|
# 计算平衡损失和监控统计
|
||||||
|
balance_loss, stats = self._compute_balance_loss_and_stats(memory_indices, memory_scores)
|
||||||
|
|
||||||
|
return memory_indices, memory_scores, balance_loss, stats
|
||||||
|
|
||||||
|
def _compute_balance_loss_and_stats(self, memory_indices, memory_scores):
|
||||||
|
"""
|
||||||
|
计算平衡损失和监控统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_indices: [batch_size, seq_len, num_selected]
|
||||||
|
memory_scores: [batch_size, seq_len, num_selected]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
balance_loss: 标量张量
|
||||||
|
stats: 统计信息字典
|
||||||
|
"""
|
||||||
|
bsz, seq_len, num_selected = memory_indices.shape
|
||||||
|
device = memory_indices.device
|
||||||
|
|
||||||
|
# 1. 计算记忆选择分布
|
||||||
|
# 将所有选择的记忆索引展平
|
||||||
|
flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected]
|
||||||
|
|
||||||
|
# 统计每个记忆条目被选中的次数
|
||||||
|
memory_counts = torch.zeros(self.knowledge_num, device=device)
|
||||||
|
memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float))
|
||||||
|
|
||||||
|
# 计算选择概率分布
|
||||||
|
total_selections = bsz * seq_len * num_selected
|
||||||
|
memory_probs = memory_counts / total_selections
|
||||||
|
|
||||||
|
# 2. 计算KL散度损失(与均匀分布的KL散度)
|
||||||
|
uniform_prob = 1.0 / self.knowledge_num
|
||||||
|
# 避免log(0)的问题
|
||||||
|
memory_probs_safe = memory_probs + 1e-10
|
||||||
|
kl_loss = F.kl_div(
|
||||||
|
torch.log(memory_probs_safe),
|
||||||
|
torch.full_like(memory_probs, uniform_prob),
|
||||||
|
reduction='sum'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 计算基尼系数损失(衡量分布不平等程度)
|
||||||
|
sorted_probs, _ = torch.sort(memory_probs)
|
||||||
|
n = self.knowledge_num
|
||||||
|
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
|
||||||
|
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
|
||||||
|
gini_loss = gini_coeff # 基尼系数越大,分布越不均匀
|
||||||
|
|
||||||
|
# 4. 组合平衡损失
|
||||||
|
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
|
||||||
|
|
||||||
|
# 5. 计算监控统计信息
|
||||||
|
with torch.no_grad():
|
||||||
|
# 记忆覆盖率:被选中的记忆条目占总数的比例
|
||||||
|
coverage_rate = (memory_counts > 0).float().mean().item()
|
||||||
|
|
||||||
|
# 热点记忆:选择次数前10%的记忆条目
|
||||||
|
top10_threshold = torch.quantile(memory_counts, 0.9)
|
||||||
|
hot_memories = (memory_counts >= top10_threshold).sum().item()
|
||||||
|
|
||||||
|
# 死记忆:从未被选中的记忆条目
|
||||||
|
dead_memories = (memory_counts == 0).sum().item()
|
||||||
|
|
||||||
|
# 记忆选择方差(衡量不平衡程度)
|
||||||
|
selection_variance = memory_counts.var().item()
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
'gini_coefficient': gini_coeff.item(),
|
||||||
|
'kl_divergence': kl_loss.item(),
|
||||||
|
'coverage_rate': coverage_rate,
|
||||||
|
'hot_memories': hot_memories,
|
||||||
|
'dead_memories': dead_memories,
|
||||||
|
'selection_variance': selection_variance,
|
||||||
|
'max_selections': memory_counts.max().item(),
|
||||||
|
'min_selections': memory_counts.min().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return balance_loss, stats
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionMemory(nn.Module):
|
class GatedMemoryFusion(nn.Module):
|
||||||
"""Cross attention using selected memory as K and V"""
|
"""Gated MLP fusion for concatenated h_attn and selected memories"""
|
||||||
def __init__(self, config: LMConfig):
|
def __init__(self, config: LMConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.n_heads = config.n_heads
|
|
||||||
self.head_dim = config.dim // config.n_heads
|
|
||||||
self.dim = config.dim
|
self.dim = config.dim
|
||||||
self.knowledge_dim = config.knowledge_dim
|
self.knowledge_dim = config.knowledge_dim
|
||||||
|
self.num_selected = getattr(config, 'num_selected', 16)
|
||||||
|
|
||||||
# Q从self-attention输出计算
|
# 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
|
||||||
self.wq = nn.Linear(config.dim, config.dim, bias=False)
|
concat_dim = self.dim + self.num_selected * self.knowledge_dim
|
||||||
|
|
||||||
# K,V从记忆数据计算
|
# 类似SwiGLU的门控MLP结构
|
||||||
self.wk = nn.Linear(config.knowledge_dim, config.dim, bias=False)
|
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
|
||||||
self.wv = nn.Linear(config.knowledge_dim, config.dim, bias=False)
|
self.up_proj = nn.Linear(concat_dim, self.dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
|
||||||
# 输出投影
|
|
||||||
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, memory_data: torch.Tensor, memory_scores: torch.Tensor):
|
def forward(self, h_attn: torch.Tensor, selected_memories: torch.Tensor, memory_scores: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: [batch_size, seq_len, dim] - Query from self attention
|
h_attn: [batch_size, seq_len, dim] - Self attention output
|
||||||
memory_data: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data
|
selected_memories: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data
|
||||||
memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights
|
memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights (not used in concatenation approach)
|
||||||
Returns:
|
Returns:
|
||||||
output: [batch_size, seq_len, dim]
|
output: [batch_size, seq_len, dim]
|
||||||
"""
|
"""
|
||||||
bsz, seq_len, _ = x.shape
|
bsz, seq_len, _ = h_attn.shape
|
||||||
num_selected = memory_data.shape[2]
|
|
||||||
|
|
||||||
# 计算Query
|
# 将选中的记忆展平为一维向量
|
||||||
q = self.wq(x) # [batch, seq_len, dim]
|
# [batch, seq_len, num_selected, knowledge_dim] -> [batch, seq_len, num_selected * knowledge_dim]
|
||||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [batch, n_heads, seq_len, head_dim]
|
memory_flat = selected_memories.view(bsz, seq_len, -1)
|
||||||
|
|
||||||
# 对选中的记忆数据计算K和V
|
# 拼接h_attn和记忆信息
|
||||||
memory_flat = memory_data.view(bsz * seq_len * num_selected, self.knowledge_dim)
|
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim]
|
||||||
k_flat = self.wk(memory_flat) # [batch * seq_len * num_selected, dim]
|
|
||||||
v_flat = self.wv(memory_flat) # [batch * seq_len * num_selected, dim]
|
|
||||||
|
|
||||||
# 重塑K和V
|
# 门控MLP处理(类似SwiGLU)
|
||||||
k = k_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim]
|
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
|
||||||
v = v_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim]
|
up = self.up_proj(concat_input) # [batch, seq_len, dim]
|
||||||
|
fusion_output = gate * up # Element-wise multiplication
|
||||||
|
|
||||||
# 扩展Q以匹配记忆维度进行交叉注意力
|
# 输出投影
|
||||||
q_expanded = q.unsqueeze(3) # [batch, n_heads, seq_len, 1, head_dim]
|
output = self.down_proj(fusion_output) # [batch, seq_len, dim]
|
||||||
|
output = self.dropout(output)
|
||||||
# 计算注意力分数
|
|
||||||
# q_expanded: [batch, n_heads, seq_len, 1, head_dim]
|
|
||||||
# k: [batch, n_heads, seq_len, num_selected, head_dim]
|
|
||||||
scores = torch.matmul(q_expanded, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [batch, n_heads, seq_len, 1, num_selected]
|
|
||||||
scores = scores.squeeze(3) # [batch, n_heads, seq_len, num_selected]
|
|
||||||
|
|
||||||
# 应用记忆选择权重
|
|
||||||
memory_scores_expanded = memory_scores.unsqueeze(1).expand(-1, self.n_heads, -1, -1) # [batch, n_heads, seq_len, num_selected]
|
|
||||||
scores = scores + memory_scores_expanded.log() # 在log空间相加
|
|
||||||
|
|
||||||
# Softmax归一化
|
|
||||||
attn_weights = F.softmax(scores, dim=-1) # [batch, n_heads, seq_len, num_selected]
|
|
||||||
attn_weights = self.dropout(attn_weights)
|
|
||||||
|
|
||||||
# 应用注意力权重到V
|
|
||||||
# attn_weights: [batch, n_heads, seq_len, num_selected]
|
|
||||||
# v: [batch, n_heads, seq_len, num_selected, head_dim]
|
|
||||||
output = torch.einsum('bhlk,bhlkd->bhld', attn_weights, v) # [batch, n_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
# 重塑输出
|
|
||||||
output = output.transpose(1, 2).reshape(bsz, seq_len, self.dim) # [batch, seq_len, dim]
|
|
||||||
output = self.wo(output)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -279,7 +333,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
|
|
||||||
# 记忆相关模块
|
# 记忆相关模块
|
||||||
self.memory_gate = MemoryGate(config)
|
self.memory_gate = MemoryGate(config)
|
||||||
self.cross_attention_memory = CrossAttentionMemory(config)
|
self.gated_memory_fusion = GatedMemoryFusion(config)
|
||||||
|
|
||||||
def forward(self, x, pos_cis, memory_bank):
|
def forward(self, x, pos_cis, memory_bank):
|
||||||
"""
|
"""
|
||||||
@ -287,16 +341,21 @@ class MiniMindBlock(nn.Module):
|
|||||||
x: [batch_size, seq_len, dim]
|
x: [batch_size, seq_len, dim]
|
||||||
pos_cis: positional encoding
|
pos_cis: positional encoding
|
||||||
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
|
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch_size, seq_len, dim]
|
||||||
|
balance_loss: 该层的平衡损失
|
||||||
|
layer_stats: 该层的监控统计信息
|
||||||
"""
|
"""
|
||||||
# Self attention
|
# Self attention
|
||||||
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
||||||
h = x + h_attn
|
h = x + h_attn
|
||||||
|
|
||||||
# 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出)
|
# 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出)
|
||||||
h_for_memory = self.memory_norm(h)
|
h_for_memory = self.memory_norm(h_attn)
|
||||||
|
|
||||||
# 门控选择记忆
|
# 门控选择记忆
|
||||||
memory_indices, memory_scores = self.memory_gate(h_for_memory)
|
memory_indices, memory_scores, balance_loss, layer_stats = self.memory_gate(h_for_memory)
|
||||||
|
|
||||||
# 根据索引获取记忆数据
|
# 根据索引获取记忆数据
|
||||||
bsz, seq_len, num_selected = memory_indices.shape
|
bsz, seq_len, num_selected = memory_indices.shape
|
||||||
@ -304,14 +363,13 @@ class MiniMindBlock(nn.Module):
|
|||||||
selected_memory = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_dim]
|
selected_memory = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_dim]
|
||||||
selected_memory = selected_memory.view(bsz, seq_len, num_selected, -1) # [batch, seq_len, num_selected, knowledge_dim]
|
selected_memory = selected_memory.view(bsz, seq_len, num_selected, -1) # [batch, seq_len, num_selected, knowledge_dim]
|
||||||
|
|
||||||
h = x + selected_memory
|
# 门控MLP融合:串型连接h_attn和选中的记忆
|
||||||
# 交叉注意力:Q来自h_attn,K和V来自选中的记忆
|
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
|
||||||
memory_output = self.cross_attention_memory(x, selected_memory, memory_scores)
|
|
||||||
|
|
||||||
# 残差连接
|
# 残差连接
|
||||||
out = h + memory_output
|
out = h + memory_output
|
||||||
|
|
||||||
return out
|
return out, balance_loss, layer_stats
|
||||||
|
|
||||||
|
|
||||||
class MiniMindLM(PreTrainedModel):
|
class MiniMindLM(PreTrainedModel):
|
||||||
@ -337,7 +395,58 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
requires_grad=True
|
requires_grad=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 记录上一步的记忆库状态,用于计算更新统计
|
||||||
|
self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False)
|
||||||
|
|
||||||
self.OUT = CausalLMOutputWithPast()
|
self.OUT = CausalLMOutputWithPast()
|
||||||
|
|
||||||
|
def get_memory_update_stats(self):
|
||||||
|
"""
|
||||||
|
计算记忆库更新统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
update_stats: 包含更新统计的字典
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0:
|
||||||
|
# 计算L2距离变化
|
||||||
|
l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1)
|
||||||
|
avg_l2_distance = l2_distance.mean().item()
|
||||||
|
max_l2_distance = l2_distance.max().item()
|
||||||
|
|
||||||
|
# 计算余弦相似度
|
||||||
|
cos_sim = F.cosine_similarity(
|
||||||
|
self.memory_bank.view(-1),
|
||||||
|
self.prev_memory_bank.view(-1),
|
||||||
|
dim=0
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# 计算更新率(发生显著变化的记忆条目比例)
|
||||||
|
threshold = 0.01 # 更新阈值
|
||||||
|
updated_memories = (l2_distance > threshold).sum().item()
|
||||||
|
update_rate = updated_memories / self.memory_bank.size(0)
|
||||||
|
|
||||||
|
update_stats = {
|
||||||
|
'memory_avg_l2_change': avg_l2_distance,
|
||||||
|
'memory_max_l2_change': max_l2_distance,
|
||||||
|
'memory_cosine_similarity': cos_sim,
|
||||||
|
'memory_update_rate': update_rate,
|
||||||
|
'memory_updated_count': updated_memories
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 第一次调用时的默认值
|
||||||
|
update_stats = {
|
||||||
|
'memory_avg_l2_change': 0.0,
|
||||||
|
'memory_max_l2_change': 0.0,
|
||||||
|
'memory_cosine_similarity': 1.0,
|
||||||
|
'memory_update_rate': 0.0,
|
||||||
|
'memory_updated_count': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新prev_memory_bank
|
||||||
|
self.prev_memory_bank.copy_(self.memory_bank)
|
||||||
|
|
||||||
|
return update_stats
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
@ -347,16 +456,26 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
h = self.dropout(self.tok_embeddings(input_ids))
|
h = self.dropout(self.tok_embeddings(input_ids))
|
||||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||||
|
|
||||||
for layer in self.layers:
|
# 收集所有层的平衡损失和统计信息
|
||||||
h = layer(h, pos_cis, self.memory_bank)
|
total_balance_loss = 0
|
||||||
|
all_layer_stats = {}
|
||||||
|
|
||||||
|
for layer_idx, layer in enumerate(self.layers):
|
||||||
|
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank)
|
||||||
|
total_balance_loss += balance_loss
|
||||||
|
# 为每层的统计信息添加前缀
|
||||||
|
for key, value in layer_stats.items():
|
||||||
|
all_layer_stats[f'layer_{layer_idx}_{key}'] = value
|
||||||
|
|
||||||
logits = self.output(self.norm(h))
|
logits = self.output(self.norm(h))
|
||||||
|
|
||||||
# 统一不使用 aux_loss
|
# 使用总的平衡损失作为aux_loss
|
||||||
aux_loss = 0
|
aux_loss = total_balance_loss
|
||||||
|
|
||||||
self.OUT.__setitem__('last_hidden_state', h)
|
self.OUT.__setitem__('last_hidden_state', h)
|
||||||
self.OUT.__setitem__('logits', logits)
|
self.OUT.__setitem__('logits', logits)
|
||||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||||
|
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
|
||||||
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
||||||
return self.OUT
|
return self.OUT
|
||||||
|
|
||||||
|
|||||||
@ -143,6 +143,7 @@ dependencies = [
|
|||||||
"smmap==5.0.2",
|
"smmap==5.0.2",
|
||||||
"sniffio==1.3.1",
|
"sniffio==1.3.1",
|
||||||
"streamlit==1.30.0",
|
"streamlit==1.30.0",
|
||||||
|
"superclaude>=3.0.0.2",
|
||||||
"swankit==0.2.4",
|
"swankit==0.2.4",
|
||||||
"swanlab==0.6.4",
|
"swanlab==0.6.4",
|
||||||
"sympy==1.13.3",
|
"sympy==1.13.3",
|
||||||
|
|||||||
335
run_file/experiment_1_4_4.sh
Normal file
335
run_file/experiment_1_4_4.sh
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# MiniMind 实验脚本 - Experiment 1.4.4
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# 🎯 实验目标:
|
||||||
|
# 基于实验1.4.2的model_memory架构,深度验证记忆库机制,实现平衡损失和四维度监控体系
|
||||||
|
#
|
||||||
|
# 使用方法:
|
||||||
|
# bash run_file/experiment_1_4_4.sh
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🧑🔬 实验基本信息
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
EXPERIMENT_VERSION="1.4.4"
|
||||||
|
EXPERIMENT_DESCRIPTION="model_memory平衡损失与四维度监控实验"
|
||||||
|
RESEARCHER_NAME="AI Assistant"
|
||||||
|
EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 环境配置
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# 调试和监控环境变量
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export PYTHONFAULTHANDLER=1
|
||||||
|
export CUDA_LAUNCH_BLOCKING=1
|
||||||
|
|
||||||
|
# SwanLab 配置
|
||||||
|
export SWANLAB_PROJECT="MiniMind-Experiment-1.4.4"
|
||||||
|
|
||||||
|
# 日志配置
|
||||||
|
LOG_DIR="out/experiment_${EXPERIMENT_VERSION}"
|
||||||
|
mkdir -p "$LOG_DIR"
|
||||||
|
LOG_FILE="$LOG_DIR/experiment.log"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 硬件配置
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
CUDA_VISIBLE_DEVICES="0"
|
||||||
|
NUM_PROCESSES="1"
|
||||||
|
MIXED_PRECISION="bf16"
|
||||||
|
MAIN_PROCESS_PORT="29500"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 模型架构参数
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
MODEL_TYPE="model_memory"
|
||||||
|
MODEL_SIZE="50.0"
|
||||||
|
DIM="512"
|
||||||
|
N_LAYERS="8"
|
||||||
|
N_HEADS="32"
|
||||||
|
MAX_SEQ_LEN="512"
|
||||||
|
USE_MOE="false"
|
||||||
|
|
||||||
|
# 知识库配置(使用更小的记忆库以适应实验需求)
|
||||||
|
KNOWLEDGE_NUM="65536" # 256x256 = 65536,确保是完全平方数
|
||||||
|
KNOWLEDGE_LENGTH="32"
|
||||||
|
KNOWLEDGE_DIM="128"
|
||||||
|
DISABLE_DB="false"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 训练超参数
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
EPOCHS="3"
|
||||||
|
EMBEDDING_EPOCH="2"
|
||||||
|
BATCH_SIZE="128"
|
||||||
|
ACCUMULATION_STEPS="8"
|
||||||
|
LEARNING_RATE="2e-4"
|
||||||
|
DTYPE="bfloat16"
|
||||||
|
GRAD_CLIP="1.0"
|
||||||
|
WARMUP_ITERS="0"
|
||||||
|
|
||||||
|
# 平衡损失配置
|
||||||
|
BALANCE_LOSS_COEF="0.1"
|
||||||
|
|
||||||
|
# 数据和缓存路径
|
||||||
|
DATA_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl"
|
||||||
|
DATABASE_INIT_PATH="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json"
|
||||||
|
CLUSTER_CACHE_PATH="/home/pci/ycz/Code/Minimind/cache/cluster_tokens_single.pt"
|
||||||
|
VAL_DATA_PATH="dataset/stable/eval_data.json"
|
||||||
|
|
||||||
|
# 训练配置(合并log_interval和profile参数)
|
||||||
|
NUM_WORKERS="1"
|
||||||
|
LOG_INTERVAL="100"
|
||||||
|
VAL_INTERVAL="100"
|
||||||
|
SAVE_INTERVAL="10000"
|
||||||
|
|
||||||
|
# 性能分析配置
|
||||||
|
USE_PROFILE="true"
|
||||||
|
PROFILE_INTERVAL="10"
|
||||||
|
MEMORY_MONITOR_INTERVAL="100"
|
||||||
|
|
||||||
|
# 高级功能
|
||||||
|
USE_FLASH_ATTN="true"
|
||||||
|
FAST_CLUSTERING="true"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 预检查函数
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
check_environment() {
|
||||||
|
echo "🔍 环境检查中..."
|
||||||
|
|
||||||
|
# 检查GPU可用性
|
||||||
|
if ! nvidia-smi &> /dev/null; then
|
||||||
|
echo "❌ 错误: 未检测到GPU或nvidia-smi不可用"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查CUDA设备
|
||||||
|
if ! nvidia-smi -i "$CUDA_VISIBLE_DEVICES" &> /dev/null; then
|
||||||
|
echo "❌ 错误: GPU $CUDA_VISIBLE_DEVICES 不可用"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查Python环境
|
||||||
|
if ! .venv/bin/python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null; then
|
||||||
|
echo "❌ 错误: PyTorch未正确安装"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查数据文件
|
||||||
|
if [[ ! -f "$DATA_PATH" ]]; then
|
||||||
|
echo "❌ 错误: 训练数据文件不存在: $DATA_PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f "$DATABASE_INIT_PATH" ]]; then
|
||||||
|
echo "❌ 错误: 数据库初始化文件不存在: $DATABASE_INIT_PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ 环境检查通过"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 实验信息记录
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
log_experiment_info() {
|
||||||
|
echo "📝 记录实验信息..."
|
||||||
|
cat > "$LOG_DIR/experiment_info.txt" << EOF
|
||||||
|
========================================
|
||||||
|
MiniMind 实验信息
|
||||||
|
========================================
|
||||||
|
实验版本: $EXPERIMENT_VERSION
|
||||||
|
实验描述: $EXPERIMENT_DESCRIPTION
|
||||||
|
研究者: $RESEARCHER_NAME
|
||||||
|
开始时间: $EXPERIMENT_DATE
|
||||||
|
========================================
|
||||||
|
硬件配置:
|
||||||
|
GPU设备: $CUDA_VISIBLE_DEVICES
|
||||||
|
进程数: $NUM_PROCESSES
|
||||||
|
混合精度: $MIXED_PRECISION
|
||||||
|
========================================
|
||||||
|
模型配置:
|
||||||
|
模型类型: $MODEL_TYPE
|
||||||
|
模型大小: $MODEL_SIZE MB
|
||||||
|
维度: $DIM
|
||||||
|
层数: $N_LAYERS
|
||||||
|
注意力头数: $N_HEADS
|
||||||
|
最大序列长度: $MAX_SEQ_LEN
|
||||||
|
知识库大小: $KNOWLEDGE_NUM
|
||||||
|
知识长度: $KNOWLEDGE_LENGTH
|
||||||
|
知识维度: $KNOWLEDGE_DIM
|
||||||
|
========================================
|
||||||
|
训练配置:
|
||||||
|
训练轮次: $EPOCHS
|
||||||
|
批次大小: $BATCH_SIZE
|
||||||
|
学习率: $LEARNING_RATE
|
||||||
|
梯度累积: $ACCUMULATION_STEPS
|
||||||
|
数据类型: $DTYPE
|
||||||
|
平衡损失系数: $BALANCE_LOSS_COEF
|
||||||
|
========================================
|
||||||
|
数据路径:
|
||||||
|
训练数据: $DATA_PATH
|
||||||
|
验证数据: $VAL_DATA_PATH
|
||||||
|
数据库初始化: $DATABASE_INIT_PATH
|
||||||
|
聚类缓存: $CLUSTER_CACHE_PATH
|
||||||
|
========================================
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 主执行函数
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
run_experiment() {
|
||||||
|
echo "🚀 开始执行实验 $EXPERIMENT_VERSION"
|
||||||
|
echo "📄 实验描述: $EXPERIMENT_DESCRIPTION"
|
||||||
|
echo "⏰ 开始时间: $EXPERIMENT_DATE"
|
||||||
|
|
||||||
|
# 构建训练命令
|
||||||
|
local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES .venv/bin/python train_pretrain_accelerate.py"
|
||||||
|
|
||||||
|
# 添加训练参数
|
||||||
|
train_cmd+=" --out_dir \"$LOG_DIR\""
|
||||||
|
train_cmd+=" --epochs $EPOCHS"
|
||||||
|
train_cmd+=" --embedding_epoch $EMBEDDING_EPOCH"
|
||||||
|
train_cmd+=" --batch_size $BATCH_SIZE"
|
||||||
|
train_cmd+=" --learning_rate $LEARNING_RATE"
|
||||||
|
train_cmd+=" --dtype $DTYPE"
|
||||||
|
train_cmd+=" --num_workers $NUM_WORKERS"
|
||||||
|
train_cmd+=" --accumulation_steps $ACCUMULATION_STEPS"
|
||||||
|
train_cmd+=" --grad_clip $GRAD_CLIP"
|
||||||
|
train_cmd+=" --warmup_iters $WARMUP_ITERS"
|
||||||
|
train_cmd+=" --log_interval $LOG_INTERVAL"
|
||||||
|
train_cmd+=" --val_interval $VAL_INTERVAL"
|
||||||
|
train_cmd+=" --save_interval $SAVE_INTERVAL"
|
||||||
|
train_cmd+=" --dim $DIM"
|
||||||
|
train_cmd+=" --n_layers $N_LAYERS"
|
||||||
|
train_cmd+=" --n_heads $N_HEADS"
|
||||||
|
train_cmd+=" --max_seq_len $MAX_SEQ_LEN"
|
||||||
|
train_cmd+=" --data_path \"$DATA_PATH\""
|
||||||
|
train_cmd+=" --val_data_path \"$VAL_DATA_PATH\""
|
||||||
|
train_cmd+=" --knowledge_num $KNOWLEDGE_NUM"
|
||||||
|
train_cmd+=" --knowledge_length $KNOWLEDGE_LENGTH"
|
||||||
|
train_cmd+=" --database_init_path \"$DATABASE_INIT_PATH\""
|
||||||
|
train_cmd+=" --memory_monitor_interval $MEMORY_MONITOR_INTERVAL"
|
||||||
|
train_cmd+=" --model_type \"$MODEL_TYPE\""
|
||||||
|
train_cmd+=" --model_size $MODEL_SIZE"
|
||||||
|
train_cmd+=" --balance_loss_coef $BALANCE_LOSS_COEF"
|
||||||
|
|
||||||
|
# 可选参数
|
||||||
|
if [[ "$USE_PROFILE" == "true" ]]; then
|
||||||
|
train_cmd+=" --profile"
|
||||||
|
train_cmd+=" --profile_interval $PROFILE_INTERVAL"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$USE_FLASH_ATTN" == "true" ]]; then
|
||||||
|
train_cmd+=" --use_flash_attn"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$FAST_CLUSTERING" == "true" ]]; then
|
||||||
|
train_cmd+=" --fast_clustering"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$CLUSTER_CACHE_PATH" != "None" ]]; then
|
||||||
|
train_cmd+=" --cluster_cache_path \"$CLUSTER_CACHE_PATH\""
|
||||||
|
fi
|
||||||
|
|
||||||
|
# SwanLab配置
|
||||||
|
train_cmd+=" --use_swanlab"
|
||||||
|
train_cmd+=" --swanlab_project \"$SWANLAB_PROJECT\""
|
||||||
|
train_cmd+=" --swanlab_online True"
|
||||||
|
|
||||||
|
echo "📋 执行命令:"
|
||||||
|
echo "$train_cmd"
|
||||||
|
echo
|
||||||
|
|
||||||
|
# 记录命令到日志文件
|
||||||
|
echo "执行命令: $train_cmd" >> "$LOG_FILE"
|
||||||
|
echo "开始时间: $(date)" >> "$LOG_FILE"
|
||||||
|
|
||||||
|
# 使用nohup执行训练(后台运行,输出写入日志文件)
|
||||||
|
echo "🔄 使用nohup后台运行训练,输出将写入日志文件: $LOG_FILE"
|
||||||
|
|
||||||
|
# 创建训练脚本
|
||||||
|
train_script="/tmp/train_${EXPERIMENT_VERSION}.sh"
|
||||||
|
cat > "$train_script" << EOF
|
||||||
|
#!/bin/bash
|
||||||
|
cd /home/pci/ycz/Code/pretrain-worktree
|
||||||
|
source /home/pci/ycz/Code/pretrain-worktree/.venv/bin/activate
|
||||||
|
$train_cmd
|
||||||
|
echo "结束时间: \$(date)"
|
||||||
|
echo "退出代码: \$?"
|
||||||
|
EOF
|
||||||
|
chmod +x "$train_script"
|
||||||
|
|
||||||
|
# 使用nohup后台运行
|
||||||
|
nohup bash "$train_script" >> "$LOG_FILE" 2>&1 &
|
||||||
|
local train_pid=$!
|
||||||
|
|
||||||
|
echo "🔥 训练进程已启动,PID: $train_pid"
|
||||||
|
echo "训练PID: $train_pid" >> "$LOG_FILE"
|
||||||
|
echo "训练脚本: $train_script" >> "$LOG_FILE"
|
||||||
|
|
||||||
|
# 等待几秒确保进程启动
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
# 检查进程是否还在运行
|
||||||
|
if kill -0 $train_pid 2>/dev/null; then
|
||||||
|
echo "✅ 训练进程正在后台运行"
|
||||||
|
echo "📋 实时查看日志: tail -f $LOG_FILE"
|
||||||
|
echo "📋 检查进程状态: ps -p $train_pid"
|
||||||
|
echo "🛑 停止训练: kill $train_pid"
|
||||||
|
echo "📈 SwanLab: https://swanlab.cn/project/$SWANLAB_PROJECT"
|
||||||
|
echo ""
|
||||||
|
echo "训练正在后台运行,可以安全关闭终端。"
|
||||||
|
else
|
||||||
|
echo "❌ 训练进程启动失败"
|
||||||
|
echo "📋 查看日志: $LOG_FILE"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 清理函数
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
cleanup() {
|
||||||
|
echo "🧹 清理临时文件..."
|
||||||
|
# 删除临时验证文件
|
||||||
|
rm -f /tmp/temp_val.jsonl
|
||||||
|
}
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 信号处理
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
trap cleanup EXIT
|
||||||
|
trap 'echo "❌ 实验被中断"; cleanup; exit 130' INT TERM
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# 🤖 主程序入口
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
main() {
|
||||||
|
echo "============================================================================"
|
||||||
|
echo "🧠 MiniMind 预训练实验 1.4.4"
|
||||||
|
echo "🎯 深度验证记忆库机制 - 平衡损失与四维度监控"
|
||||||
|
echo "============================================================================"
|
||||||
|
|
||||||
|
# 执行检查和初始化
|
||||||
|
check_environment
|
||||||
|
log_experiment_info
|
||||||
|
|
||||||
|
# 运行实验
|
||||||
|
run_experiment
|
||||||
|
|
||||||
|
echo "============================================================================"
|
||||||
|
echo "✅ 实验 $EXPERIMENT_VERSION 启动完成"
|
||||||
|
echo "📅 启动时间: $(date)"
|
||||||
|
echo "============================================================================"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行主程序
|
||||||
|
main "$@"
|
||||||
@ -24,7 +24,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|||||||
import swanlab # 替换wandb导入
|
import swanlab # 替换wandb导入
|
||||||
import gc # 添加垃圾回收模块
|
import gc # 添加垃圾回收模块
|
||||||
import psutil # 添加系统资源监控模块
|
import psutil # 添加系统资源监控模块
|
||||||
|
import json # 添加JSON支持
|
||||||
|
|
||||||
from model.LMConfig import LMConfig
|
from model.LMConfig import LMConfig
|
||||||
from model.dataset import PretrainDataset
|
from model.dataset import PretrainDataset
|
||||||
@ -98,6 +98,86 @@ def Logger(msg, accelerator=None):
|
|||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
return str(datetime.timedelta(seconds=int(seconds)))
|
return str(datetime.timedelta(seconds=int(seconds)))
|
||||||
|
|
||||||
|
def create_validation_dataset(val_data_path, tokenizer, max_length, num_samples=200):
|
||||||
|
"""
|
||||||
|
创建验证数据集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val_data_path: 验证数据文件路径
|
||||||
|
tokenizer: tokenizer实例
|
||||||
|
max_length: 最大序列长度
|
||||||
|
num_samples: 验证样本数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
val_dataset: 验证数据集
|
||||||
|
"""
|
||||||
|
if not os.path.exists(val_data_path):
|
||||||
|
Logger(f"警告:验证数据文件不存在: {val_data_path},跳过验证评估")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 读取验证数据
|
||||||
|
val_data = []
|
||||||
|
with open(val_data_path, 'r', encoding='utf-8') as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
if i >= num_samples: # 限制验证样本数量
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
sample = json.loads(line)
|
||||||
|
val_data.append(sample['text'])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建临时验证文件
|
||||||
|
temp_val_file = "/tmp/temp_val.jsonl"
|
||||||
|
with open(temp_val_file, 'w', encoding='utf-8') as f:
|
||||||
|
for text in val_data:
|
||||||
|
f.write(json.dumps({'text': text}) + '\n')
|
||||||
|
|
||||||
|
# 使用PretrainDataset创建验证集
|
||||||
|
val_dataset = PretrainDataset(temp_val_file, tokenizer, max_length=max_length)
|
||||||
|
Logger(f"创建验证数据集成功,包含 {len(val_data)} 个样本")
|
||||||
|
|
||||||
|
return val_dataset
|
||||||
|
|
||||||
|
def validate_model(model, val_loader, loss_fct, ctx, accelerator):
|
||||||
|
"""
|
||||||
|
执行模型验证
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 模型实例
|
||||||
|
val_loader: 验证数据加载器
|
||||||
|
loss_fct: 损失函数
|
||||||
|
ctx: 上下文管理器
|
||||||
|
accelerator: Accelerator实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
avg_val_loss: 平均验证损失
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
X, Y, loss_mask = batch
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
res = model(X)
|
||||||
|
loss = loss_fct(
|
||||||
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
|
Y.view(-1)
|
||||||
|
).view(Y.size())
|
||||||
|
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
avg_val_loss = total_loss / num_batches if num_batches > 0 else float('inf')
|
||||||
|
return avg_val_loss
|
||||||
|
|
||||||
# 获取学习率函数
|
# 获取学习率函数
|
||||||
def get_lr(it, num_iters, learning_rate):
|
def get_lr(it, num_iters, learning_rate):
|
||||||
# 余弦学习率衰减
|
# 余弦学习率衰减
|
||||||
@ -541,7 +621,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader=None):
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
total_steps_in_epoch = len(train_loader)
|
total_steps_in_epoch = len(train_loader)
|
||||||
@ -644,13 +724,22 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
unwrapped_model.freeze_embedding = True
|
unwrapped_model.freeze_embedding = True
|
||||||
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
||||||
res = model(X, step=step)
|
res = model(X, step=step)
|
||||||
loss = loss_fct(
|
|
||||||
|
# 计算主要损失(交叉熵损失)
|
||||||
|
ce_loss = loss_fct(
|
||||||
res.logits.view(-1, res.logits.size(-1)),
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
).view(Y.size())
|
).view(Y.size())
|
||||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum()
|
||||||
# 移除辅助损失计算,统一不使用 aux_loss
|
|
||||||
loss = loss / args.accumulation_steps
|
# 获取平衡损失(如果模型支持)
|
||||||
|
balance_loss = 0
|
||||||
|
if hasattr(res, 'aux_loss') and res.aux_loss is not None:
|
||||||
|
balance_loss = res.aux_loss
|
||||||
|
|
||||||
|
# 计算总损失
|
||||||
|
total_loss = ce_loss + args.balance_loss_coef * balance_loss
|
||||||
|
loss = total_loss / args.accumulation_steps
|
||||||
|
|
||||||
# 计时前向传播结束 (只在主进程进行)
|
# 计时前向传播结束 (只在主进程进行)
|
||||||
if args.profile and accelerator.is_main_process and forward_end is not None:
|
if args.profile and accelerator.is_main_process and forward_end is not None:
|
||||||
@ -685,8 +774,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
||||||
optimizer_end.record()
|
optimizer_end.record()
|
||||||
|
|
||||||
# 打印训练信息 (只在主进程进行)
|
# 验证评估和日志记录 (只在主进程进行)
|
||||||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
if (step + 1) % args.val_interval == 0 and accelerator.is_main_process:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 记录日志输出时的详细内存状态
|
# 记录日志输出时的详细内存状态
|
||||||
@ -809,19 +898,72 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||||||
last_log_time = current_time # 更新上次日志时间
|
last_log_time = current_time # 更新上次日志时间
|
||||||
|
|
||||||
|
# 执行验证评估
|
||||||
|
val_loss = None
|
||||||
|
if val_loader is not None:
|
||||||
|
try:
|
||||||
|
val_loss = validate_model(model, val_loader, loss_fct, ctx, accelerator)
|
||||||
|
Logger(f"验证损失: {val_loss:.4f}", accelerator)
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"验证评估失败: {e}", accelerator)
|
||||||
|
val_loss = None
|
||||||
|
|
||||||
|
# 获取记忆库更新统计(如果模型支持)
|
||||||
|
memory_update_stats = {}
|
||||||
|
if hasattr(model, 'get_memory_update_stats'):
|
||||||
|
try:
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
if hasattr(unwrapped_model, 'get_memory_update_stats'):
|
||||||
|
memory_update_stats = unwrapped_model.get_memory_update_stats()
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"获取记忆更新统计失败: {e}", accelerator)
|
||||||
|
|
||||||
|
# 获取层级统计信息(如果模型支持)
|
||||||
|
layer_stats = {}
|
||||||
|
if hasattr(res, 'layer_stats') and res.layer_stats is not None:
|
||||||
|
layer_stats = res.layer_stats
|
||||||
|
|
||||||
|
# 构建日志字典
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"step": step + 1,
|
"step": step + 1,
|
||||||
"total_steps_in_epoch": total_steps_in_epoch,
|
"total_steps_in_epoch": total_steps_in_epoch,
|
||||||
"loss": loss.item() * args.accumulation_steps,
|
"train/loss_ce": ce_loss.item(),
|
||||||
|
"train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss,
|
||||||
|
"train/loss_total": total_loss.item(),
|
||||||
"lr": current_lr,
|
"lr": current_lr,
|
||||||
"tokens_per_sec": tokens_per_sec,
|
"tokens_per_sec": tokens_per_sec,
|
||||||
"epoch_time_left_seconds": epoch_remaining_time,
|
"epoch_time_left_seconds": epoch_remaining_time,
|
||||||
"total_time_left_seconds": total_remaining_time
|
"total_time_left_seconds": total_remaining_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 添加验证损失
|
||||||
|
if val_loss is not None:
|
||||||
|
log_dict["val/loss"] = val_loss
|
||||||
|
|
||||||
|
# 添加记忆库更新统计
|
||||||
|
log_dict.update(memory_update_stats)
|
||||||
|
|
||||||
|
# 添加层级统计信息(选择性添加关键指标)
|
||||||
|
if layer_stats:
|
||||||
|
# 计算所有层的平均统计
|
||||||
|
avg_gini = np.mean([v for k, v in layer_stats.items() if k.endswith('_gini_coefficient')])
|
||||||
|
avg_coverage = np.mean([v for k, v in layer_stats.items() if k.endswith('_coverage_rate')])
|
||||||
|
total_dead = sum([v for k, v in layer_stats.items() if k.endswith('_dead_memories')])
|
||||||
|
total_hot = sum([v for k, v in layer_stats.items() if k.endswith('_hot_memories')])
|
||||||
|
|
||||||
|
log_dict.update({
|
||||||
|
'memory/avg_gini_coefficient': avg_gini,
|
||||||
|
'memory/avg_coverage_rate': avg_coverage,
|
||||||
|
'memory/total_dead_memories': total_dead,
|
||||||
|
'memory/total_hot_memories': total_hot,
|
||||||
|
})
|
||||||
|
|
||||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||||
f"Loss: {log_dict['loss']:.4f}, "
|
f"CE Loss: {log_dict['train/loss_ce']:.4f}, "
|
||||||
|
f"Balance Loss: {log_dict['train/loss_balance']:.4f}, "
|
||||||
|
f"Total Loss: {log_dict['train/loss_total']:.4f}, "
|
||||||
|
f"Val Loss: {log_dict.get('val/loss', 'N/A')}, "
|
||||||
f"LR: {log_dict['lr']:.6f}, "
|
f"LR: {log_dict['lr']:.6f}, "
|
||||||
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
||||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||||
@ -832,7 +974,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
|
|
||||||
# 保存模型 (只在主进程进行)
|
# 保存模型 (只在主进程进行)
|
||||||
loss_total = loss.item() * args.accumulation_steps
|
loss_total = loss.item() * args.accumulation_steps
|
||||||
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
|
if epoch >= 0 and best_loss > loss_total and accelerator.is_main_process:
|
||||||
best_loss = loss_total
|
best_loss = loss_total
|
||||||
# 使用函数开始处定义的moe_path变量
|
# 使用函数开始处定义的moe_path变量
|
||||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||||
@ -913,6 +1055,9 @@ def main():
|
|||||||
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
|
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
|
||||||
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
|
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
|
||||||
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
|
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
|
||||||
|
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
|
||||||
|
parser.add_argument("--val_data_path", type=str, default="dataset/stable/eval_data.json", help="验证数据集路径")
|
||||||
|
parser.add_argument("--val_interval", type=int, default=100, help="验证评估间隔")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -1053,6 +1198,19 @@ def main():
|
|||||||
prefetch_factor=2 if args.num_workers > 0 else None
|
prefetch_factor=2 if args.num_workers > 0 else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建验证数据集和加载器
|
||||||
|
val_loader = None
|
||||||
|
val_ds = create_validation_dataset(args.val_data_path, tokenizer, lm_config.max_seq_len)
|
||||||
|
if val_ds is not None:
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_ds,
|
||||||
|
batch_size=args.batch_size // 2, # 验证时使用较小批次
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=False,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0, # 验证时不使用多进程
|
||||||
|
)
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# 创建优化器
|
# 创建优化器
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -1072,9 +1230,14 @@ def main():
|
|||||||
#########################################################
|
#########################################################
|
||||||
# 准备训练
|
# 准备训练
|
||||||
#########################################################
|
#########################################################
|
||||||
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
if val_loader is not None:
|
||||||
model, optimizer, train_loader, scheduler
|
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
|
||||||
)
|
model, optimizer, train_loader, val_loader, scheduler
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_loader, scheduler
|
||||||
|
)
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# 训练循环
|
# 训练循环
|
||||||
@ -1082,7 +1245,7 @@ def main():
|
|||||||
overall_start_time = time.time() # Record overall start time
|
overall_start_time = time.time() # Record overall start time
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader) # Pass tokenizer and val_loader
|
||||||
|
|
||||||
# 每个epoch结束后进行内存清理
|
# 每个epoch结束后进行内存清理
|
||||||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user