update
This commit is contained in:
parent
c4c72ac154
commit
05e93cbc48
@ -280,7 +280,7 @@ class GatedMemoryFusion(nn.Module):
|
|||||||
|
|
||||||
# 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
|
# 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
|
||||||
# 实验1.4.6:记忆解码后立即压缩回knowledge_dim避免显存爆炸
|
# 实验1.4.6:记忆解码后立即压缩回knowledge_dim避免显存爆炸
|
||||||
concat_dim = self.dim + self.num_selected * self.knowledge_dim
|
concat_dim = self.dim + self.num_selected * self.dim
|
||||||
|
|
||||||
# 类似SwiGLU的门控MLP结构
|
# 类似SwiGLU的门控MLP结构
|
||||||
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
|
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
|
||||||
@ -305,7 +305,7 @@ class GatedMemoryFusion(nn.Module):
|
|||||||
memory_flat = selected_memories.reshape(bsz, seq_len, -1)
|
memory_flat = selected_memories.reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
# 拼接h_attn和记忆信息
|
# 拼接h_attn和记忆信息
|
||||||
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim]
|
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * dim]
|
||||||
|
|
||||||
# 门控MLP处理(类似SwiGLU)
|
# 门控MLP处理(类似SwiGLU)
|
||||||
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
|
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
|
||||||
@ -350,6 +350,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
balance_loss: 该层的平衡损失
|
balance_loss: 该层的平衡损失
|
||||||
layer_stats: 该层的监控统计信息
|
layer_stats: 该层的监控统计信息
|
||||||
ema_stats: EMA更新统计信息(如果collect_ema_stats=True)
|
ema_stats: EMA更新统计信息(如果collect_ema_stats=True)
|
||||||
|
cosine_stats: 查找向量与选中记忆条目的余弦相似度统计信息
|
||||||
"""
|
"""
|
||||||
# Self attention
|
# Self attention
|
||||||
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
||||||
@ -373,19 +374,8 @@ class MiniMindBlock(nn.Module):
|
|||||||
# 立即压缩:knowledge_length * dim -> knowledge_dim 避免显存爆炸
|
# 立即压缩:knowledge_length * dim -> knowledge_dim 避免显存爆炸
|
||||||
# 使用平均池化压缩knowledge_length维度
|
# 使用平均池化压缩knowledge_length维度
|
||||||
pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim]
|
pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim]
|
||||||
|
|
||||||
# 投影到knowledge_dim维度
|
|
||||||
if self.dim > self.config.knowledge_dim:
|
|
||||||
# 截断到knowledge_dim
|
|
||||||
compressed_memory = pooled_memory[:, :self.config.knowledge_dim]
|
|
||||||
elif self.dim < self.config.knowledge_dim:
|
|
||||||
# 填充到knowledge_dim
|
|
||||||
pad_size = self.config.knowledge_dim - self.dim
|
|
||||||
compressed_memory = F.pad(pooled_memory, (0, pad_size), 'constant', 0)
|
|
||||||
else:
|
|
||||||
compressed_memory = pooled_memory
|
|
||||||
|
|
||||||
selected_memory = compressed_memory.view(bsz, seq_len, num_selected, self.config.knowledge_dim) # [batch, seq_len, num_selected, knowledge_dim]
|
selected_memory = pooled_memory.view(bsz, seq_len, num_selected, self.dim) # [batch, seq_len, num_selected, dim]
|
||||||
|
|
||||||
# 门控MLP融合:串型连接h_attn和选中的记忆
|
# 门控MLP融合:串型连接h_attn和选中的记忆
|
||||||
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
|
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
|
||||||
@ -393,6 +383,28 @@ class MiniMindBlock(nn.Module):
|
|||||||
# 残差连接
|
# 残差连接
|
||||||
out = h + memory_output
|
out = h + memory_output
|
||||||
|
|
||||||
|
# 🔍 新增: 计算查找向量与选中记忆条目的余弦相似度
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
# 扩展查找向量维度以匹配selected_memory
|
||||||
|
h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim]
|
||||||
|
|
||||||
|
# 计算余弦相似度:cosine_sim(query, memory) for each selected memory
|
||||||
|
cosine_similarities = F.cosine_similarity(
|
||||||
|
h_expanded, # [batch, seq_len, num_selected, dim]
|
||||||
|
selected_memory, # [batch, seq_len, num_selected, knowledge_dim]
|
||||||
|
dim=-1 # 在knowledge_dim维度计算余弦相似度
|
||||||
|
) # [batch, seq_len, num_selected]
|
||||||
|
|
||||||
|
# 计算余弦相似度统计信息
|
||||||
|
cosine_stats = {
|
||||||
|
'cosine_similarities': cosine_similarities, # [batch, seq_len, num_selected]
|
||||||
|
'avg_cosine_similarity': cosine_similarities.mean().item(), # 平均余弦相似度
|
||||||
|
'max_cosine_similarity': cosine_similarities.max().item(), # 最大余弦相似度
|
||||||
|
'min_cosine_similarity': cosine_similarities.min().item(), # 最小余弦相似度
|
||||||
|
'std_cosine_similarity': cosine_similarities.std().item(), # 余弦相似度标准差
|
||||||
|
}
|
||||||
|
|
||||||
# 收集EMA更新统计信息(仅在训练时且启用时)
|
# 收集EMA更新统计信息(仅在训练时且启用时)
|
||||||
ema_stats = None
|
ema_stats = None
|
||||||
if collect_ema_stats and self.training:
|
if collect_ema_stats and self.training:
|
||||||
@ -404,9 +416,9 @@ class MiniMindBlock(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if collect_ema_stats:
|
if collect_ema_stats:
|
||||||
return out, balance_loss, layer_stats, ema_stats
|
return out, balance_loss, layer_stats, ema_stats, cosine_stats
|
||||||
else:
|
else:
|
||||||
return out, balance_loss, layer_stats
|
return out, balance_loss, layer_stats, cosine_stats
|
||||||
|
|
||||||
|
|
||||||
class MiniMindLM(PreTrainedModel):
|
class MiniMindLM(PreTrainedModel):
|
||||||
@ -455,9 +467,8 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
if params.freeze_ratio > 0.0:
|
if params.freeze_ratio > 0.0:
|
||||||
freeze_num = int(params.knowledge_num * params.freeze_ratio)
|
freeze_num = int(params.knowledge_num * params.freeze_ratio)
|
||||||
freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool)
|
freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool)
|
||||||
# 随机选择要冻结的条目
|
# 固定冻结前面的条目
|
||||||
freeze_indices = torch.randperm(params.knowledge_num)[:freeze_num]
|
freeze_mask[:freeze_num] = True
|
||||||
freeze_mask[freeze_indices] = True
|
|
||||||
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
|
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
|
||||||
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen")
|
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen")
|
||||||
else:
|
else:
|
||||||
@ -528,18 +539,23 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
total_balance_loss = 0
|
total_balance_loss = 0
|
||||||
all_layer_stats = {}
|
all_layer_stats = {}
|
||||||
all_ema_stats = {}
|
all_ema_stats = {}
|
||||||
|
all_cosine_stats = {}
|
||||||
|
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
for layer_idx, layer in enumerate(self.layers):
|
||||||
if collect_ema_stats:
|
if collect_ema_stats:
|
||||||
h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
|
h, balance_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
|
||||||
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
|
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
|
||||||
else:
|
else:
|
||||||
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
|
h, balance_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
|
||||||
|
|
||||||
total_balance_loss += balance_loss
|
total_balance_loss += balance_loss
|
||||||
# 为每层的统计信息添加前缀
|
# 为每层的统计信息添加前缀
|
||||||
for key, value in layer_stats.items():
|
for key, value in layer_stats.items():
|
||||||
all_layer_stats[f'layer_{layer_idx}_{key}'] = value
|
all_layer_stats[f'layer_{layer_idx}_{key}'] = value
|
||||||
|
|
||||||
|
# 为每层的余弦相似度统计信息添加前缀
|
||||||
|
for key, value in cosine_stats.items():
|
||||||
|
all_cosine_stats[f'layer_{layer_idx}_{key}'] = value
|
||||||
|
|
||||||
logits = self.output(self.norm(h))
|
logits = self.output(self.norm(h))
|
||||||
|
|
||||||
@ -551,6 +567,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||||
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
|
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
|
||||||
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
|
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
|
||||||
|
self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息
|
||||||
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
||||||
return self.OUT
|
return self.OUT
|
||||||
|
|
||||||
|
|||||||
@ -281,7 +281,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
sentences_data = []
|
sentences_data = []
|
||||||
for data in database_data:
|
for data in database_data:
|
||||||
sentences_data.append(data['target'][0]['sentence'])
|
# 保存句子和对应的uuid信息
|
||||||
|
sentence_info = {
|
||||||
|
'sentence': data['target'][0]['sentence'],
|
||||||
|
'uuid': data['target'][0]['uuid'],
|
||||||
|
'subject': data['target'][0].get('subject', ''),
|
||||||
|
'predicate': data['target'][0].get('predicate', ''),
|
||||||
|
'object': data['target'][0].get('object', '')
|
||||||
|
}
|
||||||
|
sentences_data.append(sentence_info)
|
||||||
|
|
||||||
# 提取sentences列表
|
# 提取sentences列表
|
||||||
# sentences_data = database_data.get('sentences', [])
|
# sentences_data = database_data.get('sentences', [])
|
||||||
@ -289,8 +297,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
# 2. 按照importance_score进行排序(从高到低)
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
try:
|
try:
|
||||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
# 注意:现在sentences_data中的每个元素都是字典,不再有importance_score字段
|
||||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
# 如果需要按重要性排序,需要从原始数据中获取该信息
|
||||||
|
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
|
||||||
|
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
|
||||||
except:
|
except:
|
||||||
sorted_sentences = sentences_data
|
sorted_sentences = sentences_data
|
||||||
# 3. 处理每条数据,不进行聚类
|
# 3. 处理每条数据,不进行聚类
|
||||||
@ -307,12 +317,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
total_sentences = 0
|
total_sentences = 0
|
||||||
truncated_sentences = 0
|
truncated_sentences = 0
|
||||||
|
|
||||||
|
# 用于记录映射关系的列表
|
||||||
|
database_mapping = []
|
||||||
|
|
||||||
for i in range(num_to_process):
|
for i in range(num_to_process):
|
||||||
sentence_data = sorted_sentences[i]
|
sentence_data = sorted_sentences[i]
|
||||||
try:
|
# 现在sentence_data是一个字典,包含sentence和uuid
|
||||||
sentence = sentence_data.get('corrected_sentence')
|
sentence = sentence_data['sentence']
|
||||||
except:
|
uuid = sentence_data['uuid']
|
||||||
sentence = sentence_data
|
|
||||||
|
|
||||||
# 将句子转换为tokens
|
# 将句子转换为tokens
|
||||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
@ -333,6 +345,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
processed_rows.append(sentence_tokens)
|
processed_rows.append(sentence_tokens)
|
||||||
|
|
||||||
|
# 记录映射关系:数据库索引 -> 原始数据信息
|
||||||
|
mapping_entry = {
|
||||||
|
'database_index': i, # 在数据库中的索引位置
|
||||||
|
'uuid': uuid, # 原始uuid
|
||||||
|
'sentence': sentence, # 原始句子
|
||||||
|
'subject': sentence_data.get('subject', ''),
|
||||||
|
'predicate': sentence_data.get('predicate', ''),
|
||||||
|
'object': sentence_data.get('object', ''),
|
||||||
|
'token_count': len(sentence_tokens),
|
||||||
|
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
|
||||||
|
}
|
||||||
|
database_mapping.append(mapping_entry)
|
||||||
|
|
||||||
if (i + 1) % 1000 == 0:
|
if (i + 1) % 1000 == 0:
|
||||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||||
|
|
||||||
@ -367,6 +392,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Failed to save processed results: {e}")
|
Logger(f"Failed to save processed results: {e}")
|
||||||
|
|
||||||
|
# 保存数据库映射文件
|
||||||
|
try:
|
||||||
|
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
|
||||||
|
mapping_data = {
|
||||||
|
'metadata': {
|
||||||
|
'total_entries': len(database_mapping),
|
||||||
|
'knowledge_num': knowledge_num,
|
||||||
|
'knowledge_length': knowledge_length,
|
||||||
|
'source_file': database_init_path,
|
||||||
|
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
},
|
||||||
|
'mappings': database_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(mapping_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
|
||||||
|
Logger(f"Database mapping saved to {mapping_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save database mapping: {e}")
|
||||||
|
|
||||||
# 4. 初始化模型的knowledge_dataset
|
# 4. 初始化模型的knowledge_dataset
|
||||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||||
@ -482,7 +527,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
sentences_data = []
|
sentences_data = []
|
||||||
for data in database_data:
|
for data in database_data:
|
||||||
sentences_data.append(data['target'][0]['sentence'])
|
# 保存句子和对应的uuid信息
|
||||||
|
sentence_info = {
|
||||||
|
'sentence': data['target'][0]['sentence'],
|
||||||
|
'uuid': data['target'][0]['uuid'],
|
||||||
|
'subject': data['target'][0].get('subject', ''),
|
||||||
|
'predicate': data['target'][0].get('predicate', ''),
|
||||||
|
'object': data['target'][0].get('object', '')
|
||||||
|
}
|
||||||
|
sentences_data.append(sentence_info)
|
||||||
|
|
||||||
# 提取sentences列表
|
# 提取sentences列表
|
||||||
# sentences_data = database_data.get('sentences', [])
|
# sentences_data = database_data.get('sentences', [])
|
||||||
@ -490,8 +543,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
# 2. 按照importance_score进行排序(从高到低)
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
try:
|
try:
|
||||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
# 注意:现在sentences_data中的每个元素都是字典,不再有importance_score字段
|
||||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
# 如果需要按重要性排序,需要从原始数据中获取该信息
|
||||||
|
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
|
||||||
|
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
|
||||||
except:
|
except:
|
||||||
sorted_sentences = sentences_data
|
sorted_sentences = sentences_data
|
||||||
# 3. 处理每条数据,不进行聚类
|
# 3. 处理每条数据,不进行聚类
|
||||||
@ -508,12 +563,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
total_sentences = 0
|
total_sentences = 0
|
||||||
truncated_sentences = 0
|
truncated_sentences = 0
|
||||||
|
|
||||||
|
# 用于记录映射关系的列表
|
||||||
|
database_mapping = []
|
||||||
|
|
||||||
for i in range(num_to_process):
|
for i in range(num_to_process):
|
||||||
sentence_data = sorted_sentences[i]
|
sentence_data = sorted_sentences[i]
|
||||||
try:
|
# 现在sentence_data是一个字典,包含sentence和uuid
|
||||||
sentence = sentence_data.get('corrected_sentence')
|
sentence = sentence_data['sentence']
|
||||||
except:
|
uuid = sentence_data['uuid']
|
||||||
sentence = sentence_data
|
|
||||||
|
|
||||||
# 将句子转换为tokens
|
# 将句子转换为tokens
|
||||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
@ -534,6 +591,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
processed_rows.append(sentence_tokens)
|
processed_rows.append(sentence_tokens)
|
||||||
|
|
||||||
|
# 记录映射关系:数据库索引 -> 原始数据信息
|
||||||
|
mapping_entry = {
|
||||||
|
'database_index': i, # 在数据库中的索引位置
|
||||||
|
'uuid': uuid, # 原始uuid
|
||||||
|
'sentence': sentence, # 原始句子
|
||||||
|
'subject': sentence_data.get('subject', ''),
|
||||||
|
'predicate': sentence_data.get('predicate', ''),
|
||||||
|
'object': sentence_data.get('object', ''),
|
||||||
|
'token_count': len(sentence_tokens),
|
||||||
|
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
|
||||||
|
}
|
||||||
|
database_mapping.append(mapping_entry)
|
||||||
|
|
||||||
if (i + 1) % 1000 == 0:
|
if (i + 1) % 1000 == 0:
|
||||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||||
|
|
||||||
@ -568,6 +638,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Failed to save processed results: {e}")
|
Logger(f"Failed to save processed results: {e}")
|
||||||
|
|
||||||
|
# 保存数据库映射文件
|
||||||
|
try:
|
||||||
|
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
|
||||||
|
mapping_data = {
|
||||||
|
'metadata': {
|
||||||
|
'total_entries': len(database_mapping),
|
||||||
|
'knowledge_num': knowledge_num,
|
||||||
|
'knowledge_length': knowledge_length,
|
||||||
|
'source_file': database_init_path,
|
||||||
|
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
},
|
||||||
|
'mappings': database_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(mapping_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
|
||||||
|
Logger(f"Database mapping saved to {mapping_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save database mapping: {e}")
|
||||||
|
|
||||||
# 4. 初始化模型的knowledge_dataset
|
# 4. 初始化模型的knowledge_dataset
|
||||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||||
@ -646,6 +736,9 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
truncated_sentences = 0
|
truncated_sentences = 0
|
||||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
|
||||||
|
# 用于记录映射关系的列表
|
||||||
|
database_mapping = []
|
||||||
|
|
||||||
# 控制处理的句子数量
|
# 控制处理的句子数量
|
||||||
num_to_process = min(len(data), knowledge_num)
|
num_to_process = min(len(data), knowledge_num)
|
||||||
Logger(f"Processing {num_to_process} out of {total_sentences} sentences")
|
Logger(f"Processing {num_to_process} out of {total_sentences} sentences")
|
||||||
@ -655,11 +748,27 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
if idx % 1000 == 0:
|
if idx % 1000 == 0:
|
||||||
Logger(f"Processing sentence {idx+1}/{num_to_process}")
|
Logger(f"Processing sentence {idx+1}/{num_to_process}")
|
||||||
|
|
||||||
# 获取句子文本
|
# 获取句子文本和uuid
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
sentence = item.get('sentence', '') or item.get('text', '') or str(item)
|
# 如果是字典格式,尝试提取target数组中的数据
|
||||||
|
if 'target' in item and len(item['target']) > 0:
|
||||||
|
sentence = item['target'][0].get('sentence', '')
|
||||||
|
uuid = item['target'][0].get('uuid', '')
|
||||||
|
subject = item['target'][0].get('subject', '')
|
||||||
|
predicate = item['target'][0].get('predicate', '')
|
||||||
|
object_name = item['target'][0].get('object', '')
|
||||||
|
else:
|
||||||
|
sentence = item.get('sentence', '') or item.get('text', '') or str(item)
|
||||||
|
uuid = item.get('uuid', '')
|
||||||
|
subject = item.get('subject', '')
|
||||||
|
predicate = item.get('predicate', '')
|
||||||
|
object_name = item.get('object', '')
|
||||||
else:
|
else:
|
||||||
sentence = str(item)
|
sentence = str(item)
|
||||||
|
uuid = ''
|
||||||
|
subject = ''
|
||||||
|
predicate = ''
|
||||||
|
object_name = ''
|
||||||
|
|
||||||
# 使用tokenizer编码句子
|
# 使用tokenizer编码句子
|
||||||
try:
|
try:
|
||||||
@ -686,11 +795,38 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
processed_rows.append(tokens)
|
processed_rows.append(tokens)
|
||||||
|
|
||||||
|
# 记录映射关系:数据库索引 -> 原始数据信息
|
||||||
|
mapping_entry = {
|
||||||
|
'database_index': idx, # 在数据库中的索引位置
|
||||||
|
'uuid': uuid, # 原始uuid
|
||||||
|
'sentence': sentence, # 原始句子
|
||||||
|
'subject': subject,
|
||||||
|
'predicate': predicate,
|
||||||
|
'object': object_name,
|
||||||
|
'token_count': len(tokens),
|
||||||
|
'is_truncated': len(tokens) > knowledge_length
|
||||||
|
}
|
||||||
|
database_mapping.append(mapping_entry)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Error processing sentence {idx}: {e}")
|
Logger(f"Error processing sentence {idx}: {e}")
|
||||||
# 使用空tokens作为fallback
|
# 使用空tokens作为fallback
|
||||||
empty_tokens = [pad_token_id] * knowledge_length
|
empty_tokens = [pad_token_id] * knowledge_length
|
||||||
processed_rows.append(empty_tokens)
|
processed_rows.append(empty_tokens)
|
||||||
|
|
||||||
|
# 为失败的句子也记录映射关系
|
||||||
|
mapping_entry = {
|
||||||
|
'database_index': idx,
|
||||||
|
'uuid': uuid,
|
||||||
|
'sentence': sentence,
|
||||||
|
'subject': subject,
|
||||||
|
'predicate': predicate,
|
||||||
|
'object': object_name,
|
||||||
|
'token_count': knowledge_length,
|
||||||
|
'is_truncated': False,
|
||||||
|
'processing_error': str(e)
|
||||||
|
}
|
||||||
|
database_mapping.append(mapping_entry)
|
||||||
|
|
||||||
# 如果句子数量不足,用空token填充剩余位置
|
# 如果句子数量不足,用空token填充剩余位置
|
||||||
while len(processed_rows) < knowledge_num:
|
while len(processed_rows) < knowledge_num:
|
||||||
@ -721,6 +857,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f"Processed results saved to {memory_cache_path}")
|
Logger(f"Processed results saved to {memory_cache_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Failed to save processed results: {e}")
|
Logger(f"Failed to save processed results: {e}")
|
||||||
|
|
||||||
|
# 保存数据库映射文件
|
||||||
|
try:
|
||||||
|
mapping_file_path = memory_cache_path.replace('.pt', '_mapping.json')
|
||||||
|
mapping_data = {
|
||||||
|
'metadata': {
|
||||||
|
'total_entries': len(database_mapping),
|
||||||
|
'knowledge_num': knowledge_num,
|
||||||
|
'knowledge_length': knowledge_length,
|
||||||
|
'source_file': database_init_path,
|
||||||
|
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
},
|
||||||
|
'mappings': database_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(mapping_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
|
||||||
|
Logger(f"Database mapping saved to {mapping_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save database mapping: {e}")
|
||||||
|
|
||||||
# 初始化模型的memory_bank
|
# 初始化模型的memory_bank
|
||||||
if hasattr(model, 'memory_bank'):
|
if hasattr(model, 'memory_bank'):
|
||||||
@ -1050,6 +1206,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
layer_stats = {}
|
layer_stats = {}
|
||||||
if hasattr(res, 'layer_stats') and res.layer_stats is not None:
|
if hasattr(res, 'layer_stats') and res.layer_stats is not None:
|
||||||
layer_stats = res.layer_stats
|
layer_stats = res.layer_stats
|
||||||
|
|
||||||
|
# 获取余弦相似度统计信息(如果模型支持)
|
||||||
|
cosine_stats = {}
|
||||||
|
if hasattr(res, 'cosine_stats') and res.cosine_stats is not None:
|
||||||
|
cosine_stats = res.cosine_stats
|
||||||
|
|
||||||
# 构建日志字典
|
# 构建日志字典
|
||||||
log_dict = {
|
log_dict = {
|
||||||
@ -1072,6 +1233,14 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
# 添加记忆库更新统计
|
# 添加记忆库更新统计
|
||||||
log_dict.update(memory_update_stats)
|
log_dict.update(memory_update_stats)
|
||||||
|
|
||||||
|
# 计算平均余弦相似度
|
||||||
|
avg_cosine_similarity = 0.0
|
||||||
|
if cosine_stats:
|
||||||
|
# 计算所有层的平均余弦相似度
|
||||||
|
cosine_similarities = [v for k, v in cosine_stats.items() if k.endswith('_avg_cosine_similarity')]
|
||||||
|
if cosine_similarities:
|
||||||
|
avg_cosine_similarity = np.mean(cosine_similarities)
|
||||||
|
|
||||||
# 添加层级统计信息(选择性添加关键指标)
|
# 添加层级统计信息(选择性添加关键指标)
|
||||||
if layer_stats:
|
if layer_stats:
|
||||||
# 计算所有层的平均统计
|
# 计算所有层的平均统计
|
||||||
@ -1085,11 +1254,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
'memory/avg_coverage_rate': avg_coverage,
|
'memory/avg_coverage_rate': avg_coverage,
|
||||||
'memory/total_dead_memories': total_dead,
|
'memory/total_dead_memories': total_dead,
|
||||||
'memory/total_hot_memories': total_hot,
|
'memory/total_hot_memories': total_hot,
|
||||||
|
'train/avg_cosine_similarity': avg_cosine_similarity,
|
||||||
})
|
})
|
||||||
|
|
||||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||||
f"CE Loss: {log_dict['train/loss_ce']:.4f}, "
|
f"CE Loss: {log_dict['train/loss_ce']:.4f}, "
|
||||||
f"Balance Loss: {log_dict['train/loss_balance']:.4f}, "
|
f"Balance Loss: {log_dict['train/loss_balance']:.4f}, "
|
||||||
|
f"Avg Cosine Sim: {avg_cosine_similarity:.4f}, "
|
||||||
f"Total Loss: {log_dict['train/loss_total']:.4f}, "
|
f"Total Loss: {log_dict['train/loss_total']:.4f}, "
|
||||||
f"Val Loss: {log_dict.get('val/loss', 'N/A')}, "
|
f"Val Loss: {log_dict.get('val/loss', 'N/A')}, "
|
||||||
f"LR: {log_dict['lr']:.6f}, "
|
f"LR: {log_dict['lr']:.6f}, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user