import math import struct import inspect import time from .LMConfig import LMConfig from typing import Any, Optional, Tuple, List, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim assert 0 <= 1 < ndim assert pos_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return pos_cis.view(*shape) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) pos_cis = unite_shape(pos_cis, xq_) xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): """Self attention module without KV cache""" def __init__(self, args: LMConfig): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 self.n_local_heads = args.n_heads self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor, pos_cis: torch.Tensor): """Forward pass without KV cache""" bsz, seq_len, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, pos_cis) # 注意:完全去除了KV cache相关代码 xq, xk, xv = ( xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( xq, xk, xv, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores += self.mask[:, :, :seq_len, :seq_len] scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = scores @ xv output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = self.resid_dropout(self.wo(output)) return output class MemoryGate(nn.Module): """Product Key Memory-based gate mechanism for memory selection""" def __init__(self, config: LMConfig): super().__init__() self.config = config self.dim = config.dim self.knowledge_num = config.knowledge_num self.knowledge_dim = config.knowledge_dim self.num_selected = getattr(config, 'num_selected', 16) # 确保知识库数量是完全平方数 assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \ f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory" self.num_keys = int(self.knowledge_num ** 0.5) # 查询投影:将输入维度映射到knowledge_dim * 2(用于两个product key) self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False) # Product Key Memory: 两个独立的键集合 self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2)) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor): """ Args: x: [batch_size, seq_len, dim] Returns: memory_indices: [batch_size, seq_len, num_selected] memory_scores: [batch_size, seq_len, num_selected] balance_loss: 平衡损失(KL散度 + 基尼系数) stats: 监控统计信息字典 """ bsz, seq_len, _ = x.shape # 生成查询向量 queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim] # 分割为两部分用于product key q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2] q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2] # 计算与两个键集合的相似度 scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys] scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys] # 获取top-k topk_scores_1, topk_indices_1 = scores_1.topk(self.num_selected, dim=-1) topk_scores_2, topk_indices_2 = scores_2.topk(self.num_selected, dim=-1) # 组合product key的结果 combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected] combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected] # 展平并选择最终的top-k combined_scores = combined_scores.view(bsz, seq_len, -1) combined_indices = combined_indices.view(bsz, seq_len, -1) final_scores, final_pk_indices = combined_scores.topk(self.num_selected, dim=-1) memory_indices = combined_indices.gather(-1, final_pk_indices) # 归一化分数 memory_scores = F.softmax(final_scores, dim=-1) memory_scores = self.dropout(memory_scores) # 计算平衡损失和监控统计 balance_loss, stats = self._compute_balance_loss_and_stats(memory_indices, memory_scores) return memory_indices, memory_scores, balance_loss, stats def _compute_balance_loss_and_stats(self, memory_indices, memory_scores): """ 计算平衡损失和监控统计信息 Args: memory_indices: [batch_size, seq_len, num_selected] memory_scores: [batch_size, seq_len, num_selected] Returns: balance_loss: 标量张量 stats: 统计信息字典 """ bsz, seq_len, num_selected = memory_indices.shape device = memory_indices.device # 1. 计算记忆选择分布 # 将所有选择的记忆索引展平 flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected] # 统计每个记忆条目被选中的次数 memory_counts = torch.zeros(self.knowledge_num, device=device) memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float)) # 计算选择概率分布 total_selections = bsz * seq_len * num_selected memory_probs = memory_counts / total_selections # 2. 计算KL散度损失(与均匀分布的KL散度) uniform_prob = 1.0 / self.knowledge_num # 避免log(0)的问题 memory_probs_safe = memory_probs + 1e-10 kl_loss = F.kl_div( torch.log(memory_probs_safe), torch.full_like(memory_probs, uniform_prob), reduction='sum' ) # 3. 计算基尼系数损失(衡量分布不平等程度) sorted_probs, _ = torch.sort(memory_probs) n = self.knowledge_num index = torch.arange(1, n + 1, device=device, dtype=torch.float) gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n gini_loss = gini_coeff # 基尼系数越大,分布越不均匀 # 4. 组合平衡损失 balance_loss = 0.5 * kl_loss + 0.5 * gini_loss # 5. 计算监控统计信息 with torch.no_grad(): # 记忆覆盖率:被选中的记忆条目占总数的比例 coverage_rate = (memory_counts > 0).float().mean().item() # 热点记忆:选择次数前10%的记忆条目 top10_threshold = torch.quantile(memory_counts, 0.9) hot_memories = (memory_counts >= top10_threshold).sum().item() # 死记忆:从未被选中的记忆条目 dead_memories = (memory_counts == 0).sum().item() # 记忆选择方差(衡量不平衡程度) selection_variance = memory_counts.var().item() stats = { 'gini_coefficient': gini_coeff.item(), 'kl_divergence': kl_loss.item(), 'coverage_rate': coverage_rate, 'hot_memories': hot_memories, 'dead_memories': dead_memories, 'selection_variance': selection_variance, 'max_selections': memory_counts.max().item(), 'min_selections': memory_counts.min().item(), } return balance_loss, stats class GatedMemoryFusion(nn.Module): """Gated MLP fusion for concatenated h_attn and selected memories""" def __init__(self, config: LMConfig): super().__init__() self.config = config self.dim = config.dim self.knowledge_dim = config.knowledge_dim self.num_selected = getattr(config, 'num_selected', 16) # 输入维度:dim (h_attn) + num_selected * knowledge_dim (选中的记忆) # 实验1.4.6:记忆解码后立即压缩回knowledge_dim避免显存爆炸 concat_dim = self.dim + self.num_selected * self.dim # 类似SwiGLU的门控MLP结构 self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False) self.up_proj = nn.Linear(concat_dim, self.dim, bias=False) self.down_proj = nn.Linear(self.dim, self.dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, h_attn: torch.Tensor, selected_memories: torch.Tensor, memory_scores: torch.Tensor): """ Args: h_attn: [batch_size, seq_len, dim] - Self attention output selected_memories: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights (not used in concatenation approach) Returns: output: [batch_size, seq_len, dim] """ bsz, seq_len, _ = h_attn.shape # 将选中的记忆展平为一维向量 # [batch, seq_len, num_selected, knowledge_dim] -> [batch, seq_len, num_selected * knowledge_dim] memory_flat = selected_memories.reshape(bsz, seq_len, -1) # 拼接h_attn和记忆信息 concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * dim] # 门控MLP处理(类似SwiGLU) gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim] up = self.up_proj(concat_input) # [batch, seq_len, dim] fusion_output = gate * up # Element-wise multiplication # 输出投影 output = self.down_proj(fusion_output) # [batch, seq_len, dim] output = self.dropout(output) return output class MiniMindBlock(nn.Module): """Transformer block with memory-based cross attention instead of FFN""" def __init__(self, layer_id: int, config: LMConfig): super().__init__() self.config = config # 保存config引用 self.n_heads = config.n_heads self.dim = config.dim self.head_dim = config.dim // config.n_heads self.attention = Attention(config) self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps) # 记忆相关模块 self.memory_gate = MemoryGate(config) self.gated_memory_fusion = GatedMemoryFusion(config) self.attentionpool = nn.Linear(16, 1) def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False): """ Args: x: [batch_size, seq_len, dim] pos_cis: positional encoding memory_bank: [knowledge_num, knowledge_dim] - shared memory bank collect_ema_stats: 是否收集EMA更新统计信息 Returns: out: [batch_size, seq_len, dim] balance_loss: 该层的平衡损失 layer_stats: 该层的监控统计信息 ema_stats: EMA更新统计信息(如果collect_ema_stats=True) cosine_stats: 查找向量与选中记忆条目的余弦相似度统计信息 """ # Self attention h_attn = self.attention(self.attention_norm(x), pos_cis) h = x + h_attn # 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出) h_for_memory = self.memory_norm(h_attn) # 门控选择记忆 memory_indices, memory_scores, balance_loss, layer_stats = self.memory_gate(h_for_memory) # 根据索引获取记忆数据 - 实验1.4.6:解码token_id为特征向量 bsz, seq_len, num_selected = memory_indices.shape memory_indices_flat = memory_indices.view(-1) selected_token_ids = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_length] # 解码token_ids为特征向量并立即压缩避免显存爆炸 selected_embeddings = tok_embeddings(selected_token_ids) # [batch * seq_len * num_selected, knowledge_length, dim] # 立即压缩:knowledge_length * dim -> knowledge_dim 避免显存爆炸 # 使用平均池化压缩knowledge_length维度 # pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim] attn_weights = self.attention(selected_embeddings) attn_weights = torch.softmax(attn_weights, dim=1) pooled_memory = torch.sum(selected_embeddings * attn_weights, dim=1) selected_memory = pooled_memory.view(bsz, seq_len, num_selected, self.dim) # [batch, seq_len, num_selected, dim] # 门控MLP融合:串型连接h_attn和选中的记忆 memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores) # 残差连接 out = h + memory_output # 🔍 新增: 计算查找向量与选中记忆条目的余弦相似度 with torch.no_grad(): # 扩展查找向量维度以匹配selected_memory h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim] # 计算余弦相似度:cosine_sim(query, memory) for each selected memory cosine_similarities = F.cosine_similarity( h_expanded, # [batch, seq_len, num_selected, dim] selected_memory, # [batch, seq_len, num_selected, knowledge_dim] dim=-1 # 在knowledge_dim维度计算余弦相似度 ) # [batch, seq_len, num_selected] # 计算余弦相似度统计信息 cosine_stats = { 'cosine_similarities': cosine_similarities, # [batch, seq_len, num_selected] 'avg_cosine_similarity': cosine_similarities.mean().item(), # 平均余弦相似度 'max_cosine_similarity': cosine_similarities.max().item(), # 最大余弦相似度 'min_cosine_similarity': cosine_similarities.min().item(), # 最小余弦相似度 'std_cosine_similarity': cosine_similarities.std().item(), # 余弦相似度标准差 } # 收集EMA更新统计信息(仅在训练时且启用时) ema_stats = None if collect_ema_stats and self.training: ema_stats = { 'memory_indices': memory_indices, # [batch, seq_len, num_selected] 'memory_scores': memory_scores, # [batch, seq_len, num_selected] 'h_for_memory': h_for_memory, # [batch, seq_len, dim] 'selected_memory': selected_memory, # [batch, seq_len, num_selected, knowledge_dim] } if collect_ema_stats: return out, balance_loss, layer_stats, ema_stats, cosine_stats else: return out, balance_loss, layer_stats, cosine_stats class MiniMindLM(PreTrainedModel): config_class = LMConfig def __init__(self, params: LMConfig = None): self.params = params or LMConfig() super().__init__(self.params) self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) # 初始化共享记忆库 - 实验1.4.6:存储token_id而非特征向量 # VQ-VAE风格:memory_bank作为codebook,使用EMA更新而非梯度更新 if params.use_ema_update: self.memory_bank = nn.Parameter( torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)), requires_grad=False # 禁用梯度更新,使用EMA更新 ) else: self.memory_bank = nn.Parameter( torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)), requires_grad=True # 传统梯度更新 ) # EMA更新相关缓冲区 if params.use_ema_update: # 记录每个memory条目的更新统计 self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False) # 注意:现在memory_bank存储token_id,但EMA在特征空间进行,所以不需要sum_buffer了 # self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False) # EMA更新频率计数器 self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False) # 记录上一步的记忆库状态,用于计算更新统计 self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False) # 🔥 新增: 冻结mask - 标记哪些memory_bank条目被冻结(不更新) if params.freeze_ratio > 0.0: freeze_num = int(params.knowledge_num * params.freeze_ratio) freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool) # 固定冻结前面的条目 freeze_mask[:freeze_num] = True self.register_buffer('freeze_mask', freeze_mask, persistent=False) print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen") else: self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False) print(f"🔥 Memory bank freezing disabled: all entries can be updated") self.OUT = CausalLMOutputWithPast() def get_memory_update_stats(self): """ 计算记忆库更新统计信息 Returns: update_stats: 包含更新统计的字典 """ with torch.no_grad(): if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0: # 计算L2距离变化 l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1) avg_l2_distance = l2_distance.mean().item() max_l2_distance = l2_distance.max().item() # 计算余弦相似度 cos_sim = F.cosine_similarity( self.memory_bank.view(-1), self.prev_memory_bank.view(-1), dim=0 ).item() # 计算更新率(发生显著变化的记忆条目比例) threshold = 0.01 # 更新阈值 updated_memories = (l2_distance > threshold).sum().item() update_rate = updated_memories / self.memory_bank.size(0) update_stats = { 'memory_avg_l2_change': avg_l2_distance, 'memory_max_l2_change': max_l2_distance, 'memory_cosine_similarity': cos_sim, 'memory_update_rate': update_rate, 'memory_updated_count': updated_memories } else: # 第一次调用时的默认值 update_stats = { 'memory_avg_l2_change': 0.0, 'memory_max_l2_change': 0.0, 'memory_cosine_similarity': 1.0, 'memory_update_rate': 0.0, 'memory_updated_count': 0 } # 更新prev_memory_bank self.prev_memory_bank.copy_(self.memory_bank) return update_stats def forward(self, input_ids: Optional[torch.Tensor] = None, **args): """Forward pass without KV cache support""" start_pos = args.get('start_pos', 0) collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] # 收集所有层的平衡损失和统计信息 total_balance_loss = 0 all_layer_stats = {} all_ema_stats = {} all_cosine_stats = {} for layer_idx, layer in enumerate(self.layers): if collect_ema_stats: h, balance_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True) all_ema_stats[f'layer_{layer_idx}'] = ema_stats else: h, balance_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False) total_balance_loss += balance_loss # 为每层的统计信息添加前缀 for key, value in layer_stats.items(): all_layer_stats[f'layer_{layer_idx}_{key}'] = value # 为每层的余弦相似度统计信息添加前缀 for key, value in cosine_stats.items(): all_cosine_stats[f'layer_{layer_idx}_{key}'] = value logits = self.output(self.norm(h)) # 使用总的平衡损失作为aux_loss aux_loss = total_balance_loss self.OUT.__setitem__('last_hidden_state', h) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息 self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息 self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息 self.OUT.__setitem__('past_key_values', None) # 不支持KV cache return self.OUT @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): """Generate without KV cache""" # 流式生成 if stream: return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) for _ in range(num_return_sequences): out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) generated.append(full_sequence) max_length = max(seq.size(1) for seq in generated) generated = [ torch.cat( [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], dim=-1) for seq in generated ] output = torch.cat(generated, dim=0) res = output.view(input_ids.size(0) * num_return_sequences, -1) return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): """Stream generation without KV cache - regenerates full sequence each time""" start = input_ids.shape[1] while input_ids.shape[1] < start + max_new_tokens: # 每次都重新计算整个序列(因为没有KV cache) out = self(input_ids, **args) logits = out.logits[:, -1, :] # 重复惩罚 logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9) # Top-p采样 if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float('Inf') input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break def apply_ema_update(self, ema_stats): """ 应用token-based EMA更新到memory_bank 实验1.4.6:批量化tensor操作优化版本 Args: ema_stats: 从forward pass收集的EMA统计信息,格式为: {'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...} """ if not self.params.use_ema_update: return {} # 增加EMA步数计数器 self.ema_step_counter += 1 # 检查是否需要进行EMA更新 if self.ema_step_counter % self.params.ema_update_freq != 0: return {'ema_update_applied': False, 'reason': 'frequency_check_failed'} with torch.no_grad(): device = self.memory_bank.device knowledge_num, knowledge_length = self.memory_bank.shape dim = self.params.dim # 🚀 批量收集所有层的数据(避免字典操作) all_indices = [] all_features = [] total_selections = 0 total_layers = 0 # 收集所有层的EMA统计信息 for layer_ema_stats in ema_stats.values(): if layer_ema_stats is None: continue total_layers += 1 memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected] h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim] bsz, seq_len, num_selected = memory_indices.shape total_selections += bsz * seq_len * num_selected # 展平索引和对应的h_for_memory flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected] # 为每个选择位置复制对应的h_for_memory h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim] flat_h = h_expanded.reshape(-1, dim) # [batch * seq_len * num_selected, dim] all_indices.append(flat_indices) all_features.append(flat_h) if not all_indices: return {'ema_update_applied': False, 'reason': 'no_ema_stats'} # 🚀 合并所有数据 all_indices = torch.cat(all_indices, dim=0) # [total_selections] all_features = torch.cat(all_features, dim=0) # [total_selections, dim] # 🚀 批量计算每个memory的平均特征(避免循环) unique_indices, inverse_indices = torch.unique(all_indices, return_inverse=True) # 使用scatter_add批量聚合(确保数据类型一致) aggregated_features = torch.zeros(unique_indices.size(0), dim, device=device, dtype=all_features.dtype) count_per_memory = torch.zeros(unique_indices.size(0), device=device, dtype=all_features.dtype) aggregated_features.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, dim), all_features) count_per_memory.scatter_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=all_features.dtype)) # 计算平均值 avg_features = aggregated_features / count_per_memory.unsqueeze(1) # [unique_count, dim] # 🚀 分批EMA更新(控制显存使用) batch_size = 4096 # 每批处理4096个memory,控制显存 updated_memories = 0 for i in range(0, unique_indices.size(0), batch_size): end_i = min(i + batch_size, unique_indices.size(0)) batch_indices = unique_indices[i:end_i] batch_avg_features = avg_features[i:end_i] # 当前批次的token解码 current_tokens_batch = self.memory_bank[batch_indices] # [batch_size, knowledge_length] current_embeddings_batch = self.tok_embeddings(current_tokens_batch.view(-1)).view( batch_indices.size(0), knowledge_length, dim) # [batch_size, knowledge_length, dim] old_features_batch = current_embeddings_batch.view(batch_indices.size(0), -1) # [batch_size, knowledge_length * dim] expanded_new_features = batch_avg_features.repeat(1, knowledge_length) # [batch_size, knowledge_length * dim] # EMA更新:new = γ * old + (1-γ) * new_avg updated_features_batch = ( self.params.ema_decay * old_features_batch + (1 - self.params.ema_decay) * expanded_new_features ) # 分批编码为token_ids(关键:控制输出层的输入大小) updated_reshaped = updated_features_batch.view(-1, dim) # [batch_size * knowledge_length, dim] logits_batch = self.output(updated_reshaped) # [batch_size * knowledge_length, vocab_size] new_token_ids_batch = torch.argmax(logits_batch, dim=-1).view(batch_indices.size(0), knowledge_length) # 🔥 新增: 应用冻结mask,只更新未冻结的条目 # 检查哪些batch_indices对应的条目没有被冻结 unfrozen_mask_batch = ~self.freeze_mask[batch_indices] # [batch_size] - True表示未冻结 # 只更新未冻结的条目 if unfrozen_mask_batch.any(): unfrozen_indices = batch_indices[unfrozen_mask_batch] unfrozen_tokens = new_token_ids_batch[unfrozen_mask_batch] self.memory_bank[unfrozen_indices] = unfrozen_tokens updated_memories += unfrozen_indices.size(0) else: # 如果这个batch中的所有条目都被冻结,则跳过更新 pass update_ratio = updated_memories / knowledge_num # 🔥 新增: 计算冻结统计信息 frozen_count = self.freeze_mask.sum().item() total_memories = knowledge_num update_stats = { 'ema_update_applied': True, 'ema_step': self.ema_step_counter.item(), 'total_selections': total_selections, 'total_layers': total_layers, 'updated_memories': updated_memories, 'update_ratio': update_ratio, 'frozen_memories': frozen_count, 'frozen_ratio': frozen_count / total_memories, 'ema_decay': self.params.ema_decay, 'selected_memory_coverage': updated_memories / knowledge_num, } return update_stats