Minimind/model/model.py

881 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
"""使用实数张量实现位置编码,避免使用复数张量
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量。
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度。
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)。
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
return pos_emb
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量。
原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。
这个函数直接使用实数运算实现相同的旋转操作。
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 将head_dim分成两半
half_head_dim = head_dim // 2
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
return xq_out.type_as(xq), xk_out.type_as(xk)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 如果提供了db_value根据头的数量调整它的形状并与xv合并
if db_value is not None:
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D]
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
# 检查是否需要调整D维度
if db_value.shape[-1] != xv.shape[-1]:
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.cross_att = CrossAttention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, db_value, pos_cis):
# 注意力计算
h_attn = self.attention(
self.attention_norm(x),
pos_cis,
db_value=db_value
)
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out
class ExtractDB(nn.Module):
def __init__(self, params, tok_embeddings=None):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowledge_length
# 智能负载均衡相关参数
self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True)
self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7)
self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3)
self.momentum = getattr(params, 'db_momentum', 0.9)
self.adaptive_weights = getattr(params, 'db_adaptive_weights', True)
# 动态权重调整参数
self.current_relevance_weight = 0.8 # 开始时更重视相关性
self.current_balance_weight = 0.2
self.weight_update_frequency = 100 # 每100步调整一次权重
self.step_counter = 0
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
self.register_buffer('total_queries', torch.tensor(0.0))
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('weight_down_embed',
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
# 存储token embeddings的引用用于计算真实的语义相关性
self.tok_embeddings = tok_embeddings
def update_usage_statistics(self, selected_indices):
"""更新数据库条目的使用统计"""
if not self.training or not self.enable_intelligent_balance:
return
with torch.no_grad():
# 统计当前batch中每个条目的使用次数
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
batch_usage[unique_indices] = counts.float()
# 使用简单的tensor操作来更新统计
current_usage = self.usage_counts.clone()
current_total = self.total_queries.clone()
new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage
new_total = current_total + selected_indices.numel()
# 直接替换buffer内容
self.usage_counts.copy_(new_usage)
self.total_queries.copy_(new_total)
def update_dynamic_weights(self):
"""动态调整相关性和平衡权重"""
if not self.adaptive_weights or not self.training:
return
self.step_counter += 1
# 每隔一定步数调整权重
if self.step_counter % self.weight_update_frequency == 0:
with torch.no_grad():
if self.total_queries > 0:
# 计算使用分布的方差(不平衡程度)
usage_rates = self.usage_counts / self.total_queries
usage_variance = usage_rates.var().item()
# 根据不平衡程度调整权重
if usage_variance > 0.01: # 高度不平衡
self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1)
self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1)
elif usage_variance < 0.001: # 已经很平衡
self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1)
self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1)
# 确保权重和为1
total_weight = self.current_relevance_weight + self.current_balance_weight
self.current_relevance_weight /= total_weight
self.current_balance_weight /= total_weight
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if not self.enable_intelligent_balance or not self.training:
# 如果禁用智能平衡或在推理模式,使用原始分数
return all_scores
with torch.no_grad():
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 更新动态权重
self.update_dynamic_weights()
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
# 预先计算query的特征表示取平均
query_features = query.mean(dim=1) # [batch_size, dim]
for batch_idx in range(batch_size):
indices = all_indices[batch_idx] # 当前batch的候选条目
scores = all_scores[batch_idx] # 当前batch的原始分数
# 第一层基于value内容计算真正的相关性
# 1. 获取候选条目的value tokens只获取当前需要的
candidate_tokens = self.weight_down_embed[indices] # [num_candidates, knowledge_length]
# 2. 高效计算直接使用embedding层避免中间变量
# 将tokens reshape为一维批量计算embeddings然后reshape回来
num_candidates, knowledge_length = candidate_tokens.shape
flat_tokens = candidate_tokens.view(-1) # [num_candidates * knowledge_length]
# 批量计算所有token的embeddings
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_candidates * knowledge_length, dim]
# Reshape回原始形状并进行mean pooling
candidate_embeddings = flat_embeddings.view(num_candidates, knowledge_length, -1)
candidate_features = candidate_embeddings.mean(dim=1) # [num_candidates, dim]
# 3. 计算query与候选条目的相似度
query_feature = query_features[batch_idx] # [dim]
similarity_scores = F.cosine_similarity(
query_feature.unsqueeze(0), candidate_features, dim=1
) # [num_candidates]
# 4. 将相似度分数归一化为概率分布
relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype)
# 相关性阈值:选择概率大于某个阈值的候选项
# 动态阈值:如果所有候选项的相似度都很平均,降低阈值
mean_prob = relevance_probs.mean()
adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5)
relevant_mask = relevance_probs > adaptive_threshold
if relevant_mask.sum() == 0:
# 如果没有足够相关的选择相似度最高的top-k
top_k = min(5, len(indices))
_, top_indices = similarity_scores.topk(top_k)
relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool)
relevant_mask[top_indices] = True
# 第二层:在相关候选中应用平衡策略
if relevant_mask.sum() > 1:
# 计算平衡分数(使用频率低的分数高)
relevant_indices = indices[relevant_mask]
relevant_usage = self.usage_counts[relevant_indices]
# 平衡分数使用频率的倒数加1避免除零
balance_scores = 1.0 / (relevant_usage + 1.0)
balance_scores = balance_scores / (balance_scores.sum() + 1e-8)
# 相关性分数(基于真实的语义相似度)
relevant_rel_scores = relevance_probs[relevant_mask]
relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8)
# 综合分数:动态权重组合
combined_scores = (self.current_relevance_weight * relevant_rel_scores +
self.current_balance_weight * balance_scores.to(dtype))
# 确保数据类型一致
adjustment = self.base_balance_strength * combined_scores.to(dtype)
# 将综合分数应用到enhanced_scores
enhanced_scores[batch_idx, relevant_mask] = (
scores[relevant_mask] + adjustment
)
# 清理中间变量,释放显存
del candidate_tokens, flat_tokens, flat_embeddings, candidate_embeddings, candidate_features
return enhanced_scores.to(device)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 应用智能分层选择策略
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
# 6. 基于增强后的分数进行最终top-k选择
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
# 7. 更新使用统计
self.update_usage_statistics(flat_indices)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
# db_value = db_values.view(self.batch_size,-1)
return db_values
@torch.no_grad()
def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
# 先创建token embeddings
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
# 创建ExtractDB传入tok_embeddings引用
self.extract_db = ExtractDB(self.params, self.tok_embeddings)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight
self.database_output.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.params = params
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
h_list = []
for l, layer in enumerate(self.layers):
# 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer(
h, db_value, pos_cis_real
)
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
# Get features from v path - now we output embedding-dimension vectors
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
# self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# except:
# pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq = input_ids.shape[1], True
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
logits = out.logits[:, -1, :]
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break