diff --git a/model/model_memory.py b/model/model_memory.py index f21434e..d362437 100644 --- a/model/model_memory.py +++ b/model/model_memory.py @@ -280,7 +280,7 @@ class GatedMemoryFusion(nn.Module): # 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆) # 实验1.4.6:记忆解码后立即压缩回knowledge_dim避免显存爆炸 - concat_dim = self.dim + self.num_selected * self.knowledge_dim + concat_dim = self.dim + self.num_selected * self.dim # 类似SwiGLU的门控MLP结构 self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False) @@ -305,7 +305,7 @@ class GatedMemoryFusion(nn.Module): memory_flat = selected_memories.reshape(bsz, seq_len, -1) # 拼接h_attn和记忆信息 - concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim] + concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * dim] # 门控MLP处理(类似SwiGLU) gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim] @@ -350,6 +350,7 @@ class MiniMindBlock(nn.Module): balance_loss: 该层的平衡损失 layer_stats: 该层的监控统计信息 ema_stats: EMA更新统计信息(如果collect_ema_stats=True) + cosine_stats: 查找向量与选中记忆条目的余弦相似度统计信息 """ # Self attention h_attn = self.attention(self.attention_norm(x), pos_cis) @@ -373,19 +374,8 @@ class MiniMindBlock(nn.Module): # 立即压缩:knowledge_length * dim -> knowledge_dim 避免显存爆炸 # 使用平均池化压缩knowledge_length维度 pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim] - - # 投影到knowledge_dim维度 - if self.dim > self.config.knowledge_dim: - # 截断到knowledge_dim - compressed_memory = pooled_memory[:, :self.config.knowledge_dim] - elif self.dim < self.config.knowledge_dim: - # 填充到knowledge_dim - pad_size = self.config.knowledge_dim - self.dim - compressed_memory = F.pad(pooled_memory, (0, pad_size), 'constant', 0) - else: - compressed_memory = pooled_memory - selected_memory = compressed_memory.view(bsz, seq_len, num_selected, self.config.knowledge_dim) # [batch, seq_len, num_selected, knowledge_dim] + selected_memory = pooled_memory.view(bsz, seq_len, num_selected, self.dim) # [batch, seq_len, num_selected, dim] # 门控MLP融合:串型连接h_attn和选中的记忆 memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores) @@ -393,6 +383,28 @@ class MiniMindBlock(nn.Module): # 残差连接 out = h + memory_output + # 🔍 新增: 计算查找向量与选中记忆条目的余弦相似度 + with torch.no_grad(): + + # 扩展查找向量维度以匹配selected_memory + h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim] + + # 计算余弦相似度:cosine_sim(query, memory) for each selected memory + cosine_similarities = F.cosine_similarity( + h_expanded, # [batch, seq_len, num_selected, dim] + selected_memory, # [batch, seq_len, num_selected, knowledge_dim] + dim=-1 # 在knowledge_dim维度计算余弦相似度 + ) # [batch, seq_len, num_selected] + + # 计算余弦相似度统计信息 + cosine_stats = { + 'cosine_similarities': cosine_similarities, # [batch, seq_len, num_selected] + 'avg_cosine_similarity': cosine_similarities.mean().item(), # 平均余弦相似度 + 'max_cosine_similarity': cosine_similarities.max().item(), # 最大余弦相似度 + 'min_cosine_similarity': cosine_similarities.min().item(), # 最小余弦相似度 + 'std_cosine_similarity': cosine_similarities.std().item(), # 余弦相似度标准差 + } + # 收集EMA更新统计信息(仅在训练时且启用时) ema_stats = None if collect_ema_stats and self.training: @@ -404,9 +416,9 @@ class MiniMindBlock(nn.Module): } if collect_ema_stats: - return out, balance_loss, layer_stats, ema_stats + return out, balance_loss, layer_stats, ema_stats, cosine_stats else: - return out, balance_loss, layer_stats + return out, balance_loss, layer_stats, cosine_stats class MiniMindLM(PreTrainedModel): @@ -455,9 +467,8 @@ class MiniMindLM(PreTrainedModel): if params.freeze_ratio > 0.0: freeze_num = int(params.knowledge_num * params.freeze_ratio) freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool) - # 随机选择要冻结的条目 - freeze_indices = torch.randperm(params.knowledge_num)[:freeze_num] - freeze_mask[freeze_indices] = True + # 固定冻结前面的条目 + freeze_mask[:freeze_num] = True self.register_buffer('freeze_mask', freeze_mask, persistent=False) print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen") else: @@ -528,18 +539,23 @@ class MiniMindLM(PreTrainedModel): total_balance_loss = 0 all_layer_stats = {} all_ema_stats = {} + all_cosine_stats = {} for layer_idx, layer in enumerate(self.layers): if collect_ema_stats: - h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True) + h, balance_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True) all_ema_stats[f'layer_{layer_idx}'] = ema_stats else: - h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False) + h, balance_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False) total_balance_loss += balance_loss # 为每层的统计信息添加前缀 for key, value in layer_stats.items(): all_layer_stats[f'layer_{layer_idx}_{key}'] = value + + # 为每层的余弦相似度统计信息添加前缀 + for key, value in cosine_stats.items(): + all_cosine_stats[f'layer_{layer_idx}_{key}'] = value logits = self.output(self.norm(h)) @@ -551,6 +567,7 @@ class MiniMindLM(PreTrainedModel): self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息 + self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息 self.OUT.__setitem__('past_key_values', None) # 不支持KV cache return self.OUT diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 437da2b..b5c9e90 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -281,7 +281,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non sentences_data = [] for data in database_data: - sentences_data.append(data['target'][0]['sentence']) + # 保存句子和对应的uuid信息 + sentence_info = { + 'sentence': data['target'][0]['sentence'], + 'uuid': data['target'][0]['uuid'], + 'subject': data['target'][0].get('subject', ''), + 'predicate': data['target'][0].get('predicate', ''), + 'object': data['target'][0].get('object', '') + } + sentences_data.append(sentence_info) # 提取sentences列表 # sentences_data = database_data.get('sentences', []) @@ -289,8 +297,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non # 2. 按照importance_score进行排序(从高到低) try: - sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) - Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") + # 注意:现在sentences_data中的每个元素都是字典,不再有importance_score字段 + # 如果需要按重要性排序,需要从原始数据中获取该信息 + sorted_sentences = sentences_data # 暂时不排序,保持原始顺序 + Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)") except: sorted_sentences = sentences_data # 3. 处理每条数据,不进行聚类 @@ -307,12 +317,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non total_sentences = 0 truncated_sentences = 0 + # 用于记录映射关系的列表 + database_mapping = [] + for i in range(num_to_process): sentence_data = sorted_sentences[i] - try: - sentence = sentence_data.get('corrected_sentence') - except: - sentence = sentence_data + # 现在sentence_data是一个字典,包含sentence和uuid + sentence = sentence_data['sentence'] + uuid = sentence_data['uuid'] # 将句子转换为tokens sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) @@ -333,6 +345,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non processed_rows.append(sentence_tokens) + # 记录映射关系:数据库索引 -> 原始数据信息 + mapping_entry = { + 'database_index': i, # 在数据库中的索引位置 + 'uuid': uuid, # 原始uuid + 'sentence': sentence, # 原始句子 + 'subject': sentence_data.get('subject', ''), + 'predicate': sentence_data.get('predicate', ''), + 'object': sentence_data.get('object', ''), + 'token_count': len(sentence_tokens), + 'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length + } + database_mapping.append(mapping_entry) + if (i + 1) % 1000 == 0: Logger(f"Processed {i + 1}/{num_to_process} sentences") @@ -367,6 +392,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non Logger(f"Processed results saved to {args.cluster_cache_path}") except Exception as e: Logger(f"Failed to save processed results: {e}") + + # 保存数据库映射文件 + try: + mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json') + mapping_data = { + 'metadata': { + 'total_entries': len(database_mapping), + 'knowledge_num': knowledge_num, + 'knowledge_length': knowledge_length, + 'source_file': database_init_path, + 'generation_time': time.strftime('%Y-%m-%d %H:%M:%S') + }, + 'mappings': database_mapping + } + + with open(mapping_file_path, 'w', encoding='utf-8') as f: + json.dump(mapping_data, f, ensure_ascii=False, indent=2) + Logger(f"Database mapping saved to {mapping_file_path}") + except Exception as e: + Logger(f"Failed to save database mapping: {e}") # 4. 初始化模型的knowledge_dataset if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'): @@ -482,7 +527,15 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non sentences_data = [] for data in database_data: - sentences_data.append(data['target'][0]['sentence']) + # 保存句子和对应的uuid信息 + sentence_info = { + 'sentence': data['target'][0]['sentence'], + 'uuid': data['target'][0]['uuid'], + 'subject': data['target'][0].get('subject', ''), + 'predicate': data['target'][0].get('predicate', ''), + 'object': data['target'][0].get('object', '') + } + sentences_data.append(sentence_info) # 提取sentences列表 # sentences_data = database_data.get('sentences', []) @@ -490,8 +543,10 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non # 2. 按照importance_score进行排序(从高到低) try: - sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) - Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") + # 注意:现在sentences_data中的每个元素都是字典,不再有importance_score字段 + # 如果需要按重要性排序,需要从原始数据中获取该信息 + sorted_sentences = sentences_data # 暂时不排序,保持原始顺序 + Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)") except: sorted_sentences = sentences_data # 3. 处理每条数据,不进行聚类 @@ -508,12 +563,14 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non total_sentences = 0 truncated_sentences = 0 + # 用于记录映射关系的列表 + database_mapping = [] + for i in range(num_to_process): sentence_data = sorted_sentences[i] - try: - sentence = sentence_data.get('corrected_sentence') - except: - sentence = sentence_data + # 现在sentence_data是一个字典,包含sentence和uuid + sentence = sentence_data['sentence'] + uuid = sentence_data['uuid'] # 将句子转换为tokens sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) @@ -534,6 +591,19 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non processed_rows.append(sentence_tokens) + # 记录映射关系:数据库索引 -> 原始数据信息 + mapping_entry = { + 'database_index': i, # 在数据库中的索引位置 + 'uuid': uuid, # 原始uuid + 'sentence': sentence, # 原始句子 + 'subject': sentence_data.get('subject', ''), + 'predicate': sentence_data.get('predicate', ''), + 'object': sentence_data.get('object', ''), + 'token_count': len(sentence_tokens), + 'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length + } + database_mapping.append(mapping_entry) + if (i + 1) % 1000 == 0: Logger(f"Processed {i + 1}/{num_to_process} sentences") @@ -568,6 +638,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non Logger(f"Processed results saved to {args.cluster_cache_path}") except Exception as e: Logger(f"Failed to save processed results: {e}") + + # 保存数据库映射文件 + try: + mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json') + mapping_data = { + 'metadata': { + 'total_entries': len(database_mapping), + 'knowledge_num': knowledge_num, + 'knowledge_length': knowledge_length, + 'source_file': database_init_path, + 'generation_time': time.strftime('%Y-%m-%d %H:%M:%S') + }, + 'mappings': database_mapping + } + + with open(mapping_file_path, 'w', encoding='utf-8') as f: + json.dump(mapping_data, f, ensure_ascii=False, indent=2) + Logger(f"Database mapping saved to {mapping_file_path}") + except Exception as e: + Logger(f"Failed to save database mapping: {e}") # 4. 初始化模型的knowledge_dataset if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'): @@ -646,6 +736,9 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non truncated_sentences = 0 pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + # 用于记录映射关系的列表 + database_mapping = [] + # 控制处理的句子数量 num_to_process = min(len(data), knowledge_num) Logger(f"Processing {num_to_process} out of {total_sentences} sentences") @@ -655,11 +748,27 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non if idx % 1000 == 0: Logger(f"Processing sentence {idx+1}/{num_to_process}") - # 获取句子文本 + # 获取句子文本和uuid if isinstance(item, dict): - sentence = item.get('sentence', '') or item.get('text', '') or str(item) + # 如果是字典格式,尝试提取target数组中的数据 + if 'target' in item and len(item['target']) > 0: + sentence = item['target'][0].get('sentence', '') + uuid = item['target'][0].get('uuid', '') + subject = item['target'][0].get('subject', '') + predicate = item['target'][0].get('predicate', '') + object_name = item['target'][0].get('object', '') + else: + sentence = item.get('sentence', '') or item.get('text', '') or str(item) + uuid = item.get('uuid', '') + subject = item.get('subject', '') + predicate = item.get('predicate', '') + object_name = item.get('object', '') else: sentence = str(item) + uuid = '' + subject = '' + predicate = '' + object_name = '' # 使用tokenizer编码句子 try: @@ -686,11 +795,38 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non processed_rows.append(tokens) + # 记录映射关系:数据库索引 -> 原始数据信息 + mapping_entry = { + 'database_index': idx, # 在数据库中的索引位置 + 'uuid': uuid, # 原始uuid + 'sentence': sentence, # 原始句子 + 'subject': subject, + 'predicate': predicate, + 'object': object_name, + 'token_count': len(tokens), + 'is_truncated': len(tokens) > knowledge_length + } + database_mapping.append(mapping_entry) + except Exception as e: Logger(f"Error processing sentence {idx}: {e}") # 使用空tokens作为fallback empty_tokens = [pad_token_id] * knowledge_length processed_rows.append(empty_tokens) + + # 为失败的句子也记录映射关系 + mapping_entry = { + 'database_index': idx, + 'uuid': uuid, + 'sentence': sentence, + 'subject': subject, + 'predicate': predicate, + 'object': object_name, + 'token_count': knowledge_length, + 'is_truncated': False, + 'processing_error': str(e) + } + database_mapping.append(mapping_entry) # 如果句子数量不足,用空token填充剩余位置 while len(processed_rows) < knowledge_num: @@ -721,6 +857,26 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non Logger(f"Processed results saved to {memory_cache_path}") except Exception as e: Logger(f"Failed to save processed results: {e}") + + # 保存数据库映射文件 + try: + mapping_file_path = memory_cache_path.replace('.pt', '_mapping.json') + mapping_data = { + 'metadata': { + 'total_entries': len(database_mapping), + 'knowledge_num': knowledge_num, + 'knowledge_length': knowledge_length, + 'source_file': database_init_path, + 'generation_time': time.strftime('%Y-%m-%d %H:%M:%S') + }, + 'mappings': database_mapping + } + + with open(mapping_file_path, 'w', encoding='utf-8') as f: + json.dump(mapping_data, f, ensure_ascii=False, indent=2) + Logger(f"Database mapping saved to {mapping_file_path}") + except Exception as e: + Logger(f"Failed to save database mapping: {e}") # 初始化模型的memory_bank if hasattr(model, 'memory_bank'): @@ -1050,6 +1206,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a layer_stats = {} if hasattr(res, 'layer_stats') and res.layer_stats is not None: layer_stats = res.layer_stats + + # 获取余弦相似度统计信息(如果模型支持) + cosine_stats = {} + if hasattr(res, 'cosine_stats') and res.cosine_stats is not None: + cosine_stats = res.cosine_stats # 构建日志字典 log_dict = { @@ -1072,6 +1233,14 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # 添加记忆库更新统计 log_dict.update(memory_update_stats) + # 计算平均余弦相似度 + avg_cosine_similarity = 0.0 + if cosine_stats: + # 计算所有层的平均余弦相似度 + cosine_similarities = [v for k, v in cosine_stats.items() if k.endswith('_avg_cosine_similarity')] + if cosine_similarities: + avg_cosine_similarity = np.mean(cosine_similarities) + # 添加层级统计信息(选择性添加关键指标) if layer_stats: # 计算所有层的平均统计 @@ -1085,11 +1254,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a 'memory/avg_coverage_rate': avg_coverage, 'memory/total_dead_memories': total_dead, 'memory/total_hot_memories': total_hot, + 'train/avg_cosine_similarity': avg_cosine_similarity, }) Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " f"CE Loss: {log_dict['train/loss_ce']:.4f}, " f"Balance Loss: {log_dict['train/loss_balance']:.4f}, " + f"Avg Cosine Sim: {avg_cosine_similarity:.4f}, " f"Total Loss: {log_dict['train/loss_total']:.4f}, " f"Val Loss: {log_dict.get('val/loss', 'N/A')}, " f"LR: {log_dict['lr']:.6f}, "