This commit is contained in:
Yu Chengzhang 2025-09-01 19:16:02 +08:00
parent c4c72ac154
commit 05e93cbc48
2 changed files with 225 additions and 37 deletions

View File

@ -280,7 +280,7 @@ class GatedMemoryFusion(nn.Module):
# 输入维度dim (h_attn) + num_selected * knowledge_dim (选中的记忆) # 输入维度dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
# 实验1.4.6记忆解码后立即压缩回knowledge_dim避免显存爆炸 # 实验1.4.6记忆解码后立即压缩回knowledge_dim避免显存爆炸
concat_dim = self.dim + self.num_selected * self.knowledge_dim concat_dim = self.dim + self.num_selected * self.dim
# 类似SwiGLU的门控MLP结构 # 类似SwiGLU的门控MLP结构
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False) self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
@ -305,7 +305,7 @@ class GatedMemoryFusion(nn.Module):
memory_flat = selected_memories.reshape(bsz, seq_len, -1) memory_flat = selected_memories.reshape(bsz, seq_len, -1)
# 拼接h_attn和记忆信息 # 拼接h_attn和记忆信息
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim] concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * dim]
# 门控MLP处理类似SwiGLU # 门控MLP处理类似SwiGLU
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim] gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
@ -350,6 +350,7 @@ class MiniMindBlock(nn.Module):
balance_loss: 该层的平衡损失 balance_loss: 该层的平衡损失
layer_stats: 该层的监控统计信息 layer_stats: 该层的监控统计信息
ema_stats: EMA更新统计信息如果collect_ema_stats=True ema_stats: EMA更新统计信息如果collect_ema_stats=True
cosine_stats: 查找向量与选中记忆条目的余弦相似度统计信息
""" """
# Self attention # Self attention
h_attn = self.attention(self.attention_norm(x), pos_cis) h_attn = self.attention(self.attention_norm(x), pos_cis)
@ -373,19 +374,8 @@ class MiniMindBlock(nn.Module):
# 立即压缩knowledge_length * dim -> knowledge_dim 避免显存爆炸 # 立即压缩knowledge_length * dim -> knowledge_dim 避免显存爆炸
# 使用平均池化压缩knowledge_length维度 # 使用平均池化压缩knowledge_length维度
pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim] pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim]
# 投影到knowledge_dim维度
if self.dim > self.config.knowledge_dim:
# 截断到knowledge_dim
compressed_memory = pooled_memory[:, :self.config.knowledge_dim]
elif self.dim < self.config.knowledge_dim:
# 填充到knowledge_dim
pad_size = self.config.knowledge_dim - self.dim
compressed_memory = F.pad(pooled_memory, (0, pad_size), 'constant', 0)
else:
compressed_memory = pooled_memory
selected_memory = compressed_memory.view(bsz, seq_len, num_selected, self.config.knowledge_dim) # [batch, seq_len, num_selected, knowledge_dim] selected_memory = pooled_memory.view(bsz, seq_len, num_selected, self.dim) # [batch, seq_len, num_selected, dim]
# 门控MLP融合串型连接h_attn和选中的记忆 # 门控MLP融合串型连接h_attn和选中的记忆
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores) memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
@ -393,6 +383,28 @@ class MiniMindBlock(nn.Module):
# 残差连接 # 残差连接
out = h + memory_output out = h + memory_output
# 🔍 新增: 计算查找向量与选中记忆条目的余弦相似度
with torch.no_grad():
# 扩展查找向量维度以匹配selected_memory
h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim]
# 计算余弦相似度cosine_sim(query, memory) for each selected memory
cosine_similarities = F.cosine_similarity(
h_expanded, # [batch, seq_len, num_selected, dim]
selected_memory, # [batch, seq_len, num_selected, knowledge_dim]
dim=-1 # 在knowledge_dim维度计算余弦相似度
) # [batch, seq_len, num_selected]
# 计算余弦相似度统计信息
cosine_stats = {
'cosine_similarities': cosine_similarities, # [batch, seq_len, num_selected]
'avg_cosine_similarity': cosine_similarities.mean().item(), # 平均余弦相似度
'max_cosine_similarity': cosine_similarities.max().item(), # 最大余弦相似度
'min_cosine_similarity': cosine_similarities.min().item(), # 最小余弦相似度
'std_cosine_similarity': cosine_similarities.std().item(), # 余弦相似度标准差
}
# 收集EMA更新统计信息仅在训练时且启用时 # 收集EMA更新统计信息仅在训练时且启用时
ema_stats = None ema_stats = None
if collect_ema_stats and self.training: if collect_ema_stats and self.training:
@ -404,9 +416,9 @@ class MiniMindBlock(nn.Module):
} }
if collect_ema_stats: if collect_ema_stats:
return out, balance_loss, layer_stats, ema_stats return out, balance_loss, layer_stats, ema_stats, cosine_stats
else: else:
return out, balance_loss, layer_stats return out, balance_loss, layer_stats, cosine_stats
class MiniMindLM(PreTrainedModel): class MiniMindLM(PreTrainedModel):
@ -455,9 +467,8 @@ class MiniMindLM(PreTrainedModel):
if params.freeze_ratio > 0.0: if params.freeze_ratio > 0.0:
freeze_num = int(params.knowledge_num * params.freeze_ratio) freeze_num = int(params.knowledge_num * params.freeze_ratio)
freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool) freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool)
# 随机选择要冻结的条目 # 固定冻结前面的条目
freeze_indices = torch.randperm(params.knowledge_num)[:freeze_num] freeze_mask[:freeze_num] = True
freeze_mask[freeze_indices] = True
self.register_buffer('freeze_mask', freeze_mask, persistent=False) self.register_buffer('freeze_mask', freeze_mask, persistent=False)
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen") print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen")
else: else:
@ -528,18 +539,23 @@ class MiniMindLM(PreTrainedModel):
total_balance_loss = 0 total_balance_loss = 0
all_layer_stats = {} all_layer_stats = {}
all_ema_stats = {} all_ema_stats = {}
all_cosine_stats = {}
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
if collect_ema_stats: if collect_ema_stats:
h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True) h, balance_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
all_ema_stats[f'layer_{layer_idx}'] = ema_stats all_ema_stats[f'layer_{layer_idx}'] = ema_stats
else: else:
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False) h, balance_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
total_balance_loss += balance_loss total_balance_loss += balance_loss
# 为每层的统计信息添加前缀 # 为每层的统计信息添加前缀
for key, value in layer_stats.items(): for key, value in layer_stats.items():
all_layer_stats[f'layer_{layer_idx}_{key}'] = value all_layer_stats[f'layer_{layer_idx}_{key}'] = value
# 为每层的余弦相似度统计信息添加前缀
for key, value in cosine_stats.items():
all_cosine_stats[f'layer_{layer_idx}_{key}'] = value
logits = self.output(self.norm(h)) logits = self.output(self.norm(h))
@ -551,6 +567,7 @@ class MiniMindLM(PreTrainedModel):
self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息 self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
return self.OUT return self.OUT

