603 lines
26 KiB
Python
603 lines
26 KiB
Python
import math
|
||
import struct
|
||
import inspect
|
||
import time
|
||
|
||
from .LMConfig import LMConfig
|
||
from typing import Any, Optional, Tuple, List, Union
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import nn
|
||
from transformers import PreTrainedModel
|
||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||
|
||
|
||
|
||
class RMSNorm(torch.nn.Module):
|
||
def __init__(self, dim: int, eps: float = 1e-6):
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def _norm(self, x):
|
||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
||
def forward(self, x):
|
||
return self.weight * self._norm(x.float()).type_as(x)
|
||
|
||
|
||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||
return pos_cis
|
||
|
||
|
||
def apply_rotary_emb(xq, xk, pos_cis):
|
||
def unite_shape(pos_cis, x):
|
||
ndim = x.ndim
|
||
assert 0 <= 1 < ndim
|
||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||
return pos_cis.view(*shape)
|
||
|
||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||
pos_cis = unite_shape(pos_cis, xq_)
|
||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||
|
||
class KnowledgeDataset(nn.Module):
|
||
def __init__(self, params, tok_embeddings, is_train=True):
|
||
super().__init__()
|
||
self.is_train = is_train
|
||
self.params = params
|
||
self.tok_embeddings = tok_embeddings
|
||
|
||
# 嵌入参数
|
||
self.knowledge_dim = params.knowledge_dim
|
||
self.key_dim = self.knowledge_dim // 2
|
||
self.to_queries = nn.Sequential(
|
||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||
)
|
||
|
||
## 数据库参数
|
||
self.knowledge_num = params.knowledge_num
|
||
self.knowledge_length = params.knowledge_length
|
||
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||
self.product_key_topk = min(16, self.knowledge_num)
|
||
|
||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||
|
||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||
self.register_buffer('knowledge_dataset',
|
||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||
)
|
||
|
||
# 计算step数目,用于动态调整权重
|
||
self.step_counter = 0
|
||
|
||
self.freeze_embedding = False
|
||
|
||
|
||
|
||
def intelligent_selection(self, query, all_scores, all_indices):
|
||
"""智能分层选择策略"""
|
||
if self.is_train == False:
|
||
return all_scores, all_indices
|
||
|
||
batch_size = all_scores.size(0)
|
||
device = all_scores.device
|
||
dtype = all_scores.dtype
|
||
|
||
# 对每个batch进行分层选择
|
||
enhanced_scores = all_scores.clone()
|
||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||
|
||
# 预先计算所有候选条目的嵌入(批量优化)
|
||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||
|
||
# 批量计算唯一候选条目的嵌入
|
||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||
flat_tokens = candidate_tokens.view(-1)
|
||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||
#获取flat_tokens对应的index
|
||
pre_update_indices = unique_indices.view(-1)
|
||
pre_update_embeddings = flat_embeddings.view(
|
||
len(unique_indices), self.knowledge_length, -1
|
||
)
|
||
|
||
unique_candidate_features = flat_embeddings.view(
|
||
len(unique_indices), self.knowledge_length, -1
|
||
).mean(dim=1) # [num_unique_candidates, dim]
|
||
|
||
# 归一化候选特征(优化相似度计算)
|
||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||
normalized_queries = F.normalize(query_features, dim=-1)
|
||
|
||
# 收集所有batch的best_tokens
|
||
batch_best_tokens = []
|
||
batch_best_tokens_embeddings = []
|
||
|
||
for batch_idx in range(batch_size):
|
||
indices = all_indices[batch_idx]
|
||
|
||
# 获取当前batch候选条目对应的特征索引
|
||
start_idx = batch_idx * len(indices)
|
||
end_idx = start_idx + len(indices)
|
||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||
|
||
# 使用预计算的归一化特征进行优化相似度计算
|
||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||
query_feature = normalized_queries[batch_idx]
|
||
|
||
# 使用矩阵乘法计算余弦相似度
|
||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||
|
||
# 找到最大相似度分数的索引
|
||
max_similarity_idx = torch.argmax(similarity_scores)
|
||
|
||
# 获取最大相似度对应的候选条目索引
|
||
best_candidate_idx = indices[max_similarity_idx]
|
||
|
||
# 获取对应的tokens
|
||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||
|
||
# 将当前batch的best_tokens添加到列表中
|
||
batch_best_tokens.append(best_tokens)
|
||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||
|
||
# 将所有batch的best_tokens堆叠成一个张量
|
||
# [batch_size, knowledge_length]
|
||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||
|
||
# 获取
|
||
|
||
# 使用重新计算的embeddings更新self.keys
|
||
if self.is_train:
|
||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||
|
||
# 更新被修改过的key
|
||
with torch.no_grad():
|
||
self.has_update_keys[pre_update_indices] = 1
|
||
|
||
return all_best_tokens, all_best_tokens_embeddings
|
||
|
||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||
if self.freeze_embedding:
|
||
return
|
||
# 使用pre_update_embeddings更新self.keys
|
||
with torch.no_grad():
|
||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||
self.keys[pre_update_indices] = pre_update_embeddings
|
||
|
||
def search_index(self,x):
|
||
batch_size, seq_len, dim = x.shape
|
||
|
||
# collapse sequence dimension by averaging
|
||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||
|
||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||
# queries = queries.permute(1, 0, 2)
|
||
|
||
# 2. 计算queries与keys的相似度
|
||
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||
|
||
# 3. 在两个子空间分别做top-k
|
||
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||
|
||
# 5. 应用智能分层选择策略
|
||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||
|
||
# 6. 更新1%的keys
|
||
if self.is_train:
|
||
# 获取未更新过的keys的索引
|
||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||
|
||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||
if len(not_updated_indices) > 0:
|
||
num_update_keys = int(self.knowledge_num * 0.01)
|
||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||
perm_num = perm.shape[0]
|
||
pre_update_indices = not_updated_indices[perm]
|
||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||
# 更新被修改过的key
|
||
with torch.no_grad():
|
||
self.has_update_keys[pre_update_indices] = 1
|
||
else:
|
||
print("all keys are updated")
|
||
# 重置所有keys的更新状态
|
||
self.has_update_keys.zero_()
|
||
# 重新获取所有可更新的索引
|
||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||
num_update_keys = int(self.knowledge_num * 0.01)
|
||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||
pre_update_indices = not_updated_indices[perm]
|
||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||
# 更新被修改过的key
|
||
with torch.no_grad():
|
||
self.has_update_keys[pre_update_indices] = 1
|
||
|
||
|
||
|
||
|
||
return best_tokens, best_tokens_embeddings
|
||
|
||
class CrossAttention(nn.Module):
|
||
def __init__(
|
||
self,
|
||
config
|
||
):
|
||
super().__init__()
|
||
self.config = config
|
||
self.num_heads = 8
|
||
self.head_dim = self.config.dim // self.num_heads
|
||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||
|
||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||
|
||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||
batch_size = x.size(0)
|
||
|
||
# 分离多头
|
||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
||
if pos_emb is not None:
|
||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||
q = q + pos_emb
|
||
k = k + pos_emb
|
||
v = v + pos_emb
|
||
|
||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||
|
||
if context_mask is not None:
|
||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||
|
||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||
|
||
context = torch.matmul(attn_weights, v)
|
||
|
||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||
|
||
context = self.to_out(context)
|
||
|
||
return context
|
||
|
||
class Attention(nn.Module):
|
||
def __init__(self, args: LMConfig):
|
||
super().__init__()
|
||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||
assert args.n_heads % self.n_kv_heads == 0
|
||
self.n_local_heads = args.n_heads
|
||
self.n_local_kv_heads = self.n_kv_heads
|
||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||
self.head_dim = args.dim // args.n_heads
|
||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||
self.attn_dropout = nn.Dropout(args.dropout)
|
||
self.resid_dropout = nn.Dropout(args.dropout)
|
||
self.dropout = args.dropout
|
||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||
mask = torch.triu(mask, diagonal=1)
|
||
self.register_buffer("mask", mask, persistent=False)
|
||
|
||
def forward(self,
|
||
x: torch.Tensor,
|
||
pos_cis: torch.Tensor):
|
||
bsz, seq_len, _ = x.shape
|
||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||
|
||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||
if self.flash and seq_len != 1:
|
||
dropout_p = self.dropout if self.training else 0.0
|
||
output = F.scaled_dot_product_attention(
|
||
xq, xk, xv,
|
||
attn_mask=None,
|
||
dropout_p=dropout_p,
|
||
is_causal=True
|
||
)
|
||
else:
|
||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||
scores = self.attn_dropout(scores)
|
||
output = scores @ xv
|
||
|
||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||
output = self.resid_dropout(self.wo(output))
|
||
return output
|
||
|
||
|
||
class FeedForward(nn.Module):
|
||
def __init__(self, config: LMConfig):
|
||
super().__init__()
|
||
if config.hidden_dim is None:
|
||
hidden_dim = 4 * config.dim
|
||
hidden_dim = int(2 * hidden_dim / 3)
|
||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||
self.dropout = nn.Dropout(config.dropout)
|
||
|
||
def forward(self, x):
|
||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||
|
||
|
||
class MoEGate(nn.Module):
|
||
def __init__(self, config: LMConfig):
|
||
super().__init__()
|
||
self.config = config
|
||
self.top_k = config.num_experts_per_tok
|
||
self.n_routed_experts = config.n_routed_experts
|
||
|
||
self.scoring_func = config.scoring_func
|
||
self.alpha = config.aux_loss_alpha
|
||
self.seq_aux = config.seq_aux
|
||
|
||
self.norm_topk_prob = config.norm_topk_prob
|
||
self.gating_dim = config.dim
|
||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self) -> None:
|
||
import torch.nn.init as init
|
||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||
|
||
def forward(self, hidden_states):
|
||
bsz, seq_len, h = hidden_states.shape
|
||
hidden_states = hidden_states.view(-1, h)
|
||
logits = F.linear(hidden_states, self.weight, None)
|
||
if self.scoring_func == 'softmax':
|
||
scores = logits.softmax(dim=-1)
|
||
else:
|
||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||
|
||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||
|
||
if self.top_k > 1 and self.norm_topk_prob:
|
||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||
topk_weight = topk_weight / denominator
|
||
|
||
if self.training and self.alpha > 0.0:
|
||
scores_for_aux = scores
|
||
aux_topk = self.top_k
|
||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||
if self.seq_aux:
|
||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||
seq_len * aux_topk / self.n_routed_experts)
|
||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||
else:
|
||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||
ce = mask_ce.float().mean(0)
|
||
Pi = scores_for_aux.mean(0)
|
||
fi = ce * self.n_routed_experts
|
||
aux_loss = (Pi * fi).sum() * self.alpha
|
||
else:
|
||
aux_loss = 0
|
||
return topk_idx, topk_weight, aux_loss
|
||
|
||
|
||
class MOEFeedForward(nn.Module):
|
||
def __init__(self, config: LMConfig):
|
||
super().__init__()
|
||
self.config = config
|
||
self.experts = nn.ModuleList([
|
||
FeedForward(config)
|
||
for _ in range(config.n_routed_experts)
|
||
])
|
||
self.gate = MoEGate(config)
|
||
if config.n_shared_experts is not None:
|
||
self.shared_experts = FeedForward(config)
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
orig_shape = x.shape
|
||
bsz, seq_len, _ = x.shape
|
||
# 使用门控机制选择专家
|
||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||
x = x.view(-1, x.shape[-1])
|
||
flat_topk_idx = topk_idx.view(-1)
|
||
if self.training:
|
||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||
y = torch.empty_like(x, dtype=torch.float16)
|
||
for i, expert in enumerate(self.experts):
|
||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||
y = y.view(*orig_shape)
|
||
else:
|
||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||
if self.config.n_shared_experts is not None:
|
||
y = y + self.shared_experts(identity)
|
||
self.aux_loss = aux_loss
|
||
return y
|
||
|
||
@torch.no_grad()
|
||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||
expert_cache = torch.zeros_like(x)
|
||
idxs = flat_expert_indices.argsort()
|
||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||
token_idxs = idxs // self.config.num_experts_per_tok
|
||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||
for i, end_idx in enumerate(tokens_per_expert):
|
||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||
if start_idx == end_idx:
|
||
continue
|
||
expert = self.experts[i]
|
||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||
expert_tokens = x[exp_token_idx]
|
||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||
|
||
return expert_cache
|
||
|
||
|
||
class MiniMindBlock(nn.Module):
|
||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||
super().__init__()
|
||
self.n_heads = config.n_heads
|
||
self.dim = config.dim
|
||
self.head_dim = config.dim // config.n_heads
|
||
self.self_attention = Attention(config)
|
||
self.cross_attention = CrossAttention(config)
|
||
self.knowledge_dataset = knowledge_dataset
|
||
|
||
self.layer_id = layer_id
|
||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||
|
||
def forward(self, x, pos_cis):
|
||
h_attn = self.self_attention(
|
||
self.attention_norm(x),
|
||
pos_cis
|
||
)
|
||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||
h = x + h_attn
|
||
out = h + self.feed_forward(self.ffn_norm(h))
|
||
return out
|
||
|
||
|
||
class MiniMindLM(PreTrainedModel):
|
||
config_class = LMConfig
|
||
|
||
def __init__(self, params: LMConfig = None):
|
||
self.params = params or LMConfig()
|
||
super().__init__(self.params)
|
||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||
self.dropout = nn.Dropout(params.dropout)
|
||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||
self.tok_embeddings.weight = self.output.weight
|
||
self.register_buffer("pos_cis",
|
||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||
persistent=False)
|
||
self.OUT = CausalLMOutputWithPast()
|
||
self.freeze_embedding = False
|
||
|
||
def forward(self,
|
||
input_ids: Optional[torch.Tensor] = None,
|
||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||
step: int = 0,
|
||
**args):
|
||
start_pos = args.get('start_pos', 0)
|
||
if self.freeze_embedding and step == 0:
|
||
self.tok_embeddings.weight.requires_grad = False
|
||
# 同时冻结KnowledgeDataset的嵌入更新
|
||
self.knowledge_dataset.freeze_embedding = True
|
||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||
h = self.dropout(self.tok_embeddings(input_ids))
|
||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||
for l, layer in enumerate(self.layers):
|
||
h = layer(
|
||
h, pos_cis
|
||
)
|
||
|
||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||
|
||
# 进一步简化,只保留必要的参数
|
||
output = CausalLMOutputWithPast(
|
||
logits=logits,
|
||
)
|
||
output.hidden_states = h
|
||
|
||
output.aux_loss = aux_loss
|
||
|
||
return output
|
||
|
||
@torch.inference_mode()
|
||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||
# 流式生成
|
||
if stream:
|
||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||
|
||
# 直接生成
|
||
generated = []
|
||
for i in range(input_ids.size(0)):
|
||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||
for _ in range(num_return_sequences):
|
||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||
generated.append(full_sequence)
|
||
|
||
max_length = max(seq.size(1) for seq in generated)
|
||
generated = [
|
||
torch.cat(
|
||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||
dim=-1)
|
||
for seq in generated
|
||
]
|
||
output = torch.cat(generated, dim=0)
|
||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||
return res
|
||
|
||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||
while input_ids.shape[1] < max_new_tokens - 1:
|
||
if first_seq:
|
||
out, first_seq = self(input_ids, **args), False
|
||
else:
|
||
out = self(input_ids[:, -1:],
|
||
start_pos=input_ids.shape[1] - 1, **args)
|
||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||
logits /= (temperature + 1e-9)
|
||
if top_p is not None and top_p < 1.0:
|
||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||
sorted_indices_to_remove = cumulative_probs > top_p
|
||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||
sorted_indices_to_remove[:, 0] = False
|
||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||
logits[indices_to_remove] = -float('Inf')
|
||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||
yield input_ids[:, start:]
|
||
if input_ids_next.item() == eos_token_id:
|
||
break |