Minimind/model/model_memory_1_4_10.py
2025-09-11 00:10:08 +08:00

932 lines
44 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
"""Self attention module without KV cache"""
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor):
"""Forward pass without KV cache"""
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 注意完全去除了KV cache相关代码
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class MemoryGate(nn.Module):
"""Product Key Memory-based gate mechanism for memory selection with Gumbel-Softmax"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_num = config.knowledge_num
self.knowledge_dim = config.knowledge_dim
self.num_candidates = getattr(config, 'num_candidates', 32) # Generate 32 candidates
self.num_selected = getattr(config, 'num_selected', 1) # Select 1 best from candidates
# 确保知识库数量是完全平方数
assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \
f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory"
self.num_keys = int(self.knowledge_num ** 0.5)
# 查询投影将输入维度映射到knowledge_dim * 2用于两个product key
self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False)
# Product Key Memory: 两个独立的键集合
self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2))
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor):
"""
Args:
x: [batch_size, seq_len, dim]
Returns:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
balance_loss: 平衡损失KL散度 + 基尼系数)
stats: 监控统计信息字典
"""
bsz, seq_len, _ = x.shape
# 生成查询向量
queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim]
# 分割为两部分用于product key
q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2]
q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2]
# 计算与两个键集合的相似度
scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys]
scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys]
# 获取top-k candidates (now using num_candidates instead of num_selected)
topk_scores_1, topk_indices_1 = scores_1.topk(self.num_candidates, dim=-1)
topk_scores_2, topk_indices_2 = scores_2.topk(self.num_candidates, dim=-1)
# 组合product key的结果
combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_candidates, num_candidates]
combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_candidates, num_candidates]
# 展平并选择最终的top-k candidates
combined_scores = combined_scores.view(bsz, seq_len, -1)
combined_indices = combined_indices.view(bsz, seq_len, -1)
candidate_scores, candidate_pk_indices = combined_scores.topk(self.num_candidates, dim=-1)
candidate_indices = combined_indices.gather(-1, candidate_pk_indices) # [batch, seq_len, num_candidates]
# 归一化候选分数
candidate_scores = F.softmax(candidate_scores, dim=-1)
candidate_scores = self.dropout(candidate_scores)
# 返回候选项用于后续的相似度选择
# 注意这里返回候选项在MiniMindBlock中进行相似度选择和多样性损失计算
return candidate_indices, candidate_scores, None, {}
def _compute_balance_loss_and_stats(self, memory_indices, memory_scores):
"""
计算平衡损失和监控统计信息
Args:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
Returns:
balance_loss: 标量张量
stats: 统计信息字典
"""
bsz, seq_len, num_selected = memory_indices.shape
device = memory_indices.device
# 1. 计算记忆选择分布
# 将所有选择的记忆索引展平
flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected]
# 统计每个记忆条目被选中的次数
memory_counts = torch.zeros(self.knowledge_num, device=device)
memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float))
# 计算选择概率分布
total_selections = bsz * seq_len * num_selected
memory_probs = memory_counts / total_selections
# 2. 计算KL散度损失与均匀分布的KL散度
uniform_prob = 1.0 / self.knowledge_num
# 避免log(0)的问题
memory_probs_safe = memory_probs + 1e-10
kl_loss = F.kl_div(
torch.log(memory_probs_safe),
torch.full_like(memory_probs, uniform_prob),
reduction='sum'
)
# 3. 计算基尼系数损失(衡量分布不平等程度)
sorted_probs, _ = torch.sort(memory_probs)
n = self.knowledge_num
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
gini_loss = gini_coeff # 基尼系数越大,分布越不均匀
# 4. 组合平衡损失
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
# 5. 计算监控统计信息
with torch.no_grad():
# 记忆覆盖率:被选中的记忆条目占总数的比例
coverage_rate = (memory_counts > 0).float().mean().item()
# 热点记忆选择次数前10%的记忆条目
top10_threshold = torch.quantile(memory_counts, 0.9)
hot_memories = (memory_counts >= top10_threshold).sum().item()
# 死记忆:从未被选中的记忆条目
dead_memories = (memory_counts == 0).sum().item()
# 记忆选择方差(衡量不平衡程度)
selection_variance = memory_counts.var().item()
stats = {
'gini_coefficient': gini_coeff.item(),
'kl_divergence': kl_loss.item(),
'coverage_rate': coverage_rate,
'hot_memories': hot_memories,
'dead_memories': dead_memories,
'selection_variance': selection_variance,
'max_selections': memory_counts.max().item(),
'min_selections': memory_counts.min().item(),
}
return balance_loss, stats
class GatedMemoryFusion(nn.Module):
"""Gated MLP fusion for concatenated h_attn and selected memories"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_dim = config.knowledge_dim
self.num_selected = getattr(config, 'num_selected', 1) # Now we select 1 best memory
# 输入维度dim (h_attn) + num_selected * dim (选中的记忆现在只有1个)
# 实验1.4.9修改为只选择1个最佳记忆
concat_dim = self.dim + self.num_selected * self.dim
# 类似SwiGLU的门控MLP结构
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.up_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.down_proj = nn.Linear(self.dim, self.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, h_attn: torch.Tensor, selected_memory: torch.Tensor):
"""
Args:
h_attn: [batch_size, seq_len, dim] - Self attention output
selected_memory: [batch_size, seq_len, dim] - Selected single best memory
Returns:
output: [batch_size, seq_len, dim]
"""
bsz, seq_len, _ = h_attn.shape
# 拼接h_attn和最佳记忆
concat_input = torch.cat([h_attn, selected_memory], dim=-1) # [batch, seq_len, dim + dim]
# 门控MLP处理类似SwiGLU
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
up = self.up_proj(concat_input) # [batch, seq_len, dim]
fusion_output = gate * up # Element-wise multiplication
# 输出投影
output = self.down_proj(fusion_output) # [batch, seq_len, dim]
output = self.dropout(output)
return output
class MiniMindBlock(nn.Module):
"""Transformer block with memory-based cross attention instead of FFN"""
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.config = config # 保存config引用
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 记忆相关模块
self.memory_gate = MemoryGate(config)
self.gated_memory_fusion = GatedMemoryFusion(config)
# Gumbel-Softmax参数
self.gumbel_temperature = getattr(config, 'gumbel_temperature', 1.0)
# self.attentionpool = nn.Linear(config.dim, 1)
def gumbel_softmax_selection(self, similarity_scores, temperature=1.0, hard=True):
"""
使用Gumbel-Softmax进行可微分的离散选择
Args:
similarity_scores: [batch_size, seq_len, num_candidates] - 相似度分数
temperature: Gumbel-Softmax温度参数
hard: 是否使用硬选择one-hot
Returns:
selection_weights: [batch_size, seq_len, num_candidates] - 选择权重
selected_indices: [batch_size, seq_len] - 选中的索引(用于统计)
"""
# 添加Gumbel噪声
gumbel_noise = -torch.log(-torch.log(torch.rand_like(similarity_scores) + 1e-20) + 1e-20)
logits = (similarity_scores + gumbel_noise) / temperature
# Softmax
soft_weights = F.softmax(logits, dim=-1)
if hard:
# 硬选择创建one-hot向量
_, max_indices = soft_weights.max(dim=-1, keepdim=True)
hard_weights = torch.zeros_like(soft_weights).scatter_(-1, max_indices, 1.0)
# 使用straight-through estimator
selection_weights = hard_weights - soft_weights.detach() + soft_weights
selected_indices = max_indices.squeeze(-1) # [batch_size, seq_len]
else:
# 软选择
selection_weights = soft_weights
selected_indices = torch.argmax(soft_weights, dim=-1)
return selection_weights, selected_indices
def compute_diversity_loss(self, candidate_memories):
"""
计算候选集内部多样性损失(鼓励候选项之间的差异性)
Args:
candidate_memories: [batch_size, seq_len, num_candidates, dim]
Returns:
diversity_loss: 标量张量
"""
bsz, seq_len, num_candidates, dim = candidate_memories.shape
# 计算候选项之间的相似度矩阵
# 归一化候选记忆用于计算余弦相似度
normalized_memories = F.normalize(candidate_memories, p=2, dim=-1) # [batch, seq_len, num_candidates, dim]
# 计算相似度矩阵: [batch, seq_len, num_candidates, num_candidates]
similarity_matrix = torch.matmul(normalized_memories, normalized_memories.transpose(-2, -1))
# 移除对角线(自相似度=1
mask = torch.eye(num_candidates, device=candidate_memories.device).bool()
mask = mask.unsqueeze(0).unsqueeze(0).expand(bsz, seq_len, -1, -1)
# 计算非对角线元素的平均相似度(希望越小越好,表示越多样)
off_diagonal_similarities = similarity_matrix.masked_select(~mask)
avg_similarity = off_diagonal_similarities.mean()
# 多样性损失:相似度越高,损失越大
diversity_loss = avg_similarity
return diversity_loss
def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False):
"""
实验1.4.9: Gumbel-Softmax + 多样性损失 + 可微分相似度损失
Args:
x: [batch_size, seq_len, dim]
pos_cis: positional encoding
memory_bank: [knowledge_num, knowledge_length] - shared memory bank with token IDs
tok_embeddings: token embedding layer
collect_ema_stats: 是否收集EMA更新统计信息
Returns:
out: [batch_size, seq_len, dim]
balance_loss: 该层的平衡损失 (从候选项计算)
similarity_loss: 相似度损失 (可微分)
diversity_loss: 多样性损失
layer_stats: 该层的监控统计信息
ema_stats: EMA更新统计信息如果collect_ema_stats=True
cosine_stats: 查找向量与候选记忆条目的余弦相似度统计信息
"""
# Self attention
h_attn = self.attention(self.attention_norm(x), pos_cis)
h = x + h_attn
# 使用h_attn作为门控和交叉注意力的输入核心self attention的输出
h_for_memory = self.memory_norm(h_attn)
# 🔥 新架构生成32个候选项
candidate_indices, candidate_scores, _, _ = self.memory_gate(h_for_memory)
# candidate_indices: [batch, seq_len, num_candidates]
# candidate_scores: [batch, seq_len, num_candidates]
bsz, seq_len, num_candidates = candidate_indices.shape
# 解码候选token_ids为特征向量
candidate_indices_flat = candidate_indices.view(-1) # [batch * seq_len * num_candidates]
candidate_token_ids = memory_bank[candidate_indices_flat] # [batch * seq_len * num_candidates, knowledge_length]
# 解码为embeddings并池化
candidate_embeddings = tok_embeddings(candidate_token_ids) # [batch * seq_len * num_candidates, knowledge_length, dim]
candidate_memories = candidate_embeddings.mean(dim=1) # [batch * seq_len * num_candidates, dim]
candidate_memories = candidate_memories.view(bsz, seq_len, num_candidates, self.dim) # [batch, seq_len, num_candidates, dim]
# 🔥 核心改进: 计算可微分的相似度分数 (移除no_grad)
h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_candidates, -1) # [batch, seq_len, num_candidates, dim]
similarity_scores = F.cosine_similarity(h_expanded, candidate_memories, dim=-1) # [batch, seq_len, num_candidates]
# 🔥 使用Gumbel-Softmax选择最佳候选项
selection_weights, selected_indices = self.gumbel_softmax_selection(
similarity_scores,
temperature=self.gumbel_temperature,
hard=True
) # selection_weights: [batch, seq_len, num_candidates], selected_indices: [batch, seq_len]
# 🔥 计算相似度损失 (现在是可微分的!)
# 相似度损失:希望选中的记忆与查询向量相似度尽可能高
selected_similarities = (similarity_scores * selection_weights).sum(dim=-1) # [batch, seq_len]
similarity_loss = -selected_similarities.mean() # 负号:相似度越高,损失越小
# 🔥 计算候选集多样性损失
diversity_loss = self.compute_diversity_loss(candidate_memories)
# 🔥 使用selection_weights进行加权选择最终记忆
batch, seq_len, num_candidates, dim = candidate_memories.shape
selected_memory = (candidate_memories * selection_weights.unsqueeze(-1)) # [batch, seq_len, dim]
selected_memory = weighted_memories.reshape(batch_size, seq_len * num_candidates, dim)
# 门控MLP融合只融合选中的单个最佳记忆
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory)
# 残差连接
out = h + memory_output
# 🔥 计算平衡损失和统计信息 (基于候选项的选择分布)
balance_loss, layer_stats = self._compute_candidate_balance_stats(candidate_indices, selection_weights)
# 🔥 计算详细的相似度统计信息
cosine_stats = {
'similarity_scores': similarity_scores, # [batch, seq_len, num_candidates]
'selected_similarities': selected_similarities, # [batch, seq_len]
'avg_similarity': similarity_scores.mean().item(), # 平均相似度
'max_similarity': similarity_scores.max().item(), # 最大相似度
'min_similarity': similarity_scores.min().item(), # 最小相似度
'selected_avg_similarity': selected_similarities.mean().item(), # 选中记忆的平均相似度
'selection_entropy': -torch.sum(selection_weights * torch.log(selection_weights + 1e-10), dim=-1).mean().item() # 选择熵
}
# 收集EMA更新统计信息现在基于选中的记忆
ema_stats = None
if collect_ema_stats and self.training:
# 扩展选中的索引以匹配EMA更新的期望格式
selected_memory_indices = candidate_indices.gather(2, selected_indices.unsqueeze(-1)) # [batch, seq_len, 1]
ema_stats = {
'memory_indices': selected_memory_indices, # [batch, seq_len, 1]
'memory_scores': torch.ones_like(selected_memory_indices.float()), # [batch, seq_len, 1] - 选中的权重为1
'h_for_memory': h_for_memory, # [batch, seq_len, dim]
'selected_memory': selected_memory.unsqueeze(2), # [batch, seq_len, 1, dim]
}
if collect_ema_stats:
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats
else:
return out, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats
def _compute_candidate_balance_stats(self, candidate_indices, selection_weights):
"""
计算基于候选项选择的平衡损失和统计信息
Args:
candidate_indices: [batch_size, seq_len, num_candidates]
selection_weights: [batch_size, seq_len, num_candidates] - Gumbel-Softmax权重
Returns:
balance_loss: 标量张量
stats: 统计信息字典
"""
bsz, seq_len, num_candidates = candidate_indices.shape
device = candidate_indices.device
# 使用加权统计每个记忆条目被选中的概率
flat_indices = candidate_indices.view(-1) # [batch * seq_len * num_candidates]
flat_weights = selection_weights.view(-1) # [batch * seq_len * num_candidates]
# 统计每个记忆条目被选中的加权次数
memory_counts = torch.zeros(self.config.knowledge_num, device=device)
memory_counts.scatter_add_(0, flat_indices, flat_weights)
# 计算选择概率分布
total_selections = memory_counts.sum()
memory_probs = memory_counts / (total_selections + 1e-10)
# 计算KL散度损失与均匀分布的KL散度
uniform_prob = 1.0 / self.config.knowledge_num
memory_probs_safe = memory_probs + 1e-10
kl_loss = F.kl_div(
torch.log(memory_probs_safe),
torch.full_like(memory_probs, uniform_prob),
reduction='sum'
)
# 计算基尼系数损失
sorted_probs, _ = torch.sort(memory_probs)
n = self.config.knowledge_num
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
gini_loss = gini_coeff
# 组合平衡损失
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
# 计算统计信息
with torch.no_grad():
coverage_rate = (memory_counts > 0.01).float().mean().item() # 被选中概率>1%的记忆比例
top10_threshold = torch.quantile(memory_counts, 0.9)
hot_memories = (memory_counts >= top10_threshold).sum().item()
dead_memories = (memory_counts < 0.01).sum().item() # 几乎从未被选中的记忆
selection_variance = memory_counts.var().item()
stats = {
'gini_coefficient': gini_coeff.item(),
'kl_divergence': kl_loss.item(),
'coverage_rate': coverage_rate,
'hot_memories': hot_memories,
'dead_memories': dead_memories,
'selection_variance': selection_variance,
'max_selections': memory_counts.max().item(),
'min_selections': memory_counts.min().item(),
}
return balance_loss, stats
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
# 初始化共享记忆库 - 实验1.4.6存储token_id而非特征向量
# VQ-VAE风格memory_bank作为codebook使用EMA更新而非梯度更新
if params.use_ema_update:
self.memory_bank = nn.Parameter(
torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)),
requires_grad=False # 禁用梯度更新使用EMA更新
)
else:
self.memory_bank = nn.Parameter(
torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)),
requires_grad=True # 传统梯度更新
)
# EMA更新相关缓冲区
if params.use_ema_update:
# 记录每个memory条目的更新统计
self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False)
# 注意现在memory_bank存储token_id但EMA在特征空间进行所以不需要sum_buffer了
# self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False)
# EMA更新频率计数器
self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False)
# 记录上一步的记忆库状态,用于计算更新统计
self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False)
# 🔥 新增: 冻结mask - 标记哪些memory_bank条目被冻结不更新
if params.freeze_ratio > 0.0:
freeze_num = int(params.knowledge_num * params.freeze_ratio)
freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool)
# 固定冻结前面的条目
freeze_mask[:freeze_num] = True
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen", flush=True)
import sys; sys.stdout.flush()
else:
self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False)
print(f"🔥 Memory bank freezing disabled: all entries can be updated", flush=True)
import sys; sys.stdout.flush()
self.OUT = CausalLMOutputWithPast()
def get_memory_update_stats(self):
"""
计算记忆库更新统计信息
Returns:
update_stats: 包含更新统计的字典
"""
with torch.no_grad():
if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0:
# 计算L2距离变化
l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1)
avg_l2_distance = l2_distance.mean().item()
max_l2_distance = l2_distance.max().item()
# 计算余弦相似度
cos_sim = F.cosine_similarity(
self.memory_bank.view(-1),
self.prev_memory_bank.view(-1),
dim=0
).item()
# 计算更新率(发生显著变化的记忆条目比例)
threshold = 0.01 # 更新阈值
updated_memories = (l2_distance > threshold).sum().item()
update_rate = updated_memories / self.memory_bank.size(0)
update_stats = {
'memory_avg_l2_change': avg_l2_distance,
'memory_max_l2_change': max_l2_distance,
'memory_cosine_similarity': cos_sim,
'memory_update_rate': update_rate,
'memory_updated_count': updated_memories
}
else:
# 第一次调用时的默认值
update_stats = {
'memory_avg_l2_change': 0.0,
'memory_max_l2_change': 0.0,
'memory_cosine_similarity': 1.0,
'memory_update_rate': 0.0,
'memory_updated_count': 0
}
# 更新prev_memory_bank
self.prev_memory_bank.copy_(self.memory_bank)
return update_stats
def forward(self,
input_ids: Optional[torch.Tensor] = None,
**args):
"""Forward pass without KV cache support"""
start_pos = args.get('start_pos', 0)
collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
# 收集所有层的损失和统计信息 - 实验1.4.9: 四损失系统
total_balance_loss = 0
total_similarity_loss = 0
total_diversity_loss = 0
all_layer_stats = {}
all_ema_stats = {}
all_cosine_stats = {}
for layer_idx, layer in enumerate(self.layers):
if collect_ema_stats:
h, balance_loss, similarity_loss, diversity_loss, layer_stats, ema_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
else:
h, balance_loss, similarity_loss, diversity_loss, layer_stats, cosine_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
# 累加四种损失
total_balance_loss += balance_loss
total_similarity_loss += similarity_loss
total_diversity_loss += diversity_loss
# 为每层的统计信息添加前缀
for key, value in layer_stats.items():
all_layer_stats[f'layer_{layer_idx}_{key}'] = value
# 为每层的余弦相似度统计信息添加前缀
for key, value in cosine_stats.items():
all_cosine_stats[f'layer_{layer_idx}_{key}'] = value
logits = self.output(self.norm(h))
# 🔥 新的四损失结构
aux_loss = {
'balance_loss': total_balance_loss,
'similarity_loss': total_similarity_loss,
'diversity_loss': total_diversity_loss,
}
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
self.OUT.__setitem__('cosine_stats', all_cosine_stats) # 添加余弦相似度统计信息
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
"""Generate without KV cache"""
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
"""Stream generation without KV cache - regenerates full sequence each time"""
start = input_ids.shape[1]
while input_ids.shape[1] < start + max_new_tokens:
# 每次都重新计算整个序列因为没有KV cache
out = self(input_ids, **args)
logits = out.logits[:, -1, :]
# 重复惩罚
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
def apply_ema_update(self, ema_stats):
"""
应用token-based EMA更新到memory_bank
实验1.4.6批量化tensor操作优化版本
Args:
ema_stats: 从forward pass收集的EMA统计信息格式为
{'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...}
"""
if not self.params.use_ema_update:
return {}
# 增加EMA步数计数器
self.ema_step_counter += 1
# 检查是否需要进行EMA更新
if self.ema_step_counter % self.params.ema_update_freq != 0:
return {'ema_update_applied': False, 'reason': 'frequency_check_failed'}
with torch.no_grad():
device = self.memory_bank.device
knowledge_num, knowledge_length = self.memory_bank.shape
dim = self.params.dim
# 🚀 批量收集所有层的数据(避免字典操作)
all_indices = []
all_features = []
total_selections = 0
total_layers = 0
# 收集所有层的EMA统计信息
for layer_ema_stats in ema_stats.values():
if layer_ema_stats is None:
continue
total_layers += 1
memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected]
h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim]
bsz, seq_len, num_selected = memory_indices.shape
total_selections += bsz * seq_len * num_selected
# 展平索引和对应的h_for_memory
flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected]
# 为每个选择位置复制对应的h_for_memory
h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim]
flat_h = h_expanded.reshape(-1, dim) # [batch * seq_len * num_selected, dim]
all_indices.append(flat_indices)
all_features.append(flat_h)
if not all_indices:
return {'ema_update_applied': False, 'reason': 'no_ema_stats'}
# 🚀 合并所有数据
all_indices = torch.cat(all_indices, dim=0) # [total_selections]
all_features = torch.cat(all_features, dim=0) # [total_selections, dim]
# 🚀 批量计算每个memory的平均特征避免循环
unique_indices, inverse_indices = torch.unique(all_indices, return_inverse=True)
# 使用scatter_add批量聚合确保数据类型一致
aggregated_features = torch.zeros(unique_indices.size(0), dim, device=device, dtype=all_features.dtype)
count_per_memory = torch.zeros(unique_indices.size(0), device=device, dtype=all_features.dtype)
aggregated_features.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, dim), all_features)
count_per_memory.scatter_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=all_features.dtype))
# 计算平均值
avg_features = aggregated_features / count_per_memory.unsqueeze(1) # [unique_count, dim]
# 🚀 分批EMA更新控制显存使用
batch_size = 4096 # 每批处理4096个memory控制显存
updated_memories = 0
for i in range(0, unique_indices.size(0), batch_size):
end_i = min(i + batch_size, unique_indices.size(0))
batch_indices = unique_indices[i:end_i]
batch_avg_features = avg_features[i:end_i]
# 当前批次的token解码
current_tokens_batch = self.memory_bank[batch_indices] # [batch_size, knowledge_length]
current_embeddings_batch = self.tok_embeddings(current_tokens_batch.view(-1)).view(
batch_indices.size(0), knowledge_length, dim) # [batch_size, knowledge_length, dim]
old_features_batch = current_embeddings_batch.view(batch_indices.size(0), -1) # [batch_size, knowledge_length * dim]
expanded_new_features = batch_avg_features.repeat(1, knowledge_length) # [batch_size, knowledge_length * dim]
# EMA更新new = γ * old + (1-γ) * new_avg
updated_features_batch = (
self.params.ema_decay * old_features_batch +
(1 - self.params.ema_decay) * expanded_new_features
)
# 分批编码为token_ids关键控制输出层的输入大小
updated_reshaped = updated_features_batch.view(-1, dim) # [batch_size * knowledge_length, dim]
logits_batch = self.output(updated_reshaped) # [batch_size * knowledge_length, vocab_size]
new_token_ids_batch = torch.argmax(logits_batch, dim=-1).view(batch_indices.size(0), knowledge_length)
# 🔥 新增: 应用冻结mask只更新未冻结的条目
# 检查哪些batch_indices对应的条目没有被冻结
unfrozen_mask_batch = ~self.freeze_mask[batch_indices] # [batch_size] - True表示未冻结
# 只更新未冻结的条目
if unfrozen_mask_batch.any():
unfrozen_indices = batch_indices[unfrozen_mask_batch]
unfrozen_tokens = new_token_ids_batch[unfrozen_mask_batch]
self.memory_bank[unfrozen_indices] = unfrozen_tokens
updated_memories += unfrozen_indices.size(0)
else:
# 如果这个batch中的所有条目都被冻结则跳过更新
pass
update_ratio = updated_memories / knowledge_num
# 🔥 新增: 计算冻结统计信息
frozen_count = self.freeze_mask.sum().item()
total_memories = knowledge_num
update_stats = {
'ema_update_applied': True,
'ema_step': self.ema_step_counter.item(),
'total_selections': total_selections,
'total_layers': total_layers,
'updated_memories': updated_memories,
'update_ratio': update_ratio,
'frozen_memories': frozen_count,
'frozen_ratio': frozen_count / total_memories,
'ema_decay': self.params.ema_decay,
'selected_memory_coverage': updated_memories / knowledge_num,
}
return update_stats