import math import struct import inspect import time from .LMConfig import LMConfig from typing import Any, Optional, Tuple, List, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim assert 0 <= 1 < ndim assert pos_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return pos_cis.view(*shape) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) pos_cis = unite_shape(pos_cis, xq_) xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class KnowledgeDataset(nn.Module): def __init__(self, params, tok_embeddings, is_train=True): super().__init__() self.is_train = is_train self.params = params self.tok_embeddings = tok_embeddings # 嵌入参数 self.knowledge_dim = params.knowledge_dim self.key_dim = self.knowledge_dim // 2 self.to_queries = nn.Sequential( nn.Linear(params.dim, self.knowledge_dim, bias=False), ) ## 数据库参数 self.knowledge_num = params.knowledge_num self.knowledge_length = params.knowledge_length self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True) self.product_key_topk = min(16, self.knowledge_num) # 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num)) # 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度 self.register_buffer('knowledge_dataset', torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) ) # 计算step数目,用于动态调整权重 self.step_counter = 0 def intelligent_selection(self, query, all_scores, all_indices): """智能分层选择策略""" if self.is_train == False: return all_scores, all_indices batch_size = all_scores.size(0) device = all_scores.device dtype = all_scores.dtype # 对每个batch进行分层选择 enhanced_scores = all_scores.clone() query_features = query.mean(dim=1) # [batch_size, dim] # 预先计算所有候选条目的嵌入(批量优化) all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0) unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True) # 批量计算唯一候选条目的嵌入 candidate_tokens = self.knowledge_dataset[unique_indices] flat_tokens = candidate_tokens.view(-1) flat_embeddings = self.tok_embeddings(flat_tokens) #获取flat_tokens对应的index pre_update_indices = unique_indices.view(-1) pre_update_embeddings = flat_embeddings.view( len(unique_indices), self.knowledge_length, -1 ) unique_candidate_features = flat_embeddings.view( len(unique_indices), self.knowledge_length, -1 ).mean(dim=1) # [num_unique_candidates, dim] # 归一化候选特征(优化相似度计算) normalized_candidates = F.normalize(unique_candidate_features, dim=-1) normalized_queries = F.normalize(query_features, dim=-1) # 收集所有batch的best_tokens batch_best_tokens = [] batch_best_tokens_embeddings = [] for batch_idx in range(batch_size): indices = all_indices[batch_idx] # 获取当前batch候选条目对应的特征索引 start_idx = batch_idx * len(indices) end_idx = start_idx + len(indices) batch_inverse_indices = inverse_indices[start_idx:end_idx] # 使用预计算的归一化特征进行优化相似度计算 batch_candidate_features = normalized_candidates[batch_inverse_indices] query_feature = normalized_queries[batch_idx] # 使用矩阵乘法计算余弦相似度 similarity_scores = torch.mv(batch_candidate_features, query_feature) # 找到最大相似度分数的索引 max_similarity_idx = torch.argmax(similarity_scores) # 获取最大相似度对应的候选条目索引 best_candidate_idx = indices[max_similarity_idx] # 获取对应的tokens best_tokens = self.knowledge_dataset[best_candidate_idx] best_tokens_embeddings = self.tok_embeddings(best_tokens) # 将当前batch的best_tokens添加到列表中 batch_best_tokens.append(best_tokens) batch_best_tokens_embeddings.append(best_tokens_embeddings) # 将所有batch的best_tokens堆叠成一个张量 # [batch_size, knowledge_length] all_best_tokens = torch.stack(batch_best_tokens, dim=0) all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) # 获取 # 使用重新计算的embeddings更新self.keys if self.is_train: self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) # 更新被修改过的key with torch.no_grad(): self.has_update_keys[pre_update_indices] = 1 return all_best_tokens, all_best_tokens_embeddings def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings): # 使用pre_update_embeddings更新self.keys with torch.no_grad(): pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512] pre_update_embeddings = self.to_queries(pre_update_embeddings) self.keys[pre_update_indices] = pre_update_embeddings def search_index(self,x): batch_size, seq_len, dim = x.shape # collapse sequence dimension by averaging x_flat = x.mean(dim=1) # [batch_size, dim] queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] # queries = queries.reshape(batch_size, 2, self.key_dim) # queries = queries.permute(1, 0, 2) # 2. 计算queries与keys的相似度 sim = torch.einsum('b d, k d -> b k', queries, self.keys) # 3. 在两个子空间分别做top-k scores_and_indices = sim.topk(self.product_key_topk, dim=-1) scores, indices = scores_and_indices[0], scores_and_indices[1] # 5. 应用智能分层选择策略 best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) # 6. 更新1%的keys if self.is_train: # 获取未更新过的keys的索引 not_updated_indices = torch.where(self.has_update_keys == 0)[0] # 如果有未更新的keys,随机选择num_update_keys个进行更新 if len(not_updated_indices) > 0: num_update_keys = int(self.knowledge_num * 0.01) perm = torch.randperm(len(not_updated_indices))[:num_update_keys] pre_update_indices = not_updated_indices[perm] pre_update_tokens = self.knowledge_dataset[pre_update_indices] pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1) self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) return best_tokens, best_tokens_embeddings class CrossAttention(nn.Module): def __init__( self, config ): super().__init__() self.config = config self.num_heads = 8 self.head_dim = self.config.dim // self.num_heads self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) def forward(self, x, db, context_mask=None, pos_emb=None): batch_size = x.size(0) # 分离多头 q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) if pos_emb is not None: pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) q = q + pos_emb k = k + pos_emb v = v + pos_emb attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if context_mask is not None: expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) attn_weights = F.softmax(attn_scores, dim=-1) context = torch.matmul(attn_weights, v) context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) context = self.to_out(context) return context class Attention(nn.Module): def __init__(self, args: LMConfig): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 self.n_local_heads = args.n_heads self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor, pos_cis: torch.Tensor): bsz, seq_len, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, pos_cis) if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( xq, xk, xv, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores += self.mask[:, :, :seq_len, :seq_len] scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = scores @ xv output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = self.resid_dropout(self.wo(output)) return output class FeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() if config.hidden_dim is None: hidden_dim = 4 * config.dim hidden_dim = int(2 * hidden_dim / 3) config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class MoEGate(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.dim self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == 'softmax': scores = logits.softmax(dim=-1) else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = 0 return topk_idx, topk_weight, aux_loss class MOEFeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.experts = nn.ModuleList([ FeedForward(config) for _ in range(config.n_routed_experts) ]) self.gate = MoEGate(config) if config.n_shared_experts is not None: self.shared_experts = FeedForward(config) def forward(self, x): identity = x orig_shape = x.shape bsz, seq_len, _ = x.shape # 使用门控机制选择专家 topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=torch.float16) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) else: y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) self.aux_loss = aux_loss return y @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.config.num_experts_per_tok # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4) # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时 # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok) # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推 for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: continue expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) return expert_cache class MiniMindBlock(nn.Module): def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset): super().__init__() self.n_heads = config.n_heads self.dim = config.dim self.head_dim = config.dim // config.n_heads self.self_attention = Attention(config) self.cross_attention = CrossAttention(config) self.knowledge_dataset = knowledge_dataset self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) def forward(self, x, pos_cis): h_attn = self.self_attention( self.attention_norm(x), pos_cis ) db, db_embeddings = self.knowledge_dataset.search_index(h_attn) h_attn = self.cross_attention(h_attn, db_embeddings) h = x + h_attn out = h + self.feed_forward(self.ffn_norm(h)) return out class MiniMindLM(PreTrainedModel): config_class = LMConfig def __init__(self, params: LMConfig = None): self.params = params or LMConfig() super().__init__(self.params) self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings) self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.OUT = CausalLMOutputWithPast() def forward(self, input_ids: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **args): start_pos = args.get('start_pos', 0) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] for l, layer in enumerate(self.layers): h = layer( h, pos_cis ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) # 进一步简化,只保留必要的参数 output = CausalLMOutputWithPast( logits=logits, ) output.hidden_states = h output.aux_loss = aux_loss return output @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): # 流式生成 if stream: return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) for _ in range(num_return_sequences): out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) generated.append(full_sequence) max_length = max(seq.size(1) for seq in generated) generated = [ torch.cat( [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], dim=-1) for seq in generated ] output = torch.cat(generated, dim=0) res = output.view(input_ids.size(0) * num_return_sequences, -1) return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): start, first_seq, past_kvs = input_ids.shape[1], True, None while input_ids.shape[1] < max_new_tokens - 1: if first_seq: out, first_seq = self(input_ids, **args), False else: out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args) logits, past_kvs = out.logits[:, -1, :], out.past_key_values logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9) if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float('Inf') input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break