Minimind/model/model_memory_1_4_7.py
2025-08-20 13:46:42 +08:00

749 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
"""Self attention module without KV cache"""
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor):
"""Forward pass without KV cache"""
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 注意完全去除了KV cache相关代码
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class MemoryGate(nn.Module):
"""Product Key Memory-based gate mechanism for memory selection"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_num = config.knowledge_num
self.knowledge_dim = config.knowledge_dim
self.num_selected = getattr(config, 'num_selected', 16)
# 确保知识库数量是完全平方数
assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \
f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory"
self.num_keys = int(self.knowledge_num ** 0.5)
# 查询投影将输入维度映射到knowledge_dim * 2用于两个product key
self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False)
# Product Key Memory: 两个独立的键集合
self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2))
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor):
"""
Args:
x: [batch_size, seq_len, dim]
Returns:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
balance_loss: 平衡损失KL散度 + 基尼系数)
stats: 监控统计信息字典
"""
bsz, seq_len, _ = x.shape
# 生成查询向量
queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim]
# 分割为两部分用于product key
q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2]
q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2]
# 计算与两个键集合的相似度
scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys]
scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys]
# 获取top-k
topk_scores_1, topk_indices_1 = scores_1.topk(self.num_selected, dim=-1)
topk_scores_2, topk_indices_2 = scores_2.topk(self.num_selected, dim=-1)
# 组合product key的结果
combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
# 展平并选择最终的top-k
combined_scores = combined_scores.view(bsz, seq_len, -1)
combined_indices = combined_indices.view(bsz, seq_len, -1)
final_scores, final_pk_indices = combined_scores.topk(self.num_selected, dim=-1)
memory_indices = combined_indices.gather(-1, final_pk_indices)
# 归一化分数
memory_scores = F.softmax(final_scores, dim=-1)
memory_scores = self.dropout(memory_scores)
# 计算平衡损失和监控统计
balance_loss, stats = self._compute_balance_loss_and_stats(memory_indices, memory_scores)
return memory_indices, memory_scores, balance_loss, stats
def _compute_balance_loss_and_stats(self, memory_indices, memory_scores):
"""
计算平衡损失和监控统计信息
Args:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
Returns:
balance_loss: 标量张量
stats: 统计信息字典
"""
bsz, seq_len, num_selected = memory_indices.shape
device = memory_indices.device
# 1. 计算记忆选择分布
# 将所有选择的记忆索引展平
flat_indices = memory_indices.view(-1) # [batch_size * seq_len * num_selected]
# 统计每个记忆条目被选中的次数
memory_counts = torch.zeros(self.knowledge_num, device=device)
memory_counts.scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float))
# 计算选择概率分布
total_selections = bsz * seq_len * num_selected
memory_probs = memory_counts / total_selections
# 2. 计算KL散度损失与均匀分布的KL散度
uniform_prob = 1.0 / self.knowledge_num
# 避免log(0)的问题
memory_probs_safe = memory_probs + 1e-10
kl_loss = F.kl_div(
torch.log(memory_probs_safe),
torch.full_like(memory_probs, uniform_prob),
reduction='sum'
)
# 3. 计算基尼系数损失(衡量分布不平等程度)
sorted_probs, _ = torch.sort(memory_probs)
n = self.knowledge_num
index = torch.arange(1, n + 1, device=device, dtype=torch.float)
gini_coeff = (2 * torch.sum(index * sorted_probs) / (n * torch.sum(sorted_probs))) - (n + 1) / n
gini_loss = gini_coeff # 基尼系数越大,分布越不均匀
# 4. 组合平衡损失
balance_loss = 0.5 * kl_loss + 0.5 * gini_loss
# 5. 计算监控统计信息
with torch.no_grad():
# 记忆覆盖率:被选中的记忆条目占总数的比例
coverage_rate = (memory_counts > 0).float().mean().item()
# 热点记忆选择次数前10%的记忆条目
top10_threshold = torch.quantile(memory_counts, 0.9)
hot_memories = (memory_counts >= top10_threshold).sum().item()
# 死记忆:从未被选中的记忆条目
dead_memories = (memory_counts == 0).sum().item()
# 记忆选择方差(衡量不平衡程度)
selection_variance = memory_counts.var().item()
stats = {
'gini_coefficient': gini_coeff.item(),
'kl_divergence': kl_loss.item(),
'coverage_rate': coverage_rate,
'hot_memories': hot_memories,
'dead_memories': dead_memories,
'selection_variance': selection_variance,
'max_selections': memory_counts.max().item(),
'min_selections': memory_counts.min().item(),
}
return balance_loss, stats
class GatedMemoryFusion(nn.Module):
"""Gated MLP fusion for concatenated h_attn and selected memories"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_dim = config.knowledge_dim
self.num_selected = getattr(config, 'num_selected', 16)
# 输入维度dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
# 实验1.4.6记忆解码后立即压缩回knowledge_dim避免显存爆炸
concat_dim = self.dim + self.num_selected * self.knowledge_dim
# 类似SwiGLU的门控MLP结构
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.up_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.down_proj = nn.Linear(self.dim, self.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, h_attn: torch.Tensor, selected_memories: torch.Tensor, memory_scores: torch.Tensor):
"""
Args:
h_attn: [batch_size, seq_len, dim] - Self attention output
selected_memories: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data
memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights (not used in concatenation approach)
Returns:
output: [batch_size, seq_len, dim]
"""
bsz, seq_len, _ = h_attn.shape
# 将选中的记忆展平为一维向量
# [batch, seq_len, num_selected, knowledge_dim] -> [batch, seq_len, num_selected * knowledge_dim]
memory_flat = selected_memories.reshape(bsz, seq_len, -1)
# 拼接h_attn和记忆信息
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim]
# 门控MLP处理类似SwiGLU
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
up = self.up_proj(concat_input) # [batch, seq_len, dim]
fusion_output = gate * up # Element-wise multiplication
# 输出投影
output = self.down_proj(fusion_output) # [batch, seq_len, dim]
output = self.dropout(output)
return output
class MiniMindBlock(nn.Module):
"""Transformer block with memory-based cross attention instead of FFN"""
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.config = config # 保存config引用
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 记忆相关模块
self.memory_gate = MemoryGate(config)
self.gated_memory_fusion = GatedMemoryFusion(config)
def forward(self, x, pos_cis, memory_bank, tok_embeddings, collect_ema_stats=False):
"""
Args:
x: [batch_size, seq_len, dim]
pos_cis: positional encoding
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
collect_ema_stats: 是否收集EMA更新统计信息
Returns:
out: [batch_size, seq_len, dim]
balance_loss: 该层的平衡损失
layer_stats: 该层的监控统计信息
ema_stats: EMA更新统计信息如果collect_ema_stats=True
"""
# Self attention
h_attn = self.attention(self.attention_norm(x), pos_cis)
h = x + h_attn
# 使用h_attn作为门控和交叉注意力的输入核心self attention的输出
h_for_memory = self.memory_norm(h_attn)
# 门控选择记忆
memory_indices, memory_scores, balance_loss, layer_stats = self.memory_gate(h_for_memory)
# 根据索引获取记忆数据 - 实验1.4.6解码token_id为特征向量
bsz, seq_len, num_selected = memory_indices.shape
memory_indices_flat = memory_indices.view(-1)
selected_token_ids = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_length]
# 解码token_ids为特征向量并立即压缩避免显存爆炸
selected_embeddings = tok_embeddings(selected_token_ids) # [batch * seq_len * num_selected, knowledge_length, dim]
knowledge_length = selected_token_ids.size(-1)
# 立即压缩knowledge_length * dim -> knowledge_dim 避免显存爆炸
# 使用平均池化压缩knowledge_length维度
pooled_memory = selected_embeddings.mean(dim=1) # [batch * seq_len * num_selected, dim]
# 投影到knowledge_dim维度
if self.dim > self.config.knowledge_dim:
# 截断到knowledge_dim
compressed_memory = pooled_memory[:, :self.config.knowledge_dim]
elif self.dim < self.config.knowledge_dim:
# 填充到knowledge_dim
pad_size = self.config.knowledge_dim - self.dim
compressed_memory = F.pad(pooled_memory, (0, pad_size), 'constant', 0)
else:
compressed_memory = pooled_memory
selected_memory = compressed_memory.view(bsz, seq_len, num_selected, self.config.knowledge_dim) # [batch, seq_len, num_selected, knowledge_dim]
# 门控MLP融合串型连接h_attn和选中的记忆
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
# 残差连接
out = h + memory_output
# 收集EMA更新统计信息仅在训练时且启用时
ema_stats = None
if collect_ema_stats and self.training:
ema_stats = {
'memory_indices': memory_indices, # [batch, seq_len, num_selected]
'memory_scores': memory_scores, # [batch, seq_len, num_selected]
'h_for_memory': h_for_memory, # [batch, seq_len, dim]
'selected_memory': selected_memory, # [batch, seq_len, num_selected, knowledge_dim]
}
if collect_ema_stats:
return out, balance_loss, layer_stats, ema_stats
else:
return out, balance_loss, layer_stats
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
# 初始化共享记忆库 - 实验1.4.6存储token_id而非特征向量
# VQ-VAE风格memory_bank作为codebook使用EMA更新而非梯度更新
if params.use_ema_update:
self.memory_bank = nn.Parameter(
torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)),
requires_grad=False # 禁用梯度更新使用EMA更新
)
else:
self.memory_bank = nn.Parameter(
torch.randint(0, params.vocab_size, (params.knowledge_num, params.knowledge_length)),
requires_grad=True # 传统梯度更新
)
# EMA更新相关缓冲区
if params.use_ema_update:
# 记录每个memory条目的更新统计
self.register_buffer('ema_update_count', torch.zeros(params.knowledge_num), persistent=False)
# 注意现在memory_bank存储token_id但EMA在特征空间进行所以不需要sum_buffer了
# self.register_buffer('ema_sum_buffer', torch.zeros_like(self.memory_bank), persistent=False)
# EMA更新频率计数器
self.register_buffer('ema_step_counter', torch.zeros(1, dtype=torch.long), persistent=False)
# 记录上一步的记忆库状态,用于计算更新统计
self.register_buffer('prev_memory_bank', torch.zeros_like(self.memory_bank), persistent=False)
# 🔥 新增: 冻结mask - 标记哪些memory_bank条目被冻结不更新
if params.freeze_ratio > 0.0:
freeze_num = int(params.knowledge_num * params.freeze_ratio)
freeze_mask = torch.zeros(params.knowledge_num, dtype=torch.bool)
# 随机选择要冻结的条目
freeze_indices = torch.randperm(params.knowledge_num)[:freeze_num]
freeze_mask[freeze_indices] = True
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen")
else:
self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False)
print(f"🔥 Memory bank freezing disabled: all entries can be updated")
self.OUT = CausalLMOutputWithPast()
def get_memory_update_stats(self):
"""
计算记忆库更新统计信息
Returns:
update_stats: 包含更新统计的字典
"""
with torch.no_grad():
if hasattr(self, 'prev_memory_bank') and self.prev_memory_bank.numel() > 0:
# 计算L2距离变化
l2_distance = torch.norm(self.memory_bank - self.prev_memory_bank, p=2, dim=-1)
avg_l2_distance = l2_distance.mean().item()
max_l2_distance = l2_distance.max().item()
# 计算余弦相似度
cos_sim = F.cosine_similarity(
self.memory_bank.view(-1),
self.prev_memory_bank.view(-1),
dim=0
).item()
# 计算更新率(发生显著变化的记忆条目比例)
threshold = 0.01 # 更新阈值
updated_memories = (l2_distance > threshold).sum().item()
update_rate = updated_memories / self.memory_bank.size(0)
update_stats = {
'memory_avg_l2_change': avg_l2_distance,
'memory_max_l2_change': max_l2_distance,
'memory_cosine_similarity': cos_sim,
'memory_update_rate': update_rate,
'memory_updated_count': updated_memories
}
else:
# 第一次调用时的默认值
update_stats = {
'memory_avg_l2_change': 0.0,
'memory_max_l2_change': 0.0,
'memory_cosine_similarity': 1.0,
'memory_update_rate': 0.0,
'memory_updated_count': 0
}
# 更新prev_memory_bank
self.prev_memory_bank.copy_(self.memory_bank)
return update_stats
def forward(self,
input_ids: Optional[torch.Tensor] = None,
**args):
"""Forward pass without KV cache support"""
start_pos = args.get('start_pos', 0)
collect_ema_stats = args.get('collect_ema_stats', self.params.use_ema_update and self.training)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
# 收集所有层的平衡损失和统计信息
total_balance_loss = 0
all_layer_stats = {}
all_ema_stats = {}
for layer_idx, layer in enumerate(self.layers):
if collect_ema_stats:
h, balance_loss, layer_stats, ema_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=True)
all_ema_stats[f'layer_{layer_idx}'] = ema_stats
else:
h, balance_loss, layer_stats = layer(h, pos_cis, self.memory_bank, self.tok_embeddings, collect_ema_stats=False)
total_balance_loss += balance_loss
# 为每层的统计信息添加前缀
for key, value in layer_stats.items():
all_layer_stats[f'layer_{layer_idx}_{key}'] = value
logits = self.output(self.norm(h))
# 使用总的平衡损失作为aux_loss
aux_loss = total_balance_loss
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('layer_stats', all_layer_stats) # 添加层级统计信息
self.OUT.__setitem__('ema_stats', all_ema_stats if collect_ema_stats else None) # 添加EMA统计信息
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
"""Generate without KV cache"""
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
"""Stream generation without KV cache - regenerates full sequence each time"""
start = input_ids.shape[1]
while input_ids.shape[1] < start + max_new_tokens:
# 每次都重新计算整个序列因为没有KV cache
out = self(input_ids, **args)
logits = out.logits[:, -1, :]
# 重复惩罚
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
def apply_ema_update(self, ema_stats):
"""
应用token-based EMA更新到memory_bank
实验1.4.6批量化tensor操作优化版本
Args:
ema_stats: 从forward pass收集的EMA统计信息格式为
{'layer_0': {'memory_indices': ..., 'h_for_memory': ...}, 'layer_1': ...}
"""
if not self.params.use_ema_update:
return {}
# 增加EMA步数计数器
self.ema_step_counter += 1
# 检查是否需要进行EMA更新
if self.ema_step_counter % self.params.ema_update_freq != 0:
return {'ema_update_applied': False, 'reason': 'frequency_check_failed'}
with torch.no_grad():
device = self.memory_bank.device
knowledge_num, knowledge_length = self.memory_bank.shape
dim = self.params.dim
# 🚀 批量收集所有层的数据(避免字典操作)
all_indices = []
all_features = []
total_selections = 0
total_layers = 0
# 收集所有层的EMA统计信息
for layer_ema_stats in ema_stats.values():
if layer_ema_stats is None:
continue
total_layers += 1
memory_indices = layer_ema_stats['memory_indices'] # [batch, seq_len, num_selected]
h_for_memory = layer_ema_stats['h_for_memory'] # [batch, seq_len, dim]
bsz, seq_len, num_selected = memory_indices.shape
total_selections += bsz * seq_len * num_selected
# 展平索引和对应的h_for_memory
flat_indices = memory_indices.view(-1) # [batch * seq_len * num_selected]
# 为每个选择位置复制对应的h_for_memory
h_expanded = h_for_memory.unsqueeze(2).expand(-1, -1, num_selected, -1) # [batch, seq_len, num_selected, dim]
flat_h = h_expanded.reshape(-1, dim) # [batch * seq_len * num_selected, dim]
all_indices.append(flat_indices)
all_features.append(flat_h)
if not all_indices:
return {'ema_update_applied': False, 'reason': 'no_ema_stats'}
# 🚀 合并所有数据
all_indices = torch.cat(all_indices, dim=0) # [total_selections]
all_features = torch.cat(all_features, dim=0) # [total_selections, dim]
# 🚀 批量计算每个memory的平均特征避免循环
unique_indices, inverse_indices = torch.unique(all_indices, return_inverse=True)
# 使用scatter_add批量聚合确保数据类型一致
aggregated_features = torch.zeros(unique_indices.size(0), dim, device=device, dtype=all_features.dtype)
count_per_memory = torch.zeros(unique_indices.size(0), device=device, dtype=all_features.dtype)
aggregated_features.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, dim), all_features)
count_per_memory.scatter_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=all_features.dtype))
# 计算平均值
avg_features = aggregated_features / count_per_memory.unsqueeze(1) # [unique_count, dim]
# 🚀 分批EMA更新控制显存使用
batch_size = 4096 # 每批处理4096个memory控制显存
updated_memories = 0
for i in range(0, unique_indices.size(0), batch_size):
end_i = min(i + batch_size, unique_indices.size(0))
batch_indices = unique_indices[i:end_i]
batch_avg_features = avg_features[i:end_i]
# 当前批次的token解码
current_tokens_batch = self.memory_bank[batch_indices] # [batch_size, knowledge_length]
current_embeddings_batch = self.tok_embeddings(current_tokens_batch.view(-1)).view(
batch_indices.size(0), knowledge_length, dim) # [batch_size, knowledge_length, dim]
old_features_batch = current_embeddings_batch.view(batch_indices.size(0), -1) # [batch_size, knowledge_length * dim]
expanded_new_features = batch_avg_features.repeat(1, knowledge_length) # [batch_size, knowledge_length * dim]
# EMA更新new = γ * old + (1-γ) * new_avg
updated_features_batch = (
self.params.ema_decay * old_features_batch +
(1 - self.params.ema_decay) * expanded_new_features
)
# 分批编码为token_ids关键控制输出层的输入大小
updated_reshaped = updated_features_batch.view(-1, dim) # [batch_size * knowledge_length, dim]
logits_batch = self.output(updated_reshaped) # [batch_size * knowledge_length, vocab_size]
new_token_ids_batch = torch.argmax(logits_batch, dim=-1).view(batch_indices.size(0), knowledge_length)
# 🔥 新增: 应用冻结mask只更新未冻结的条目
# 检查哪些batch_indices对应的条目没有被冻结
unfrozen_mask_batch = ~self.freeze_mask[batch_indices] # [batch_size] - True表示未冻结
# 只更新未冻结的条目
if unfrozen_mask_batch.any():
unfrozen_indices = batch_indices[unfrozen_mask_batch]
unfrozen_tokens = new_token_ids_batch[unfrozen_mask_batch]
self.memory_bank[unfrozen_indices] = unfrozen_tokens
updated_memories += unfrozen_indices.size(0)
else:
# 如果这个batch中的所有条目都被冻结则跳过更新
pass
update_ratio = updated_memories / knowledge_num
# 🔥 新增: 计算冻结统计信息
frozen_count = self.freeze_mask.sum().item()
total_memories = knowledge_num
update_stats = {
'ema_update_applied': True,
'ema_step': self.ema_step_counter.item(),
'total_selections': total_selections,
'total_layers': total_layers,
'updated_memories': updated_memories,
'update_ratio': update_ratio,
'frozen_memories': frozen_count,
'frozen_ratio': frozen_count / total_memories,
'ema_decay': self.params.ema_decay,
'selected_memory_coverage': updated_memories / knowledge_num,
}
return update_stats