View File

@ -281,7 +281,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
sentences_data = [] sentences_data = []
for data in database_data: for data in database_data:
sentences_data.append(data['target'][0]['sentence']) # 保存句子和对应的uuid信息
sentence_info = {
'sentence': data['target'][0]['sentence'],
'uuid': data['target'][0]['uuid'],
'subject': data['target'][0].get('subject', ''),
'predicate': data['target'][0].get('predicate', ''),
'object': data['target'][0].get('object', '')
}
sentences_data.append(sentence_info)
# 提取sentences列表 # 提取sentences列表
# sentences_data = database_data.get('sentences', []) # sentences_data = database_data.get('sentences', [])
@ -289,8 +297,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 2. 按照importance_score进行排序从高到低 # 2. 按照importance_score进行排序从高到低
try: try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) # 注意现在sentences_data中的每个元素都是字典不再有importance_score字段
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") # 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
except: except:
sorted_sentences = sentences_data sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类 # 3. 处理每条数据,不进行聚类
@ -307,12 +317,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
total_sentences = 0 total_sentences = 0
truncated_sentences = 0 truncated_sentences = 0
# 用于记录映射关系的列表
database_mapping = []
for i in range(num_to_process): for i in range(num_to_process):
sentence_data = sorted_sentences[i] sentence_data = sorted_sentences[i]
try: # 现在sentence_data是一个字典包含sentence和uuid
sentence = sentence_data.get('corrected_sentence') sentence = sentence_data['sentence']
except: uuid = sentence_data['uuid']
sentence = sentence_data
# 将句子转换为tokens # 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
@ -333,6 +345,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
processed_rows.append(sentence_tokens) processed_rows.append(sentence_tokens)
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': i, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': sentence_data.get('subject', ''),
'predicate': sentence_data.get('predicate', ''),
'object': sentence_data.get('object', ''),
'token_count': len(sentence_tokens),
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
}
database_mapping.append(mapping_entry)
if (i + 1) % 1000 == 0: if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences") Logger(f"Processed {i + 1}/{num_to_process} sentences")
@ -367,6 +392,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f"Processed results saved to {args.cluster_cache_path}") Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e: except Exception as e:
Logger(f"Failed to save processed results: {e}") Logger(f"Failed to save processed results: {e}")
# 保存数据库映射文件
try:
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
# 4. 初始化模型的knowledge_dataset # 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'): if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
@ -482,7 +527,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
sentences_data = [] sentences_data = []
for data in database_data: for data in database_data:
sentences_data.append(data['target'][0]['sentence']) # 保存句子和对应的uuid信息
sentence_info = {
'sentence': data['target'][0]['sentence'],
'uuid': data['target'][0]['uuid'],
'subject': data['target'][0].get('subject', ''),
'predicate': data['target'][0].get('predicate', ''),
'object': data['target'][0].get('object', '')
}
sentences_data.append(sentence_info)
# 提取sentences列表 # 提取sentences列表
# sentences_data = database_data.get('sentences', []) # sentences_data = database_data.get('sentences', [])
@ -490,8 +543,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 2. 按照importance_score进行排序从高到低 # 2. 按照importance_score进行排序从高到低
try: try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) # 注意现在sentences_data中的每个元素都是字典不再有importance_score字段
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") # 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
except: except:
sorted_sentences = sentences_data sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类 # 3. 处理每条数据,不进行聚类
@ -508,12 +563,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
total_sentences = 0 total_sentences = 0
truncated_sentences = 0 truncated_sentences = 0
# 用于记录映射关系的列表
database_mapping = []
for i in range(num_to_process): for i in range(num_to_process):
sentence_data = sorted_sentences[i] sentence_data = sorted_sentences[i]
try: # 现在sentence_data是一个字典包含sentence和uuid
sentence = sentence_data.get('corrected_sentence') sentence = sentence_data['sentence']
except: uuid = sentence_data['uuid']
sentence = sentence_data
# 将句子转换为tokens # 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
@ -534,6 +591,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
processed_rows.append(sentence_tokens) processed_rows.append(sentence_tokens)
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': i, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': sentence_data.get('subject', ''),
'predicate': sentence_data.get('predicate', ''),
'object': sentence_data.get('object', ''),
'token_count': len(sentence_tokens),
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
}
database_mapping.append(mapping_entry)
if (i + 1) % 1000 == 0: if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences") Logger(f"Processed {i + 1}/{num_to_process} sentences")
@ -568,6 +638,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f"Processed results saved to {args.cluster_cache_path}") Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e: except Exception as e:
Logger(f"Failed to save processed results: {e}") Logger(f"Failed to save processed results: {e}")
# 保存数据库映射文件
try:
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
# 4. 初始化模型的knowledge_dataset # 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'): if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
@ -646,6 +736,9 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
truncated_sentences = 0 truncated_sentences = 0
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 用于记录映射关系的列表
database_mapping = []
# 控制处理的句子数量 # 控制处理的句子数量
num_to_process = min(len(data), knowledge_num) num_to_process = min(len(data), knowledge_num)
Logger(f"Processing {num_to_process} out of {total_sentences} sentences") Logger(f"Processing {num_to_process} out of {total_sentences} sentences")
@ -655,11 +748,27 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
if idx % 1000 == 0: if idx % 1000 == 0:
Logger(f"Processing sentence {idx+1}/{num_to_process}") Logger(f"Processing sentence {idx+1}/{num_to_process}")
# 获取句子文本 # 获取句子文本和uuid
if isinstance(item, dict): if isinstance(item, dict):
sentence = item.get('sentence', '') or item.get('text', '') or str(item) # 如果是字典格式尝试提取target数组中的数据
if 'target' in item and len(item['target']) > 0:
sentence = item['target'][0].get('sentence', '')
uuid = item['target'][0].get('uuid', '')
subject = item['target'][0].get('subject', '')
predicate = item['target'][0].get('predicate', '')
object_name = item['target'][0].get('object', '')
else:
sentence = item.get('sentence', '') or item.get('text', '') or str(item)
uuid = item.get('uuid', '')
subject = item.get('subject', '')
predicate = item.get('predicate', '')
object_name = item.get('object', '')
else: else:
sentence = str(item) sentence = str(item)
uuid = ''
subject = ''
predicate = ''
object_name = ''
# 使用tokenizer编码句子 # 使用tokenizer编码句子
try: try:
@ -686,11 +795,38 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
processed_rows.append(tokens) processed_rows.append(tokens)
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': idx, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': subject,
'predicate': predicate,
'object': object_name,
'token_count': len(tokens),
'is_truncated': len(tokens) > knowledge_length
}
database_mapping.append(mapping_entry)
except Exception as e: except Exception as e:
Logger(f"Error processing sentence {idx}: {e}") Logger(f"Error processing sentence {idx}: {e}")
# 使用空tokens作为fallback # 使用空tokens作为fallback
empty_tokens = [pad_token_id] * knowledge_length empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens) processed_rows.append(empty_tokens)
# 为失败的句子也记录映射关系
mapping_entry = {
'database_index': idx,
'uuid': uuid,
'sentence': sentence,
'subject': subject,
'predicate': predicate,
'object': object_name,
'token_count': knowledge_length,
'is_truncated': False,
'processing_error': str(e)
}
database_mapping.append(mapping_entry)
# 如果句子数量不足用空token填充剩余位置 # 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num: while len(processed_rows) < knowledge_num:
@ -721,6 +857,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f"Processed results saved to {memory_cache_path}") Logger(f"Processed results saved to {memory_cache_path}")
except Exception as e: except Exception as e:
Logger(f"Failed to save processed results: {e}") Logger(f"Failed to save processed results: {e}")
# 保存数据库映射文件
try:
mapping_file_path = memory_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
# 初始化模型的memory_bank # 初始化模型的memory_bank
if hasattr(model, 'memory_bank'): if hasattr(model, 'memory_bank'):
@ -1050,6 +1206,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
layer_stats = {} layer_stats = {}
if hasattr(res, 'layer_stats') and res.layer_stats is not None: if hasattr(res, 'layer_stats') and res.layer_stats is not None:
layer_stats = res.layer_stats layer_stats = res.layer_stats
# 获取余弦相似度统计信息(如果模型支持)
cosine_stats = {}
if hasattr(res, 'cosine_stats') and res.cosine_stats is not None:
cosine_stats = res.cosine_stats
# 构建日志字典 # 构建日志字典
log_dict = { log_dict = {
@ -1072,6 +1233,14 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 添加记忆库更新统计 # 添加记忆库更新统计
log_dict.update(memory_update_stats) log_dict.update(memory_update_stats)
# 计算平均余弦相似度
avg_cosine_similarity = 0.0
if cosine_stats:
# 计算所有层的平均余弦相似度
cosine_similarities = [v for k, v in cosine_stats.items() if k.endswith('_avg_cosine_similarity')]
if cosine_similarities:
avg_cosine_similarity = np.mean(cosine_similarities)
# 添加层级统计信息(选择性添加关键指标) # 添加层级统计信息(选择性添加关键指标)
if layer_stats: if layer_stats:
# 计算所有层的平均统计 # 计算所有层的平均统计
@ -1085,11 +1254,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
'memory/avg_coverage_rate': avg_coverage, 'memory/avg_coverage_rate': avg_coverage,
'memory/total_dead_memories': total_dead, 'memory/total_dead_memories': total_dead,
'memory/total_hot_memories': total_hot, 'memory/total_hot_memories': total_hot,
'train/avg_cosine_similarity': avg_cosine_similarity,
}) })
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"CE Loss: {log_dict['train/loss_ce']:.4f}, " f"CE Loss: {log_dict['train/loss_ce']:.4f}, "
f"Balance Loss: {log_dict['train/loss_balance']:.4f}, " f"Balance Loss: {log_dict['train/loss_balance']:.4f}, "
f"Avg Cosine Sim: {avg_cosine_similarity:.4f}, "
f"Total Loss: {log_dict['train/loss_total']:.4f}, " f"Total Loss: {log_dict['train/loss_total']:.4f}, "
f"Val Loss: {log_dict.get('val/loss', 'N/A')}, " f"Val Loss: {log_dict.get('val/loss', 'N/A')}, "
f"LR: {log_dict['lr']:.6f}, " f"LR: {log_dict['lr']:.6f}, "