Minimind/model/model.py

1081 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
"""使用实数张量实现位置编码,避免使用复数张量
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量。
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度。
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)。
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
return pos_emb
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量。
原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。
这个函数直接使用实数运算实现相同的旋转操作。
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 将head_dim分成两半
half_head_dim = head_dim // 2
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
return xq_out.type_as(xq), xk_out.type_as(xk)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 如果提供了db_value根据头的数量调整它的形状并与xv合并
if db_value is not None:
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D]
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
# 检查是否需要调整D维度
if db_value.shape[-1] != xv.shape[-1]:
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.cross_att = CrossAttention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, db_value, pos_cis):
# 注意力计算
h_attn = self.attention(
self.attention_norm(x),
pos_cis,
db_value=db_value
)
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out
class ExtractDB(nn.Module):
def __init__(self, params, tok_embeddings=None):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowledge_length
# 智能负载均衡相关参数
self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True)
self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7)
self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3)
self.momentum = getattr(params, 'db_momentum', 0.9)
self.adaptive_weights = getattr(params, 'db_adaptive_weights', True)
# 动态权重调整参数
self.current_relevance_weight = 0.8 # 开始时更重视相关性
self.current_balance_weight = 0.2
self.weight_update_frequency = 100 # 每100步调整一次权重
self.step_counter = 0
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
self.register_buffer('total_queries', torch.tensor(0.0))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('weight_down_embed',
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
# 存储token embeddings的引用用于计算真实的语义相关性
self.tok_embeddings = tok_embeddings
def update_usage_statistics(self, selected_indices):
"""更新数据库条目的使用统计"""
if not self.training or not self.enable_intelligent_balance:
return
with torch.no_grad():
# 统计当前batch中每个条目的使用次数
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
batch_usage[unique_indices] = counts.float()
# 使用简单的tensor操作来更新统计
current_usage = self.usage_counts.clone()
current_total = self.total_queries.clone()
new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage
new_total = current_total + selected_indices.numel()
# 直接替换buffer内容
self.usage_counts.copy_(new_usage)
self.total_queries.copy_(new_total)
def update_dynamic_weights(self):
"""动态调整相关性和平衡权重"""
if not self.adaptive_weights or not self.training:
return
self.step_counter += 1
# 每隔一定步数调整权重
if self.step_counter % self.weight_update_frequency == 0:
with torch.no_grad():
if self.total_queries > 0:
# 计算使用分布的方差(不平衡程度)
usage_rates = self.usage_counts / self.total_queries
usage_variance = usage_rates.var().item()
# 根据不平衡程度调整权重
if usage_variance > 0.01: # 高度不平衡
self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1)
self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1)
elif usage_variance < 0.001: # 已经很平衡
self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1)
self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1)
# 确保权重和为1
total_weight = self.current_relevance_weight + self.current_balance_weight
self.current_relevance_weight /= total_weight
self.current_balance_weight /= total_weight
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if not self.enable_intelligent_balance or not self.training:
return all_scores
with torch.no_grad():
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 更新动态权重
self.update_dynamic_weights()
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.weight_down_embed[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
scores = all_scores[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 应用相关性阈值过滤
relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype)
mean_prob = relevance_probs.mean()
adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5)
relevant_mask = relevance_probs > adaptive_threshold
if relevant_mask.sum() == 0:
# 如果没有相关候选,选择相似度最高的
top_k = min(5, len(indices))
_, top_indices = similarity_scores.topk(top_k)
relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool)
relevant_mask[top_indices] = True
# 在相关候选中应用负载均衡
if relevant_mask.sum() > 1:
relevant_indices = indices[relevant_mask]
relevant_usage = self.usage_counts[relevant_indices]
# 计算平衡分数
balance_scores = 1.0 / (relevant_usage + 1.0)
balance_scores = balance_scores / (balance_scores.sum() + 1e-8)
# 相关性分数
relevant_rel_scores = relevance_probs[relevant_mask]
relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8)
# 综合分数
combined_scores = (self.current_relevance_weight * relevant_rel_scores +
self.current_balance_weight * balance_scores.to(dtype))
# 应用调整
adjustment = self.base_balance_strength * combined_scores.to(dtype)
enhanced_scores[batch_idx, relevant_mask] = scores[relevant_mask] + adjustment
return enhanced_scores.to(device)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 应用智能分层选择策略
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
# 6. 基于增强后的分数进行最终top-k选择
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
# 7. 更新使用统计
self.update_usage_statistics(flat_indices)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
# db_value = db_values.view(self.batch_size,-1)
return db_values
@torch.no_grad()
def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
# 先创建token embeddings
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
# 根据配置选择ExtractDB版本
# use_direct_semantic = getattr(params, 'use_direct_semantic', False)
# if use_direct_semantic:
# self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
# else:
# self.extract_db = ExtractDB(self.params, self.tok_embeddings)
self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight
self.database_output.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.params = params
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
h_list = []
for l, layer in enumerate(self.layers):
# 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer(
h, db_value, pos_cis_real
)
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
# Get features from v path - now we output embedding-dimension vectors
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
# self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# except:
# pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq = input_ids.shape[1], True
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
logits = out.logits[:, -1, :]
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
class ExtractDB_DirectSemantic(nn.Module):
"""直接语义匹配的数据库检索模块完全移除Product Key"""
def __init__(self, params, tok_embeddings=None):
super().__init__()
self.batch_size = None
self.dim = params.dim
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.tok_embeddings = tok_embeddings
self.num_experts_per_head_topk = 1
# 训练步数管理
self.current_step = 0
self.realtime_threshold = getattr(params, 'realtime_steps', 800) # 前800步实时计算
# 渐进式缓存策略参数
self.knowledge_update_rate = 0.01 # 每步更新1%的知识
self.knowledge_per_step = max(1, int(self.knowledge_num * self.knowledge_update_rate))
self.update_cycle = 100 # 100步循环
# 知识库存储
self.register_buffer('weight_down_embed',
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
# 嵌入缓存
self.knowledge_embeddings_cache = None
self.cache_update_mask = torch.zeros(self.knowledge_num, dtype=torch.bool) # 跟踪哪些已更新
# 归一化缓存(用于优化相似度计算)
self.normalized_knowledge_cache = None
self.normalization_valid = False
# 负载均衡组件
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
self.register_buffer('total_queries', torch.tensor(0.0))
self.momentum = getattr(params, 'db_momentum', 0.9)
self.balance_strength = getattr(params, 'db_balance_strength', 0.1)
def should_use_realtime_computation(self):
"""判断是否应该使用实时计算"""
return self.current_step < self.realtime_threshold
def get_knowledge_indices_to_update(self):
"""获取本步需要更新的知识条目索引"""
if self.should_use_realtime_computation():
# 前800步全部实时计算
return torch.arange(self.knowledge_num)
# 后续步数:循环更新策略
cycle_position = self.current_step % self.update_cycle
start_idx = (cycle_position * self.knowledge_per_step) % self.knowledge_num
end_idx = min(start_idx + self.knowledge_per_step, self.knowledge_num)
return torch.arange(start_idx, end_idx)
def update_knowledge_embeddings(self, force_all=False):
"""智能更新知识嵌入缓存"""
if force_all or self.should_use_realtime_computation():
# 全量更新
indices_to_update = torch.arange(self.knowledge_num)
else:
# 渐进式更新
indices_to_update = self.get_knowledge_indices_to_update()
if len(indices_to_update) == 0:
return
# 初始化缓存
if self.knowledge_embeddings_cache is None:
# 获取tok_embeddings的dtype确保类型一致
dummy_input = torch.zeros(1, dtype=torch.long, device=self.weight_down_embed.device)
dummy_embedding = self.tok_embeddings(dummy_input)
embedding_dtype = dummy_embedding.dtype
self.knowledge_embeddings_cache = torch.zeros(
self.knowledge_num, self.dim,
device=self.weight_down_embed.device,
dtype=embedding_dtype # 使用与tok_embeddings相同的dtype
)
with torch.no_grad():
# 只更新指定的知识条目
tokens_to_update = self.weight_down_embed[indices_to_update] # [num_update, knowledge_length]
flat_tokens = tokens_to_update.view(-1) # [num_update * knowledge_length]
# 批量计算嵌入
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_update * knowledge_length, dim]
# 重塑并平均池化
updated_embeddings = flat_embeddings.view(
len(indices_to_update), self.knowledge_length, -1
).mean(dim=1) # [num_update, dim]
# 更新缓存 - 现在类型应该匹配了
self.knowledge_embeddings_cache[indices_to_update] = updated_embeddings
self.cache_update_mask[indices_to_update] = True
# 使归一化缓存失效
self.normalization_valid = False
def get_normalized_knowledge_embeddings(self):
"""获取归一化的知识嵌入(用于优化相似度计算)"""
if not self.normalization_valid or self.normalized_knowledge_cache is None:
if self.knowledge_embeddings_cache is None:
self.update_knowledge_embeddings(force_all=True)
self.normalized_knowledge_cache = F.normalize(
self.knowledge_embeddings_cache, dim=-1
)
self.normalization_valid = True
return self.normalized_knowledge_cache
def optimized_similarity_computation(self, query_features):
"""优化的相似度计算"""
# 归一化查询特征
normalized_query = F.normalize(query_features, dim=-1) # [batch_size, dim]
# 获取归一化的知识嵌入
normalized_knowledge = self.get_normalized_knowledge_embeddings() # [knowledge_num, dim]
# 使用矩阵乘法计算余弦相似度
similarities = torch.mm(normalized_query, normalized_knowledge.t()) # [batch_size, knowledge_num]
return similarities
def apply_load_balancing(self, similarities):
"""应用负载均衡策略"""
if not self.training or self.total_queries == 0:
return similarities
# 计算使用频率
usage_rates = self.usage_counts / (self.total_queries + 1e-8)
# 创建平衡偏置(低频率条目获得正偏置)
max_usage = usage_rates.max()
balance_bias = self.balance_strength * (max_usage - usage_rates + 1e-8).log()
# 应用偏置
balanced_similarities = similarities + balance_bias.unsqueeze(0)
return balanced_similarities
def update_usage_statistics(self, selected_indices):
"""更新使用统计"""
if not self.training:
return
with torch.no_grad():
# 统计当前batch中每个条目的使用次数
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
batch_usage[unique_indices] = counts.float()
# 更新统计
self.usage_counts.copy_(
self.momentum * self.usage_counts + (1 - self.momentum) * batch_usage
)
self.total_queries.copy_(self.total_queries + selected_indices.numel())
def q_to_k(self, x):
"""直接语义检索的主方法"""
self.current_step += 1
batch_size, seq_len, dim = x.shape
# 智能更新知识嵌入缓存
self.update_knowledge_embeddings()
# 计算查询特征(序列平均)
query_features = x.mean(dim=1) # [batch_size, dim]
# 优化的相似度计算
similarities = self.optimized_similarity_computation(query_features)
# 应用负载均衡
balanced_similarities = self.apply_load_balancing(similarities)
# 选择top-k
_, indices = balanced_similarities.topk(self.num_experts_per_head_topk, dim=-1)
flat_indices = indices.view(-1)
# 更新使用统计
self.update_usage_statistics(flat_indices)
return flat_indices
def get_data(self, index):
"""获取数据,与原版本兼容"""
return self.weight_down_embed[index]
@torch.no_grad()
def updata_value(self, k, v):
"""更新数据,与原版本兼容"""
v_reshaped = v.view(v.size(0), -1)
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
# 标记相关缓存需要更新
self.cache_update_mask[k] = False
self.normalization_valid = False