419 lines
18 KiB
Python
419 lines
18 KiB
Python
import math
|
||
import struct
|
||
import inspect
|
||
import time
|
||
|
||
from .LMConfig import LMConfig
|
||
from typing import Any, Optional, Tuple, List, Union
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import nn
|
||
from transformers import PreTrainedModel
|
||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||
|
||
|
||
class RMSNorm(torch.nn.Module):
|
||
def __init__(self, dim: int, eps: float = 1e-6):
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def _norm(self, x):
|
||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
||
def forward(self, x):
|
||
return self.weight * self._norm(x.float()).type_as(x)
|
||
|
||
|
||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||
return pos_cis
|
||
|
||
|
||
def apply_rotary_emb(xq, xk, pos_cis):
|
||
def unite_shape(pos_cis, x):
|
||
ndim = x.ndim
|
||
assert 0 <= 1 < ndim
|
||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||
return pos_cis.view(*shape)
|
||
|
||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||
pos_cis = unite_shape(pos_cis, xq_)
|
||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||
|
||
|
||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||
bs, slen, n_kv_heads, head_dim = x.shape
|
||
if n_rep == 1:
|
||
return x
|
||
return (
|
||
x[:, :, :, None, :]
|
||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||
)
|
||
|
||
|
||
class Attention(nn.Module):
|
||
"""Self attention module without KV cache"""
|
||
def __init__(self, args: LMConfig):
|
||
super().__init__()
|
||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||
assert args.n_heads % self.n_kv_heads == 0
|
||
self.n_local_heads = args.n_heads
|
||
self.n_local_kv_heads = self.n_kv_heads
|
||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||
self.head_dim = args.dim // args.n_heads
|
||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||
self.attn_dropout = nn.Dropout(args.dropout)
|
||
self.resid_dropout = nn.Dropout(args.dropout)
|
||
self.dropout = args.dropout
|
||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||
mask = torch.triu(mask, diagonal=1)
|
||
self.register_buffer("mask", mask, persistent=False)
|
||
|
||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor):
|
||
"""Forward pass without KV cache"""
|
||
bsz, seq_len, _ = x.shape
|
||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||
|
||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||
|
||
# 注意:完全去除了KV cache相关代码
|
||
|
||
xq, xk, xv = (
|
||
xq.transpose(1, 2),
|
||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||
)
|
||
if self.flash and seq_len != 1:
|
||
dropout_p = self.dropout if self.training else 0.0
|
||
output = F.scaled_dot_product_attention(
|
||
xq, xk, xv,
|
||
attn_mask=None,
|
||
dropout_p=dropout_p,
|
||
is_causal=True
|
||
)
|
||
else:
|
||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||
scores = self.attn_dropout(scores)
|
||
output = scores @ xv
|
||
|
||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||
output = self.resid_dropout(self.wo(output))
|
||
return output
|
||
|
||
|
||
class MemoryGate(nn.Module):
|
||
"""Product Key Memory-based gate mechanism for memory selection"""
|
||
def __init__(self, config: LMConfig):
|
||
super().__init__()
|
||
self.config = config
|
||
self.dim = config.dim
|
||
self.knowledge_num = config.knowledge_num
|
||
self.knowledge_dim = config.knowledge_dim
|
||
self.num_selected = getattr(config, 'num_selected', 16)
|
||
|
||
# 确保知识库数量是完全平方数
|
||
assert int(self.knowledge_num ** 0.5) ** 2 == self.knowledge_num, \
|
||
f"knowledge_num ({self.knowledge_num}) must be a perfect square for product key memory"
|
||
|
||
self.num_keys = int(self.knowledge_num ** 0.5)
|
||
|
||
# 查询投影:将输入维度映射到knowledge_dim * 2(用于两个product key)
|
||
self.gate_proj = nn.Linear(self.dim, self.knowledge_dim, bias=False)
|
||
|
||
# Product Key Memory: 两个独立的键集合
|
||
self.keys = nn.Parameter(torch.randn(2, self.num_keys, self.knowledge_dim // 2))
|
||
|
||
self.dropout = nn.Dropout(config.dropout)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
"""
|
||
Args:
|
||
x: [batch_size, seq_len, dim]
|
||
Returns:
|
||
memory_indices: [batch_size, seq_len, num_selected]
|
||
memory_scores: [batch_size, seq_len, num_selected]
|
||
"""
|
||
bsz, seq_len, _ = x.shape
|
||
|
||
# 生成查询向量
|
||
queries = self.gate_proj(x) # [batch, seq_len, knowledge_dim]
|
||
|
||
# 分割为两部分用于product key
|
||
q1 = queries[:, :, :self.knowledge_dim // 2] # [batch, seq_len, knowledge_dim // 2]
|
||
q2 = queries[:, :, self.knowledge_dim // 2:] # [batch, seq_len, knowledge_dim // 2]
|
||
|
||
# 计算与两个键集合的相似度
|
||
scores_1 = torch.einsum('bsd,kd->bsk', q1, self.keys[0]) # [batch, seq_len, num_keys]
|
||
scores_2 = torch.einsum('bsd,kd->bsk', q2, self.keys[1]) # [batch, seq_len, num_keys]
|
||
|
||
# 获取top-k
|
||
topk_scores_1, topk_indices_1 = scores_1.topk(self.num_selected, dim=-1)
|
||
topk_scores_2, topk_indices_2 = scores_2.topk(self.num_selected, dim=-1)
|
||
|
||
# 组合product key的结果
|
||
combined_scores = topk_scores_1.unsqueeze(-1) + topk_scores_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
|
||
combined_indices = topk_indices_1.unsqueeze(-1) * self.num_keys + topk_indices_2.unsqueeze(-2) # [batch, seq_len, num_selected, num_selected]
|
||
|
||
# 展平并选择最终的top-k
|
||
combined_scores = combined_scores.view(bsz, seq_len, -1)
|
||
combined_indices = combined_indices.view(bsz, seq_len, -1)
|
||
|
||
final_scores, final_pk_indices = combined_scores.topk(self.num_selected, dim=-1)
|
||
memory_indices = combined_indices.gather(-1, final_pk_indices)
|
||
|
||
# 归一化分数
|
||
memory_scores = F.softmax(final_scores, dim=-1)
|
||
memory_scores = self.dropout(memory_scores)
|
||
|
||
return memory_indices, memory_scores
|
||
|
||
|
||
class CrossAttentionMemory(nn.Module):
|
||
"""Cross attention using selected memory as K and V"""
|
||
def __init__(self, config: LMConfig):
|
||
super().__init__()
|
||
self.config = config
|
||
self.n_heads = config.n_heads
|
||
self.head_dim = config.dim // config.n_heads
|
||
self.dim = config.dim
|
||
self.knowledge_dim = config.knowledge_dim
|
||
|
||
# Q从self-attention输出计算
|
||
self.wq = nn.Linear(config.dim, config.dim, bias=False)
|
||
|
||
# K,V从记忆数据计算
|
||
self.wk = nn.Linear(config.knowledge_dim, config.dim, bias=False)
|
||
self.wv = nn.Linear(config.knowledge_dim, config.dim, bias=False)
|
||
|
||
# 输出投影
|
||
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
||
self.dropout = nn.Dropout(config.dropout)
|
||
|
||
def forward(self, x: torch.Tensor, memory_data: torch.Tensor, memory_scores: torch.Tensor):
|
||
"""
|
||
Args:
|
||
x: [batch_size, seq_len, dim] - Query from self attention
|
||
memory_data: [batch_size, seq_len, num_selected, knowledge_dim] - Selected memory data
|
||
memory_scores: [batch_size, seq_len, num_selected] - Memory selection weights
|
||
Returns:
|
||
output: [batch_size, seq_len, dim]
|
||
"""
|
||
bsz, seq_len, _ = x.shape
|
||
num_selected = memory_data.shape[2]
|
||
|
||
# 计算Query
|
||
q = self.wq(x) # [batch, seq_len, dim]
|
||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [batch, n_heads, seq_len, head_dim]
|
||
|
||
# 对选中的记忆数据计算K和V
|
||
memory_flat = memory_data.view(bsz * seq_len * num_selected, self.knowledge_dim)
|
||
k_flat = self.wk(memory_flat) # [batch * seq_len * num_selected, dim]
|
||
v_flat = self.wv(memory_flat) # [batch * seq_len * num_selected, dim]
|
||
|
||
# 重塑K和V
|
||
k = k_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim]
|
||
v = v_flat.view(bsz, seq_len, num_selected, self.n_heads, self.head_dim).permute(0, 3, 1, 2, 4) # [batch, n_heads, seq_len, num_selected, head_dim]
|
||
|
||
# 扩展Q以匹配记忆维度进行交叉注意力
|
||
q_expanded = q.unsqueeze(3) # [batch, n_heads, seq_len, 1, head_dim]
|
||
|
||
# 计算注意力分数
|
||
# q_expanded: [batch, n_heads, seq_len, 1, head_dim]
|
||
# k: [batch, n_heads, seq_len, num_selected, head_dim]
|
||
scores = torch.matmul(q_expanded, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [batch, n_heads, seq_len, 1, num_selected]
|
||
scores = scores.squeeze(3) # [batch, n_heads, seq_len, num_selected]
|
||
|
||
# 应用记忆选择权重
|
||
memory_scores_expanded = memory_scores.unsqueeze(1).expand(-1, self.n_heads, -1, -1) # [batch, n_heads, seq_len, num_selected]
|
||
scores = scores + memory_scores_expanded.log() # 在log空间相加
|
||
|
||
# Softmax归一化
|
||
attn_weights = F.softmax(scores, dim=-1) # [batch, n_heads, seq_len, num_selected]
|
||
attn_weights = self.dropout(attn_weights)
|
||
|
||
# 应用注意力权重到V
|
||
# attn_weights: [batch, n_heads, seq_len, num_selected]
|
||
# v: [batch, n_heads, seq_len, num_selected, head_dim]
|
||
output = torch.einsum('bhlk,bhlkd->bhld', attn_weights, v) # [batch, n_heads, seq_len, head_dim]
|
||
|
||
# 重塑输出
|
||
output = output.transpose(1, 2).reshape(bsz, seq_len, self.dim) # [batch, seq_len, dim]
|
||
output = self.wo(output)
|
||
|
||
return output
|
||
|
||
|
||
class MiniMindBlock(nn.Module):
|
||
"""Transformer block with memory-based cross attention instead of FFN"""
|
||
def __init__(self, layer_id: int, config: LMConfig):
|
||
super().__init__()
|
||
self.n_heads = config.n_heads
|
||
self.dim = config.dim
|
||
self.head_dim = config.dim // config.n_heads
|
||
self.attention = Attention(config)
|
||
|
||
self.layer_id = layer_id
|
||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||
self.memory_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||
|
||
# 记忆相关模块
|
||
self.memory_gate = MemoryGate(config)
|
||
self.cross_attention_memory = CrossAttentionMemory(config)
|
||
|
||
def forward(self, x, pos_cis, memory_bank):
|
||
"""
|
||
Args:
|
||
x: [batch_size, seq_len, dim]
|
||
pos_cis: positional encoding
|
||
memory_bank: [knowledge_num, knowledge_dim] - shared memory bank
|
||
"""
|
||
# Self attention
|
||
h_attn = self.attention(self.attention_norm(x), pos_cis)
|
||
h = x + h_attn
|
||
|
||
# 使用h_attn作为门控和交叉注意力的输入(核心:self attention的输出)
|
||
h_for_memory = self.memory_norm(h_attn)
|
||
|
||
# 门控选择记忆
|
||
memory_indices, memory_scores = self.memory_gate(h_for_memory)
|
||
|
||
# 根据索引获取记忆数据
|
||
bsz, seq_len, num_selected = memory_indices.shape
|
||
memory_indices_flat = memory_indices.view(-1)
|
||
selected_memory = memory_bank[memory_indices_flat] # [batch * seq_len * num_selected, knowledge_dim]
|
||
selected_memory = selected_memory.view(bsz, seq_len, num_selected, -1) # [batch, seq_len, num_selected, knowledge_dim]
|
||
|
||
# 交叉注意力:Q来自h_attn,K和V来自选中的记忆
|
||
memory_output = self.cross_attention_memory(h_for_memory, selected_memory, memory_scores)
|
||
|
||
# 残差连接
|
||
out = h + memory_output
|
||
|
||
return out
|
||
|
||
|
||
class MiniMindLM(PreTrainedModel):
|
||
config_class = LMConfig
|
||
|
||
def __init__(self, params: LMConfig = None):
|
||
self.params = params or LMConfig()
|
||
super().__init__(self.params)
|
||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||
self.dropout = nn.Dropout(params.dropout)
|
||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||
self.tok_embeddings.weight = self.output.weight
|
||
self.register_buffer("pos_cis",
|
||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||
persistent=False)
|
||
|
||
# 初始化共享记忆库
|
||
self.memory_bank = nn.Parameter(
|
||
torch.randn(params.knowledge_num, params.knowledge_dim),
|
||
requires_grad=True
|
||
)
|
||
|
||
self.OUT = CausalLMOutputWithPast()
|
||
|
||
def forward(self,
|
||
input_ids: Optional[torch.Tensor] = None,
|
||
**args):
|
||
"""Forward pass without KV cache support"""
|
||
start_pos = args.get('start_pos', 0)
|
||
h = self.dropout(self.tok_embeddings(input_ids))
|
||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||
|
||
for layer in self.layers:
|
||
h = layer(h, pos_cis, self.memory_bank)
|
||
|
||
logits = self.output(self.norm(h))
|
||
|
||
# 统一不使用 aux_loss
|
||
aux_loss = 0
|
||
self.OUT.__setitem__('last_hidden_state', h)
|
||
self.OUT.__setitem__('logits', logits)
|
||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||
self.OUT.__setitem__('past_key_values', None) # 不支持KV cache
|
||
return self.OUT
|
||
|
||
@torch.inference_mode()
|
||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||
"""Generate without KV cache"""
|
||
# 流式生成
|
||
if stream:
|
||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||
|
||
# 直接生成
|
||
generated = []
|
||
for i in range(input_ids.size(0)):
|
||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||
for _ in range(num_return_sequences):
|
||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||
generated.append(full_sequence)
|
||
|
||
max_length = max(seq.size(1) for seq in generated)
|
||
generated = [
|
||
torch.cat(
|
||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||
dim=-1)
|
||
for seq in generated
|
||
]
|
||
output = torch.cat(generated, dim=0)
|
||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||
return res
|
||
|
||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||
"""Stream generation without KV cache - regenerates full sequence each time"""
|
||
start = input_ids.shape[1]
|
||
while input_ids.shape[1] < start + max_new_tokens:
|
||
# 每次都重新计算整个序列(因为没有KV cache)
|
||
out = self(input_ids, **args)
|
||
logits = out.logits[:, -1, :]
|
||
|
||
# 重复惩罚
|
||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||
logits /= (temperature + 1e-9)
|
||
|
||
# Top-p采样
|
||
if top_p is not None and top_p < 1.0:
|
||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||
sorted_indices_to_remove = cumulative_probs > top_p
|
||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||
sorted_indices_to_remove[:, 0] = False
|
||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||
logits[indices_to_remove] = -float('Inf')
|
||
|
||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||
yield input_ids[:, start:]
|
||
if input_ids_next.item() == eos_token_id:
|
||
break |