diff --git a/model/LMConfig.py b/model/LMConfig.py index 7dd1f4e..8eb8a62 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -9,13 +9,13 @@ class LMConfig(PretrainedConfig): self, dim: int = 512, n_layers: int = 8, - n_heads: int = 32, + n_heads: int = 16, n_kv_heads: int = 8, vocab_size: int = 6400, hidden_dim: int = None, multiple_of: int = 64, norm_eps: float = 1e-5, - max_seq_len: int = 8192, + max_seq_len: int = 512, rope_theta: int = 1e6, dropout: float = 0.0, flash_attn: bool = True, @@ -38,8 +38,8 @@ class LMConfig(PretrainedConfig): seq_aux: bool = True, norm_topk_prob: bool = True, #################################################### - knowledge_num: int = 64*64, - knowledge_length: int = 8, + knowledge_num: int = 1024*1024, + knowledge_length: int = 16, knowledge_dim: int = 128, #################################################### # EMA update related configurations (inspired by VQ-VAE) diff --git a/model/model_memory_1_4_10.py b/model/model_memory_1_4_10.py new file mode 100644 index 0000000..b40483b --- /dev/null +++ b/model/model_memory_1_4_10.py @@ -0,0 +1,930 @@ +import math +import struct +import inspect +import time + +from .LMConfig import LMConfig +from typing import Any, Optional, Tuple, List, Union +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self.weight * self._norm(x.float()).type_as(x) + + +def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return pos_cis + + +def apply_rotary_emb(xq, xk, pos_cis): + def unite_shape(pos_cis, x): + ndim = x.ndim + assert 0 <= 1 < ndim + assert pos_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return pos_cis.view(*shape) + + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + pos_cis = unite_shape(pos_cis, xq_) + xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """Self attention module without KV cache""" + def __init__(self, args: LMConfig): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert args.n_heads % self.n_kv_heads == 0 + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.dropout = args.dropout + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: torch.Tensor, pos_cis: torch.Tensor): + """Forward pass without KV cache""" + bsz, seq_len, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, pos_cis) + + # 注意:完全去除了KV cache相关代码 + + xq, xk, xv = ( + xq.transpose(1, 2), + repeat_kv(xk, self.n_rep).transpose(1, 2), + repeat_kv(xv, self.n_rep).transpose(1, 2) + ) + if self.flash and seq_len != 1: + dropout_p = self.dropout if self.training else 0.0 + output = F.scaled_dot_product_attention( + xq, xk, xv, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True + ) + else: + scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) + scores += self.mask[:, :, :seq_len, :seq_len] + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = self.attn_dropout(scores) + output = scores @ xv + + output = output.transpose(1, 2).reshape(bsz, seq_len, -1) + output = self.resid_dropout(self.wo(output)) + return output + + +class MemoryGate(nn.Module): + """Product Key Memory-based gate mechanism for memory selection with Gumbel-Softmax""" + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.dim = config.dim + self.knowledge_num = config.knowledge_num + self.knowledge_dim = config.knowledge_dim + self.num_candidates = getattr(config, 'num_candidates', 32) # Generate 32 candidates + self.num_selected = getattr(config, 'num_selected', 1) # Select 1 best from candidates + + # 确保知识库数量是完全平方数 + assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \ + f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory" + + self.num_keys = int(self.knowledge_num ** 0.5) + + # 查询投影:将输入维度映射到knowledge_dim * 2(用于两个product key) + self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False) + + # Product Key Memory: 两个独立的键集合 + self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2)) + + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x: torch.Tensor): + """ + Args: + x: [batch_size, seq_len, dim] + Returns: + memory_indices: [batch_size, seq_len, num_selected] + memory_scores: [batch_size, seq_len, num_selected] + balance_loss: 平衡损失(KL散度 + 基尼系数) + stats: 监控统计信息字典 + """ + bsz, seq_len, _ = x.shape + + # 生成查询向量 + queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim] + + # 分割为两部分用于product key + q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2] + q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2] + + # 计算与两个键集合的相似度 + scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys] + scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys] + + # 获取top-k candidates (now using num_candidates instead of num_selected) + topk_scores_1, topk_indices_1 = scores_1.topk(self.num_candidates, dim=-1) + topk_scores_2, topk_indices_2 = scores_2.topk(self.num_candidates, dim=-1) + + # 组合product key的结果 + combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_candidates, num_candidates] + combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_candidates, num_candidates] + + # 展平并选择最终的top-k candidates + combined_scores = combined_scores.view(bsz, seq_len, -1) + combined_indices = combined_indices.view(bsz, seq_len, -1) + + candidate_scores, candidate_pk_indices = combined_scores.topk(self.num_candidates, dim=-1) + candidate_indices = combined_indices.gather(-1, candidate_pk_indices) # [batch, seq_len, num_candidates] + + # 归一化候选分数 + candidate_scores = F.softmax(candidate_scores, dim=-1) + candidate_scores = self.dropout(candidate_scores) + + # 返回候选项用于后续的相似度选择 + # 注意:这里返回候选项,在MiniMindBlock中进行相似度选择和多样性损失计算 + return candidate_indices, candidate_scores, None, {} + + def _compute_balance_loss_and_stats(self, memory_indices, memory_scores): + """ + 计算平衡损失和监控统计信息 + + Args: + memory_indices: [batch_size, seq_len, num_selected] + memory_scores: [batch_size, seq_len, num_selected] + + Returns: + balance_loss: 标量张量 + stats: 统计信息字典 + """ + bsz, seq_len, num_selected = memory_indices.shape + device = memory_indices.device + + # 1. 计算记忆选择分布 + # 将所有选择的记忆索引展平 + flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected] + + # 统计每个记忆条目被选中的次数 + memory_counts = torch.zeros(self.knowledge_num, device=device) + memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float)) + + # 计算选择概率分布 + total_selections = bsz * seq_len * num_selected + memory_probs = memory_counts / total_selections + + # 2. 计算KL散度损失(与均匀分布的KL散度) + uniform_prob = 1.0 / self.knowledge_num + # 避免log(0)的问题 + memory_probs_safe = memory_probs + 1e-10 + kl_loss = F.kl_div( + torch.log(memory_probs_safe), + torch.full_like(memory_probs, uniform_prob), + reduction='sum' + ) + + # 3. 计算基尼系数损失(衡量分布不平等程度) + sorted_probs, _ = torch.sort(memory_probs) + n = self.knowledge_num + index = torch.arange(1, n + 1, device=device, dtype=torch.float) + gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n + gini_loss = gini_coeff # 基尼系数越大,分布越不均匀 + + # 4. 组合平衡损失 + balance_loss = 0.5 * kl_loss + 0.5 * gini_loss + + # 5. 计算监控统计信息 + with torch.no_grad(): + # 记忆覆盖率:被选中的记忆条目占总数的比例 + coverage_rate = (memory_counts > 0).float().mean().item() + + # 热点记忆:选择次数前10%的记忆条目 + top10_threshold = torch.quantile(memory_counts, 0.9) + hot_memories = (memory_counts >= top10_threshold).sum().item() + + # 死记忆:从未被选中的记忆条目 + dead_memories = (memory_counts == 0).sum().item() + + # 记忆选择方差(衡量不平衡程度) + selection_variance = memory_counts.var().item() + + stats = { + 'gini_coefficient': gini_coeff.item(), + 'kl_divergence': kl_loss.item(), + 'coverage_rate': coverage_rate, + 'hot_memories': hot_memories, + 'dead_memories': dead_memories, + 'selection_variance': selection_variance, + 'max_selections': memory_counts.max().item(), + 'min_selections': memory_counts.min().item(), + } + + return balance_loss, stats + + +class GatedMemoryFusion(nn.Module): + """Gated MLP fusion for concatenated h_attn and selected memories""" + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.dim = config.dim + self.knowledge_dim = config.knowledge_dim + self.num_selected = getattr(config, 'num_selected', 1) # Now we select 1 best memory + + # 输入维度:dim (h_attn) + num_selected * dim (选中的记忆,现在只有1个) + # 实验1.4.9:修改为只选择1个最佳记忆 + concat_dim = self.dim + self.num_selected * self.dim + + # 类似SwiGLU的门控MLP结构 + self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False) + self.up_proj = nn.Linear(concat_dim, self.dim, bias=False) + self.down_proj = nn.Linear(self.dim, self.dim, bias=False) + + self.dropout = nn.Dropout(config.dropout) + + def forward(self, h_attn: torch.Tensor, selected_memory: torch.Tensor): + """ + Args: + h_attn: [batch_size, seq_len, dim] - Self attention output + selected_memory: [batch_size, seq_len, dim] - Selected single best memory + Returns: + output: [batch_size, seq_len, dim] + """ + bsz, seq_len, _ = h_attn.shape + + # 拼接h_attn和最佳记忆 + concat_input = torch.cat([h_attn, selected_memory], dim=-1) # [batch, seq_len, dim + dim] + + # 门控MLP处理(类似SwiGLU) + gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim] + up = self.up_proj(concat_input) # [batch, seq_len, dim] + fusion_output = gate * up # Element-wise multiplication + + # 输出投影 + output = self.down_proj(fusion_output) # [batch, seq_len, dim] + output = self.dropout(output) + + return output + + +class MiniMindBlock(nn.Module): + """Transformer block with memory-based cross attention instead of FFN""" + def __init__(self, layer_id: int, config: LMConfig): + super().__init__() + self.config = config # 保存config引用 + self.n_heads = config.n_heads + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.attention = Attention(config) + + self.layer_id = layer_id + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps) + + # 记忆相关模块 + self.memory_gate = MemoryGate(config) + self.gated_memory_fusion = GatedMemoryFusion(config) + + # Gumbel-Softmax参数 + self.gumbel_temperature = getattr(config, 'gumbel_temperature', 1.0) + + # self.attentionpool = nn.Linear(config.dim, 1) + + def gumbel_softmax_selection(self, similarity_scores, temperature=1.0, hard=True): + """ + 使用Gumbel-Softmax进行可微分的离散选择 + + Args: + similarity_scores: [batch_size, seq_len, num_candidates] - 相似度分数 + temperature: Gumbel-Softmax温度参数 + hard: 是否使用硬选择(one-hot) + + Returns: + selection_weights: [batch_size, seq_len, num_candidates] - 选择权重 + selected_indices: [batch_size, seq_len] - 选中的索引(用于统计) + """ + # 添加Gumbel噪声 + gumbel_noise = -torch.log(-torch.log(torch.rand_like(similarity_scores) + 1e-20) + 1e-20) + logits = (similarity_scores + gumbel_noise) / temperature + + # Softmax + soft_weights = F.softmax(logits, dim=-1) + + if hard: + # 硬选择:创建one-hot向量 + _, max_indices = soft_weights.max(dim=-1, keepdim=True) + hard_weights = torch.zeros_like(soft_weights).scatter_(-1, max_indices, 1.0) + # 使用straight-through estimator + selection_weights = hard_weights - soft_weights.detach() + soft_weights + selected_indices = max_indices.squeeze(-1) # [batch_size, seq_len] + else: + # 软选择 + selection_weights = soft_weights + selected_indices = torch.argmax(soft_weights, dim=-1) + + return selection_weights, selected_indices + + def compute_diversity_loss(self, candidate_memories): + """ + 计算候选集内部多样性损失(鼓励候选项之间的差异性) + + Args: + candidate_memories: [batch_size, seq_len, num_candidates, dim] + + Returns: + diversity_loss: 标量张量 + """ + bsz, seq_len, num_candidates, dim = candidate_memories.shape + + # 计算候选项之间的相似度矩阵 + # 归一化候选记忆用于计算余弦相似度 + normalized_memories = F.normalize(candidate_memories, p=2, dim=-1) # [batch, seq_len, num_candidates, dim] + + # 计算相似度矩阵: [batch, seq_len, num_candidates, num_candidates] + similarity_matrix = torch.matmul(normalized_memories, normalized_memories.transpose(-2, -1)) + + # 移除对角线(自相似度=1) + mask = torch.eye(num_candidates, device=candidate_memories.device).bool() + mask = mask.unsqueeze(0).unsqueeze(0).expand(bsz, seq_len, -1, -1) + + # 计算非对角线元素的平均相似度(希望越小越好,表示越多样) + off_diagonal_similarities = similarity_matrix.masked_select(~mask) + avg_similarity = off_diagonal_similarities.mean() + + # 多样性损失:相似度越高,损失越大 + diversity_loss = avg_similarity + + return diversity_loss + + def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False): + """ + 实验1.4.9: Gumbel-Softmax + 多样性损失 + 可微分相似度损失 + + Args: + x: [batch_size, seq_len, dim] + pos_cis: positional encoding + memory_bank: [knowledge_num, knowledge_length] - shared memory bank with token IDs + tok_embeddings: token embedding layer + collect_ema_stats: 是否收集EMA更新统计信息 + + Returns: + out: [batch_size, seq_len, dim] + balance_loss: 该层的平衡损失 (从候选项计算) + similarity_loss: 相似度损失 (可微分) + diversity_loss: 多样性损失 + layer_stats: 该层的监控统计信息 + ema_stats: EMA更新统计信息(如果collect_ema_stats=True) + cosine_stats: 查找向量与候选记忆条目的余弦相似度统计信息 + """ + # Self attention + h_attn = self.attention(self.attention_norm(x), pos_cis) + h = x + h_attn + + # 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出) + h_for_memory = self.memory_norm(h_attn) + + # 🔥 新架构:生成32个候选项 + candidate_indices, candidate_scores, _, _ = self.memory_gate(h_for_memory) + # candidate_indices: [batch, seq_len, num_candidates] + # candidate_scores: [batch, seq_len, num_candidates] + + bsz, seq_len, num_candidates = candidate_indices.shape + + # 解码候选token_ids为特征向量 + candidate_indices_flat = candidate_indices.view(-1) # [batch * seq_len * num_candidates] + candidate_token_ids = memory_bank[candidate_indices_flat] # [batch * seq_len * num_candidates, knowledge_length] + + # 解码为embeddings并池化 + candidate_embeddings = tok_embeddings(candidate_token_ids) # [batch * seq_len * num_candidates, knowledge_length, dim] + candidate_memories = candidate_embeddings.mean(dim=1) # [batch * seq_len * num_candidates, dim] + candidate_memories = candidate_memories.view(bsz, seq_len, num_candidates, self.dim) # [batch, seq_len, num_candidates, dim] + + # 🔥 核心改进: 计算可微分的相似度分数 (移除no_grad) + h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_candidates, -1) # [batch, seq_len, num_candidates, dim] + similarity_scores = F.cosine_similarity(h_expanded, candidate_memories, dim=-1) # [batch, seq_len, num_candidates] + + # 🔥 使用Gumbel-Softmax选择最佳候选项 + selection_weights, selected_indices = self.gumbel_softmax_selection( + similarity_scores, + temperature=self.gumbel_temperature, + hard=True + ) # selection_weights: [batch, seq_len, num_candidates], selected_indices: [batch, seq_len] + + # 🔥 计算相似度损失 (现在是可微分的!) + # 相似度损失:希望选中的记忆与查询向量相似度尽可能高 + selected_similarities = (similarity_scores * selection_weights).sum(dim=-1) # [batch, seq_len] + similarity_loss = -selected_similarities.mean() # 负号:相似度越高,损失越小 + + # 🔥 计算候选集多样性损失 + diversity_loss = self.compute_diversity_loss(candidate_memories) + + # 🔥 使用selection_weights进行加权选择最终记忆 + selected_memory = (candidate_memories * selection_weights.unsqueeze(-1)).sum(dim=2) # [batch, seq_len, dim] + + # 门控MLP融合:只融合选中的单个最佳记忆 + memory_output = self.gated_memory_fusion(h_for_memory, selected_memory) + + # 残差连接 + out = h + memory_output + + # 🔥 计算平衡损失和统计信息 (基于候选项的选择分布) + balance_loss, layer_stats = self._compute_candidate_balance_stats(candidate_indices, selection_weights) + + # 🔥 计算详细的相似度统计信息 + cosine_stats = { + 'similarity_scores': similarity_scores, # [batch, seq_len, num_candidates] + 'selected_similarities': selected_similarities, # [batch, seq_len] + 'avg_similarity': similarity_scores.mean().item(), # 平均相似度 + 'max_similarity': similarity_scores.max().item(), # 最大相似度 + 'min_similarity': similarity_scores.min().item(), # 最小相似度 + 'selected_avg_similarity': selected_similarities.mean().item(), # 选中记忆的平均相似度 + 'selection_entropy': -torch.sum(selection_weights * torch.log(selection_weights + 1e-10), dim=-1).mean().item() # 选择熵 + } + + # 收集EMA更新统计信息(现在基于选中的记忆) + ema_stats = None + if collect_ema_stats and self.training: + # 扩展选中的索引以匹配EMA更新的期望格式 + selected_memory_indices = candidate_indices.gather(2, selected_indices.unsqueeze(-1)) # [batch, seq_len, 1] + ema_stats = { + 'memory_indices': selected_memory_indices, # [batch, seq_len, 1] + 'memory_scores': torch.ones_like(selected_memory_indices.float()), # [batch, seq_len, 1] - 选中的权重为1 + 'h_for_memory': h_for_memory, # [batch, seq_len, dim] + 'selected_memory': selected_memory.unsqueeze(2), # [batch, seq_len, 1, dim] + } + + if collect_ema_stats: + return out, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats + else: + return out, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats + + def _compute_candidate_balance_stats(self, candidate_indices, selection_weights): + """ + 计算基于候选项选择的平衡损失和统计信息 + + Args: + candidate_indices: [batch_size, seq_len, num_candidates] + selection_weights: [batch_size, seq_len, num_candidates] - Gumbel-Softmax权重 + + Returns: + balance_loss: 标量张量 + stats: 统计信息字典 + """ + bsz, seq_len, num_candidates = candidate_indices.shape + device = candidate_indices.device + + # 使用加权统计每个记忆条目被选中的概率 + flat_indices = candidate_indices.view(-1) # [batch * seq_len * num_candidates] + flat_weights = selection_weights.view(-1) # [batch * seq_len * num_candidates] + + # 统计每个记忆条目被选中的加权次数 + memory_counts = torch.zeros(self.config.knowledge_num, device=device) + memory_counts.scatter_add_(0, flat_indices, flat_weights) + + # 计算选择概率分布 + total_selections = memory_counts.sum() + memory_probs = memory_counts / (total_selections + 1e-10) + + # 计算KL散度损失(与均匀分布的KL散度) + uniform_prob = 1.0 / self.config.knowledge_num + memory_probs_safe = memory_probs + 1e-10 + kl_loss = F.kl_div( + torch.log(memory_probs_safe), + torch.full_like(memory_probs, uniform_prob), + reduction='sum' + ) + + # 计算基尼系数损失 + sorted_probs, _ = torch.sort(memory_probs) + n = self.config.knowledge_num + index = torch.arange(1, n + 1, device=device, dtype=torch.float) + gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n + gini_loss = gini_coeff + + # 组合平衡损失 + balance_loss = 0.5 * kl_loss + 0.5 * gini_loss + + # 计算统计信息 + with torch.no_grad(): + coverage_rate = (memory_counts > 0.01).float().mean().item() # 被选中概率>1%的记忆比例 + top10_threshold = torch.quantile(memory_counts, 0.9) + hot_memories = (memory_counts >= top10_threshold).sum().item() + dead_memories = (memory_counts < 0.01).sum().item() # 几乎从未被选中的记忆 + selection_variance = memory_counts.var().item() + + stats = { + 'gini_coefficient': gini_coeff.item(), + 'kl_divergence': kl_loss.item(), + 'coverage_rate': coverage_rate, + 'hot_memories': hot_memories, + 'dead_memories': dead_memories, + 'selection_variance': selection_variance, + 'max_selections': memory_counts.max().item(), + 'min_selections': memory_counts.min().item(), + } + + return balance_loss, stats + + +class MiniMindLM(PreTrainedModel): + config_class = LMConfig + + def __init__(self, params: LMConfig = None): + self.params = params + super().__init__(self.params) + self.vocab_size, self.n_layers = params.vocab_size, params.n_layers + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.dropout = nn.Dropout(params.dropout) + self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.tok_embeddings.weight = self.output.weight + self.register_buffer("pos_cis", + precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), + persistent=False) + + # 初始化共享记忆库 - 实验1.4.6:存储token_id而非特征向量 + # VQ-VAE风格:memory_bank作为codebook,使用EMA更新而非梯度更新 + if params.use_ema_update: + self.memory_bank = nn.Parameter( + torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)), + requires_grad=False # 禁用梯度更新,使用EMA更新 + ) + else: + self.memory_bank = nn.Parameter( + torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)), + requires_grad=True # 传统梯度更新 + ) + + # EMA更新相关缓冲区 + if params.use_ema_update: + # 记录每个memory条目的更新统计 + self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False) + # 注意:现在memory_bank存储token_id,但EMA在特征空间进行,所以不需要sum_buffer了 + # self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False) + # EMA更新频率计数器 + self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False) + + # 记录上一步的记忆库状态,用于计算更新统计 + self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False) + + # 🔥 新增: 冻结mask - 标记哪些memory_bank条目被冻结(不更新) + if params.freeze_ratio > 0.0: + freeze_num = int(params.knowledge_num * params.freeze_ratio) + freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool) + # 固定冻结前面的条目 + freeze_mask[:freeze_num] = True + self.register_buffer('freeze_mask', freeze_mask, persistent=False) + print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen", flush=True) + import sys; sys.stdout.flush() + else: + self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False) + print(f"🔥 Memory bank freezing disabled: all entries can be updated", flush=True) + import sys; sys.stdout.flush() + + self.OUT = CausalLMOutputWithPast() + + def get_memory_update_stats(self): + """ + 计算记忆库更新统计信息 + + Returns: + update_stats: 包含更新统计的字典 + """ + with torch.no_grad(): + if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0: + # 计算L2距离变化 + l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1) + avg_l2_distance = l2_distance.mean().item() + max_l2_distance = l2_distance.max().item() + + # 计算余弦相似度 + cos_sim = F.cosine_similarity( + self.memory_bank.view(-1), + self.prev_memory_bank.view(-1), + dim=0 + ).item() + + # 计算更新率(发生显著变化的记忆条目比例) + threshold = 0.01 # 更新阈值 + updated_memories = (l2_distance > threshold).sum().item() + update_rate = updated_memories / self.memory_bank.size(0) + + update_stats = { + 'memory_avg_l2_change': avg_l2_distance, + 'memory_max_l2_change': max_l2_distance, + 'memory_cosine_similarity': cos_sim, + 'memory_update_rate': update_rate, + 'memory_updated_count': updated_memories + } + else: + # 第一次调用时的默认值 + update_stats = { + 'memory_avg_l2_change': 0.0, + 'memory_max_l2_change': 0.0, + 'memory_cosine_similarity': 1.0, + 'memory_update_rate': 0.0, + 'memory_updated_count': 0 + } + + # 更新prev_memory_bank + self.prev_memory_bank.copy_(self.memory_bank) + + return update_stats + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + **args): + """Forward pass without KV cache support""" + start_pos = args.get('start_pos', 0) + collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training) + + h = self.dropout(self.tok_embeddings(input_ids)) + pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] + + # 收集所有层的损失和统计信息 - 实验1.4.9: 四损失系统 + total_balance_loss = 0 + total_similarity_loss = 0 + total_diversity_loss = 0 + all_layer_stats = {} + all_ema_stats = {} + all_cosine_stats = {} + + for layer_idx, layer in enumerate(self.layers): + if collect_ema_stats: + h, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True) + all_ema_stats[f'layer_{layer_idx}'] = ema_stats + else: + h, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False) + + # 累加四种损失 + total_balance_loss += balance_loss + total_similarity_loss += similarity_loss + total_diversity_loss += diversity_loss + + # 为每层的统计信息添加前缀 + for key, value in layer_stats.items(): + all_layer_stats[f'layer_{layer_idx}_{key}'] = value + + # 为每层的余弦相似度统计信息添加前缀 + for key, value in cosine_stats.items(): + all_cosine_stats[f'layer_{layer_idx}_{key}'] = value + + logits = self.output(self.norm(h)) + + # 🔥 新的四损失结构 + aux_loss = { + 'balance_loss': total_balance_loss, + 'similarity_loss': total_similarity_loss, + 'diversity_loss': total_diversity_loss, + } + + self.OUT.__setitem__('last_hidden_state', h) + self.OUT.__setitem__('logits', logits) + self.OUT.__setitem__('aux_loss', aux_loss) + self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 + self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息 + self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息 + self.OUT.__setitem__('past_key_values', None) # 不支持KV cache + return self.OUT + + @torch.inference_mode() + def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, + stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): + """Generate without KV cache""" + # 流式生成 + if stream: + return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) + + # 直接生成 + generated = [] + for i in range(input_ids.size(0)): + non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) + for _ in range(num_return_sequences): + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) + tokens_list = [tokens[:, -1:] for tokens in out] + gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad + full_sequence = torch.cat([non_pad, gen], dim=-1) + generated.append(full_sequence) + + max_length = max(seq.size(1) for seq in generated) + generated = [ + torch.cat( + [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], + dim=-1) + for seq in generated + ] + output = torch.cat(generated, dim=0) + res = output.view(input_ids.size(0) * num_return_sequences, -1) + return res + + def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): + """Stream generation without KV cache - regenerates full sequence each time""" + start = input_ids.shape[1] + while input_ids.shape[1] < start + max_new_tokens: + # 每次都重新计算整个序列(因为没有KV cache) + out = self(input_ids, **args) + logits = out.logits[:, -1, :] + + # 重复惩罚 + logits[:, list(set(input_ids.tolist()[0]))] /= rp + logits /= (temperature + 1e-9) + + # Top-p采样 + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = -float('Inf') + + input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + input_ids = torch.cat((input_ids, input_ids_next), dim=1) + yield input_ids[:, start:] + if input_ids_next.item() == eos_token_id: + break + + def apply_ema_update(self, ema_stats): + """ + 应用token-based EMA更新到memory_bank + 实验1.4.6:批量化tensor操作优化版本 + + Args: + ema_stats: 从forward pass收集的EMA统计信息,格式为: + {'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...} + """ + if not self.params.use_ema_update: + return {} + + # 增加EMA步数计数器 + self.ema_step_counter += 1 + + # 检查是否需要进行EMA更新 + if self.ema_step_counter % self.params.ema_update_freq != 0: + return {'ema_update_applied': False, 'reason': 'frequency_check_failed'} + + with torch.no_grad(): + device = self.memory_bank.device + knowledge_num, knowledge_length = self.memory_bank.shape + dim = self.params.dim + + # 🚀 批量收集所有层的数据(避免字典操作) + all_indices = [] + all_features = [] + total_selections = 0 + total_layers = 0 + + # 收集所有层的EMA统计信息 + for layer_ema_stats in ema_stats.values(): + if layer_ema_stats is None: + continue + + total_layers += 1 + memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected] + h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim] + + bsz, seq_len, num_selected = memory_indices.shape + total_selections += bsz * seq_len * num_selected + + # 展平索引和对应的h_for_memory + flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected] + + # 为每个选择位置复制对应的h_for_memory + h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim] + flat_h = h_expanded.reshape(-1, dim) # [batch * seq_len * num_selected, dim] + + all_indices.append(flat_indices) + all_features.append(flat_h) + + if not all_indices: + return {'ema_update_applied': False, 'reason': 'no_ema_stats'} + + # 🚀 合并所有数据 + all_indices = torch.cat(all_indices, dim=0) # [total_selections] + all_features = torch.cat(all_features, dim=0) # [total_selections, dim] + + # 🚀 批量计算每个memory的平均特征(避免循环) + unique_indices, inverse_indices = torch.unique(all_indices, return_inverse=True) + + # 使用scatter_add批量聚合(确保数据类型一致) + aggregated_features = torch.zeros(unique_indices.size(0), dim, device=device, dtype=all_features.dtype) + count_per_memory = torch.zeros(unique_indices.size(0), device=device, dtype=all_features.dtype) + + aggregated_features.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, dim), all_features) + count_per_memory.scatter_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=all_features.dtype)) + + # 计算平均值 + avg_features = aggregated_features / count_per_memory.unsqueeze(1) # [unique_count, dim] + + # 🚀 分批EMA更新(控制显存使用) + batch_size = 4096 # 每批处理4096个memory,控制显存 + updated_memories = 0 + + for i in range(0, unique_indices.size(0), batch_size): + end_i = min(i + batch_size, unique_indices.size(0)) + batch_indices = unique_indices[i:end_i] + batch_avg_features = avg_features[i:end_i] + + # 当前批次的token解码 + current_tokens_batch = self.memory_bank[batch_indices] # [batch_size, knowledge_length] + current_embeddings_batch = self.tok_embeddings(current_tokens_batch.view(-1)).view( + batch_indices.size(0), knowledge_length, dim) # [batch_size, knowledge_length, dim] + + old_features_batch = current_embeddings_batch.view(batch_indices.size(0), -1) # [batch_size, knowledge_length * dim] + expanded_new_features = batch_avg_features.repeat(1, knowledge_length) # [batch_size, knowledge_length * dim] + + # EMA更新:new = γ * old + (1-γ) * new_avg + updated_features_batch = ( + self.params.ema_decay * old_features_batch + + (1 - self.params.ema_decay) * expanded_new_features + ) + + # 分批编码为token_ids(关键:控制输出层的输入大小) + updated_reshaped = updated_features_batch.view(-1, dim) # [batch_size * knowledge_length, dim] + logits_batch = self.output(updated_reshaped) # [batch_size * knowledge_length, vocab_size] + new_token_ids_batch = torch.argmax(logits_batch, dim=-1).view(batch_indices.size(0), knowledge_length) + + # 🔥 新增: 应用冻结mask,只更新未冻结的条目 + # 检查哪些batch_indices对应的条目没有被冻结 + unfrozen_mask_batch = ~self.freeze_mask[batch_indices] # [batch_size] - True表示未冻结 + + # 只更新未冻结的条目 + if unfrozen_mask_batch.any(): + unfrozen_indices = batch_indices[unfrozen_mask_batch] + unfrozen_tokens = new_token_ids_batch[unfrozen_mask_batch] + self.memory_bank[unfrozen_indices] = unfrozen_tokens + updated_memories += unfrozen_indices.size(0) + else: + # 如果这个batch中的所有条目都被冻结,则跳过更新 + pass + + update_ratio = updated_memories / knowledge_num + + # 🔥 新增: 计算冻结统计信息 + frozen_count = self.freeze_mask.sum().item() + total_memories = knowledge_num + + update_stats = { + 'ema_update_applied': True, + 'ema_step': self.ema_step_counter.item(), + 'total_selections': total_selections, + 'total_layers': total_layers, + 'updated_memories': updated_memories, + 'update_ratio': update_ratio, + 'frozen_memories': frozen_count, + 'frozen_ratio': frozen_count / total_memories, + 'ema_decay': self.params.ema_decay, + 'selected_memory_coverage': updated_memories / knowledge_num, + } + + return update_stats \ No newline at end of file diff --git a/run_file/experiment_1_4_10.sh b/run_file/experiment_1_4_10.sh index 0e6eb43..86bd057 100644 --- a/run_file/experiment_1_4_10.sh +++ b/run_file/experiment_1_4_10.sh @@ -40,8 +40,8 @@ LOG_FILE="$LOG_DIR/experiment.log" # ---------------------------------------------------------------------------- # 🤖 硬件配置 # ---------------------------------------------------------------------------- -CUDA_VISIBLE_DEVICES="0,1" -NUM_PROCESSES="2" +CUDA_VISIBLE_DEVICES="0,1,2,3" +NUM_PROCESSES="4" MIXED_PRECISION="bf16" MAIN_PROCESS_PORT="29500" @@ -58,7 +58,7 @@ USE_MOE="false" # 🔥 知识库配置(四损失系统优化) KNOWLEDGE_NUM="1048576" # 1M entries -KNOWLEDGE_LENGTH="16" # 🔥 增加到16个token提升表达能力 +KNOWLEDGE_LENGTH="8" # 🔥 增加到16个token提升表达能力 KNOWLEDGE_DIM="128" # 保留兼容性 DISABLE_DB="false" @@ -67,7 +67,7 @@ DISABLE_DB="false" # ---------------------------------------------------------------------------- EPOCHS="3" EMBEDDING_EPOCH="2" -BATCH_SIZE="64" # 🔥 降低批次大小以适应更复杂的计算 +BATCH_SIZE="48" # 🔥 降低批次大小以适应更复杂的计算 ACCUMULATION_STEPS="8" # 🔥 增加累积步数保持有效批次大小 LEARNING_RATE="2e-4" # 🔥 适度降低学习率提升稳定性 DTYPE="bfloat16"