Minimind/model/model_memory.py
Yu Chengzhang 57d6d768e1 Experiment 1.4.2: 门控MLP融合串型连接验证连接方式对记忆库性能的影响
## 实验目标
验证连接方式是否是导致实验1.4.1性能下降的主要原因,通过将跳接(交叉注意力)
改为串型连接(门控MLP融合)来测试记忆库机制的有效性。

## 核心改进
- 保留Product Key Memory记忆选择机制
- 使用串型连接替代跳接连接
- 门控MLP融合替代交叉注意力
- 拼接h_attn和选中记忆进行处理

## 实验结果
- 训练Loss: 2.75 (vs 1.4.1的2.84, 1.4.0的2.43)
- 评估Loss: 2.33 (vs 1.4.1的7.68, 1.4.0的1.99)
- 生成质量: 6.2/10 (vs 1.4.1的2.0/10, 1.4.0的7.5/10)
- 训练时间: 15.4小时,GPU内存: ~22GB

## 关键发现
 连接方式确实是性能差异的关键因素
 门控MLP融合显著优于交叉注意力
 记忆库机制本身可行,但需要优化记忆质量

## 技术实现
- 实现GatedMemoryFusion类替代CrossAttentionMemory
- 使用类SwiGLU的门控MLP结构
- 拼接输入维度: dim + num_selected * knowledge_dim
- 门控激活函数: SiLU + 元素级乘法

## 文件变更
- model/model_memory.py: 实现门控MLP融合机制
- run_file/experiment_1_4_2.sh: 实验执行脚本
- experiment/EXPERIMENT_1_4_2.md: 完整实验记录和分析

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-04 20:12:00 +08:00

393 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
"""Self attention module without KV cache"""
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor):
"""Forward pass without KV cache"""
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 注意完全去除了KV cache相关代码
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class MemoryGate(nn.Module):
"""Product Key Memory-based gate mechanism for memory selection"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_num = config.knowledge_num
self.knowledge_dim = config.knowledge_dim
self.num_selected = getattr(config, 'num_selected', 16)
# 确保知识库数量是完全平方数
assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \
f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory"
self.num_keys = int(self.knowledge_num ** 0.5)
# 查询投影将输入维度映射到knowledge_dim * 2用于两个product key
self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False)
# Product Key Memory: 两个独立的键集合
self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2))
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor):
"""
Args:
x: [batch_size, seq_len, dim]
Returns:
memory_indices: [batch_size, seq_len, num_selected]
memory_scores: [batch_size, seq_len, num_selected]
"""
bsz, seq_len, _ = x.shape
# 生成查询向量
queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim]
# 分割为两部分用于product key
q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2]
q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2]
# 计算与两个键集合的相似度
scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys]
scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys]
# 获取top-k
topk_scores_1, topk_indices_1 = scores_1.topk(self.num_selected, dim=-1)
topk_scores_2, topk_indices_2 = scores_2.topk(self.num_selected, dim=-1)
# 组合product key的结果
combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
# 展平并选择最终的top-k
combined_scores = combined_scores.view(bsz, seq_len, -1)
combined_indices = combined_indices.view(bsz, seq_len, -1)
final_scores, final_pk_indices = combined_scores.topk(self.num_selected, dim=-1)
memory_indices = combined_indices.gather(-1, final_pk_indices)
# 归一化分数
memory_scores = F.softmax(final_scores, dim=-1)
memory_scores = self.dropout(memory_scores)
return memory_indices, memory_scores
class GatedMemoryFusion(nn.Module):
"""Gated MLP fusion for concatenated h_attn and selected memories"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.dim = config.dim
self.knowledge_dim = config.knowledge_dim
self.num_selected = getattr(config, 'num_selected', 16)
# 输入维度dim (h_attn) + num_selected * knowledge_dim (选中的记忆)
concat_dim = self.dim + self.num_selected * self.knowledge_dim
# 类似SwiGLU的门控MLP结构
self.gate_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.up_proj = nn.Linear(concat_dim, self.dim, bias=False)
self.down_proj = nn.Linear(self.dim, self.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, h_attn: torch.Tensor, selected_memories: torch.Tensor, memory_scores: torch.Tensor):
"""
Args:
h_attn: [batch_size, seq_len, dim] - Self attention output
selected_memories: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data
memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights (not used in concatenation approach)
Returns:
output: [batch_size, seq_len, dim]
"""
bsz, seq_len, _ = h_attn.shape
# 将选中的记忆展平为一维向量
# [batch, seq_len, num_selected, knowledge_dim] -> [batch, seq_len, num_selected * knowledge_dim]
memory_flat = selected_memories.view(bsz, seq_len, -1)
# 拼接h_attn和记忆信息
concat_input = torch.cat([h_attn, memory_flat], dim=-1) # [batch, seq_len, dim + num_selected * knowledge_dim]
# 门控MLP处理类似SwiGLU
gate = F.silu(self.gate_proj(concat_input)) # [batch, seq_len, dim]
up = self.up_proj(concat_input) # [batch, seq_len, dim]
fusion_output = gate * up # Element-wise multiplication
# 输出投影
output = self.down_proj(fusion_output) # [batch, seq_len, dim]
output = self.dropout(output)
return output
class MiniMindBlock(nn.Module):
"""Transformer block with memory-based cross attention instead of FFN"""
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 记忆相关模块
self.memory_gate = MemoryGate(config)
self.gated_memory_fusion = GatedMemoryFusion(config)
def forward(self, x, pos_cis, memory_bank):
"""
Args:
x: [batch_size, seq_len, dim]
pos_cis: positional encoding
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
"""
# Self attention
h_attn = self.attention(self.attention_norm(x), pos_cis)
h = x + h_attn
# 使用h_attn作为门控和交叉注意力的输入核心self attention的输出
h_for_memory = self.memory_norm(h_attn)
# 门控选择记忆
memory_indices, memory_scores = self.memory_gate(h_for_memory)
# 根据索引获取记忆数据
bsz, seq_len, num_selected = memory_indices.shape
memory_indices_flat = memory_indices.view(-1)
selected_memory = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_dim]
selected_memory = selected_memory.view(bsz, seq_len, num_selected, -1) # [batch, seq_len, num_selected, knowledge_dim]
# 门控MLP融合串型连接h_attn和选中的记忆
memory_output = self.gated_memory_fusion(h_for_memory, selected_memory, memory_scores)
# 残差连接
out = h + memory_output
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
# 初始化共享记忆库
self.memory_bank = nn.Parameter(
torch.randn(params.knowledge_num, params.knowledge_dim),
requires_grad=True
)
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
**args):
"""Forward pass without KV cache support"""
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for layer in self.layers:
h = layer(h, pos_cis, self.memory_bank)
logits = self.output(self.norm(h))
# 统一不使用 aux_loss
aux_loss = 0
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
"""Generate without KV cache"""
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
"""Stream generation without KV cache - regenerates full sequence each time"""
start = input_ids.shape[1]
while input_ids.shape[1] < start + max_new_tokens:
# 每次都重新计算整个序列因为没有KV cache
out = self(input_ids, **args)
logits = out.logits[:, -1, :]
# 重复惩罚
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break