diff --git a/model/model_extra.py b/model/model_extra.py new file mode 100644 index 0000000..fca8c54 --- /dev/null +++ b/model/model_extra.py @@ -0,0 +1,745 @@ +import math +import struct +import inspect +import time +import gc +#子空间二维分解+梯度更新 +from .LMConfig import LMConfig +from typing import Any, Optional, Tuple, List, Union +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self.weight * self._norm(x.float()).type_as(x) + + +def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return pos_cis + + +def apply_rotary_emb(xq, xk, pos_cis): + def unite_shape(pos_cis, x): + ndim = x.ndim + assert 0 <= 1 < ndim + assert pos_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return pos_cis.view(*shape) + + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + pos_cis = unite_shape(pos_cis, xq_) + xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +class KnowledgeDataset(nn.Module): + def __init__(self, params, tok_embeddings, is_train=True): + super().__init__() + self.is_train = is_train + self.params = params + self.tok_embeddings = tok_embeddings + + # 嵌入参数 + self.knowledge_dim = params.knowledge_dim + self.key_dim = self.knowledge_dim // 2 + self.to_queries = nn.Sequential( + nn.Linear(params.dim, self.knowledge_dim, bias=False), + ) + + ## 数据库参数 + self.knowledge_num = params.knowledge_num + self.knowledge_length = params.knowledge_length + + # 修改键存储为二维分解空间,设置为可训练参数 + self.num_keys = int(math.sqrt(self.knowledge_num)) + # 确保keys是可训练参数 + self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True) + self.product_key_topk = min(16, self.num_keys) + + # 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度 + self.register_buffer('knowledge_dataset', + torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)) + + # 计算step数目,用于动态调整权重 + self.step_counter = 0 + + # 移除批次计数器和更新频率相关代码 + + def intelligent_selection(self, query, all_scores, all_indices): + """智能分层选择策略""" + if self.is_train == False: + return all_scores, all_indices + + batch_size = all_scores.size(0) + device = all_scores.device + dtype = all_scores.dtype + + # 记录进入智能选择前的内存状态 + if hasattr(self, 'step_counter'): + self.step_counter += 1 + # 禁用GPU内存监控记录以提高性能 + # if self.step_counter % 50 == 0: # 每50次调用记录一次 + # if torch.cuda.is_available(): + # allocated_before = torch.cuda.memory_allocated() / (1024**3) + # print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB") + + # 对每个batch进行分层选择 + enhanced_scores = all_scores.clone() + query_features = query.mean(dim=1) # [batch_size, dim] + + # 预先计算所有候选条目的嵌入(批量优化) + all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0) + unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True) + + # 批量计算唯一候选条目的嵌入 + candidate_tokens = self.knowledge_dataset[unique_indices] + flat_tokens = candidate_tokens.view(-1) + flat_embeddings = self.tok_embeddings(flat_tokens) + + # 获取flat_tokens对应的index(保留这些变量以便其他地方使用) + pre_update_indices = unique_indices.view(-1) + pre_update_embeddings = flat_embeddings.view( + len(unique_indices), self.knowledge_length, -1 + ) + + unique_candidate_features = flat_embeddings.view( + len(unique_indices), self.knowledge_length, -1 + ).mean(dim=1) # [num_unique_candidates, dim] + + # 归一化候选特征(优化相似度计算) + normalized_candidates = F.normalize(unique_candidate_features, dim=-1) + normalized_queries = F.normalize(query_features, dim=-1) + + # 收集所有batch的best_tokens + batch_best_tokens = [] + batch_best_tokens_embeddings = [] + + for batch_idx in range(batch_size): + indices = all_indices[batch_idx] + + # 获取当前batch候选条目对应的特征索引 + start_idx = batch_idx * len(indices) + end_idx = start_idx + len(indices) + batch_inverse_indices = inverse_indices[start_idx:end_idx] + + # 使用预计算的归一化特征进行优化相似度计算 + batch_candidate_features = normalized_candidates[batch_inverse_indices] + query_feature = normalized_queries[batch_idx] + + # 使用矩阵乘法计算余弦相似度 + similarity_scores = torch.mv(batch_candidate_features, query_feature) + + # 找到最大相似度分数的索引 + max_similarity_idx = torch.argmax(similarity_scores) + + # 获取最大相似度对应的候选条目索引 + best_candidate_idx = indices[max_similarity_idx] + + # 获取对应的tokens + best_tokens = self.knowledge_dataset[best_candidate_idx] + best_tokens_embeddings = self.tok_embeddings(best_tokens) + + # 将当前batch的best_tokens添加到列表中 + batch_best_tokens.append(best_tokens) + batch_best_tokens_embeddings.append(best_tokens_embeddings) + + # 将所有batch的best_tokens堆叠成一个张量 + # [batch_size, knowledge_length] + all_best_tokens = torch.stack(batch_best_tokens, dim=0) + all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) + + # 清理中间张量以防止内存泄漏 + del all_candidate_indices, unique_indices, inverse_indices + del unique_candidate_features, normalized_candidates, normalized_queries + del batch_best_tokens, batch_best_tokens_embeddings + del flat_tokens, flat_embeddings, pre_update_embeddings + + # 记录退出智能选择后的内存状态(已禁用以提高性能) + # if hasattr(self, 'step_counter') and self.step_counter % 50 == 0: + # if torch.cuda.is_available(): + # allocated_after = torch.cuda.memory_allocated() / (1024**3) + # print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB") + + # 强制垃圾回收(仅在监控步骤) + if hasattr(self, 'step_counter') and self.step_counter % 100 == 0: + gc.collect() + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + + return all_best_tokens, all_best_tokens_embeddings + + + + def search_index(self, x): + batch_size, seq_len, dim = x.shape + + # 1. 序列维度平均 + x_flat = x.mean(dim=1) # [batch_size, dim] + + # 2. 生成查询向量并重塑为两个子查询 + queries = self.to_queries(x_flat) # [batch_size, knowledge_dim] + queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim] + # 调整维度顺序,使子空间维度位于首位 + queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim] + + # 3. 计算每个子空间的相似度 + sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) + + # 4. 在两个子空间分别做top-k + scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] + scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] + indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] + + # 5. 组合两个子空间的结果 + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk] + all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk] + + # 6. 将结果重塑为二维 + all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk] + all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk] + + # 7. 选择最终的top-k结果 + scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1) + indices = torch.gather(all_indices, 1, indices_of_indices) + + # 8. 应用智能分层选择策略 + best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) + + + return best_tokens, best_tokens_embeddings + +class CrossAttention(nn.Module): + def __init__( + self, + config + ): + super().__init__() + self.config = config + self.num_heads = 8 + self.head_dim = self.config.dim // self.num_heads + self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) + self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) + self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) + + self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) + + def forward(self, x, db, context_mask=None, pos_emb=None): + batch_size = x.size(0) + + # 监控交叉注意力开始时的内存(已禁用以提高性能) + if not hasattr(self, 'call_counter'): + self.call_counter = 0 + self.call_counter += 1 + + # 禁用GPU内存监控记录以提高性能 + # if self.call_counter % 100 == 0 and torch.cuda.is_available(): + # allocated_before = torch.cuda.memory_allocated() / (1024**3) + # print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB") + + # 分离多头 + q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if pos_emb is not None: + pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + q = q + pos_emb + k = k + pos_emb + v = v + pos_emb + + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + + if context_mask is not None: + expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) + + attn_weights = F.softmax(attn_scores, dim=-1) + + context = torch.matmul(attn_weights, v) + + context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) + + context = self.to_out(context) + + # 清理中间张量 + del q, k, v, attn_scores, attn_weights + + # 监控交叉注意力结束时的内存(已禁用以提高性能) + # if self.call_counter % 100 == 0 and torch.cuda.is_available(): + # allocated_after = torch.cuda.memory_allocated() / (1024**3) + # print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB") + + return context + +class Attention(nn.Module): + def __init__(self, args: LMConfig): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert args.n_heads % self.n_kv_heads == 0 + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.dropout = args.dropout + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, + x: torch.Tensor, + pos_cis: torch.Tensor): + bsz, seq_len, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, pos_cis) + if self.flash and seq_len != 1: + dropout_p = self.dropout if self.training else 0.0 + output = F.scaled_dot_product_attention( + xq, xk, xv, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True + ) + else: + scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) + scores += self.mask[:, :, :seq_len, :seq_len] + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = self.attn_dropout(scores) + output = scores @ xv + + output = output.transpose(1, 2).reshape(bsz, seq_len, -1) + output = self.resid_dropout(self.wo(output)) + return output + + +class FeedForward(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + if config.hidden_dim is None: + hidden_dim = 4 * config.dim + hidden_dim = int(2 * hidden_dim / 3) + config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) + self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) + self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) + + +class MoEGate(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_(1, topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( + seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = 0 + return topk_idx, topk_weight, aux_loss + + +class MOEFeedForward(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.experts = nn.ModuleList([ + FeedForward(config) + for _ in range(config.n_routed_experts) + ]) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + self.shared_experts = FeedForward(config) + + def forward(self, x): + identity = x + orig_shape = x.shape + bsz, seq_len, _ = x.shape + # 使用门控机制选择专家 + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) + y = torch.empty_like(x, dtype=torch.float16) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + self.aux_loss = aux_loss + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.config.num_experts_per_tok + # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4) + # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时 + # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok) + # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推 + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens).to(expert_cache.dtype) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) + + return expert_cache + + +class TripleExtractionHead(nn.Module): + """三元组提取任务头""" + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + + # 三元组长度超参数 + self.max_subject_len = config.max_subject_len + self.max_predicate_len = config.max_predicate_len + self.max_object_len = config.max_object_len + + # 自注意力机制 + self.self_attention = Attention(config) + self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps) + + # 交叉注意力机制(用于主语和宾语提取) + self.cross_attention_subject = CrossAttention(config) + self.cross_attention_object = CrossAttention(config) + + # 归一化层 + self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.object_norm = RMSNorm(config.dim, eps=config.norm_eps) + + # Feed Forward 网络 + self.predicate_ff = FeedForward(config) + self.subject_ff = FeedForward(config) + self.object_ff = FeedForward(config) + + # 输出投影层 - 修改为支持序列预测 + self.predicate_output = nn.Linear(config.dim, self.max_predicate_len *config.dim, bias=False) + self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False) + self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False) + + print(f"三元组提取任务头配置:") + print(f"- 主语最大长度: {self.max_subject_len}") + print(f"- 谓语最大长度: {self.max_predicate_len}") + print(f"- 宾语最大长度: {self.max_object_len}") + + def forward(self, h, pos_cis): + """ + Args: + h: [batch_size, seq_len, dim] - 来自transformer层的隐藏状态 + pos_cis: 位置编码 + Returns: + predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测 + subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测 + object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测 + """ + batch_size, seq_len, dim = h.shape + + # 1. h通过自注意力得到h1 + h1 = self.self_attention(self.self_attn_norm(h), pos_cis) + h1 = h + h1 # 残差连接 + + # 2. h1通过feed_forward得到谓语输出 + predicate_features = self.predicate_ff(h1) + predicate_features = predicate_features.mean(dim=1) + predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size] + predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1) + + # 3. h1通过交叉注意力(k,v都是h)得到h2 + h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h + h2 = h1 + h2 # 残差连接 + + # 4. h2通过feed_forward得到主语输出 + subject_features = self.subject_ff(self.subject_norm(h2)) + subject_features = subject_features.mean(dim=1) + subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size] + subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1) + + # 5. h2通过交叉注意力(k,v都是h)得到h3 + h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h + h3 = h2 + h3 # 残差连接 + + # 6. h3通过feed_forward得到宾语输出 + object_features = self.object_ff(self.object_norm(h3)) + object_features = object_features.mean(dim=1) + object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size] + object_logits = object_raw.view(batch_size, self.max_object_len, -1) + + return predicate_logits, subject_logits, object_logits + + +class MiniMindBlock(nn.Module): + def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset): + super().__init__() + self.n_heads = config.n_heads + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.self_attention = Attention(config) + self.cross_attention = CrossAttention(config) + self.knowledge_dataset = knowledge_dataset + + self.layer_id = layer_id + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) + + def forward(self, x, pos_cis): + h_attn = self.self_attention( + self.attention_norm(x), + pos_cis + ) + db, db_embeddings = self.knowledge_dataset.search_index(h_attn) + h_attn = self.cross_attention(h_attn, db_embeddings) + h = x + h_attn + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class MiniMindLM(PreTrainedModel): + config_class = LMConfig + + def __init__(self, params: LMConfig = None,mode="triple"): + self.params = params or LMConfig() + super().__init__(self.params) + self.vocab_size, self.n_layers = params.vocab_size, params.n_layers + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.dropout = nn.Dropout(params.dropout) + self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings) + self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)]) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.tok_embeddings.weight = self.output.weight + + # 添加三元组提取任务头(可训练) + self.triple_extraction_head = TripleExtractionHead(params) + self.register_buffer("pos_cis", + precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), + persistent=False) + self.OUT = CausalLMOutputWithPast() + self.freeze_embedding = False + + self.mode = mode + + # 冻结所有指定组件的权重 + self._freeze_components() + + def _freeze_components(self): + """冻结指定组件的权重""" + # 冻结词嵌入层 + for param in self.tok_embeddings.parameters(): + param.requires_grad = False + + # 冻结知识数据库 + for param in self.knowledge_dataset.parameters(): + param.requires_grad = False + + # 冻结所有transformer层 + for param in self.layers.parameters(): + param.requires_grad = False + + # 冻结输出层 + for param in self.output.parameters(): + param.requires_grad = False + + # pos_cis是buffer,本身就不需要梯度,但为了明确起见 + # (实际上buffer默认就是requires_grad=False) + if hasattr(self, 'pos_cis'): + self.pos_cis.requires_grad = False + + print("已冻结以下组件的权重:") + print("- tok_embeddings") + print("- knowledge_dataset") + print("- layers (所有transformer层)") + print("- output") + print("- pos_cis") + print("注意:triple_extraction_head 保持可训练状态") + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + step: int = 0, + **args): + start_pos = args.get('start_pos', 0) + h = self.dropout(self.tok_embeddings(input_ids)) + pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] + for l, layer in enumerate(self.layers): + h = layer( + h, pos_cis + ) + + # 应用三元组提取任务头 + predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis) + predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1) + subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1) + object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1) + + predicate_logits = self.output(predicate_logits) + subject_logits = self.output(subject_logits) + object_logits = self.output(object_logits) + + predicate_logits = predicate_logits.reshape(input_ids.size(0), self.params.max_predicate_len, -1) + subject_logits = subject_logits.reshape(input_ids.size(0), self.params.max_subject_len, -1) + object_logits = object_logits.reshape(input_ids.size(0), self.params.max_object_len, -1) + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.output(self.norm(h)[:, slice_indices, :]) + aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) + + # 进一步简化,只保留必要的参数 + output = CausalLMOutputWithPast( + logits=logits, + ) + output.hidden_states = h + output.aux_loss = aux_loss + + # 添加三元组提取结果 + # 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size] + output.predicate_logits = predicate_logits + output.subject_logits = subject_logits + output.object_logits = object_logits + + return output + + @torch.inference_mode() + def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, + stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): + # 流式生成 + if stream: + return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) + + # 直接生成 + generated = [] + for i in range(input_ids.size(0)): + non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) + for _ in range(num_return_sequences): + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) + tokens_list = [tokens[:, -1:] for tokens in out] + gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad + full_sequence = torch.cat([non_pad, gen], dim=-1) + generated.append(full_sequence) + + max_length = max(seq.size(1) for seq in generated) + generated = [ + torch.cat( + [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], + dim=-1) + for seq in generated + ] + output = torch.cat(generated, dim=0) + res = output.view(input_ids.size(0) * num_return_sequences, -1) + return res + + def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): + start, first_seq, past_kvs = input_ids.shape[1], True, None + while input_ids.shape[1] < max_new_tokens - 1: + if first_seq: + out, first_seq = self(input_ids, **args), False + else: + out = self(input_ids[:, -1:], + start_pos=input_ids.shape[1] - 1, **args) + logits, past_kvs = out.logits[:, -1, :], out.past_key_values + logits[:, list(set(input_ids.tolist()[0]))] /= rp + logits /= (temperature + 1e-9) + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = -float('Inf') + input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + input_ids = torch.cat((input_ids, input_ids_next), dim=1) + yield input_ids[:, start:] + if input_ids_next.item() == eos_token_id: + break + diff --git a/model/model_original.py b/model/model_original.py new file mode 100644 index 0000000..6e2dcb7 --- /dev/null +++ b/model/model_original.py @@ -0,0 +1,385 @@ +import math +import struct +import inspect +import time + +from .LMConfig import LMConfig +from typing import Any, Optional, Tuple, List, Union +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self.weight * self._norm(x.float()).type_as(x) + + +def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return pos_cis + + +def apply_rotary_emb(xq, xk, pos_cis): + def unite_shape(pos_cis, x): + ndim = x.ndim + assert 0 <= 1 < ndim + assert pos_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return pos_cis.view(*shape) + + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + pos_cis = unite_shape(pos_cis, xq_) + xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: LMConfig): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert args.n_heads % self.n_kv_heads == 0 + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.dropout = args.dropout + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, + x: torch.Tensor, + pos_cis: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache=False): + bsz, seq_len, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, pos_cis) + # kv_cache实现 + if past_key_value is not None: + xk = torch.cat([past_key_value[0], xk], dim=1) + xv = torch.cat([past_key_value[1], xv], dim=1) + past_kv = (xk, xv) if use_cache else None + + xq, xk, xv = ( + xq.transpose(1, 2), + repeat_kv(xk, self.n_rep).transpose(1, 2), + repeat_kv(xv, self.n_rep).transpose(1, 2) + ) + if self.flash and seq_len != 1: + dropout_p = self.dropout if self.training else 0.0 + output = F.scaled_dot_product_attention( + xq, xk, xv, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True + ) + else: + scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) + scores += self.mask[:, :, :seq_len, :seq_len] + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = self.attn_dropout(scores) + output = scores @ xv + + output = output.transpose(1, 2).reshape(bsz, seq_len, -1) + output = self.resid_dropout(self.wo(output)) + return output, past_kv + + +class FeedForward(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + if config.hidden_dim is None: + hidden_dim = 4 * config.dim + hidden_dim = int(2 * hidden_dim / 3) + config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) + self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) + self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) + + +class MoEGate(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_(1, topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( + seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = 0 + return topk_idx, topk_weight, aux_loss + + +class MOEFeedForward(nn.Module): + def __init__(self, config: LMConfig): + super().__init__() + self.config = config + self.experts = nn.ModuleList([ + FeedForward(config) + for _ in range(config.n_routed_experts) + ]) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + self.shared_experts = FeedForward(config) + + def forward(self, x): + identity = x + orig_shape = x.shape + bsz, seq_len, _ = x.shape + # 使用门控机制选择专家 + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) + y = torch.empty_like(x, dtype=torch.float16) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + self.aux_loss = aux_loss + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.config.num_experts_per_tok + # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4) + # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时 + # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok) + # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推 + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens).to(expert_cache.dtype) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) + + return expert_cache + + +class MiniMindBlock(nn.Module): + def __init__(self, layer_id: int, config: LMConfig): + super().__init__() + self.n_heads = config.n_heads + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.attention = Attention(config) + + self.layer_id = layer_id + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) + + def forward(self, x, pos_cis, past_key_value=None, use_cache=False): + h_attn, past_kv = self.attention( + self.attention_norm(x), + pos_cis, + past_key_value=past_key_value, + use_cache=use_cache + ) + h = x + h_attn + out = h + self.feed_forward(self.ffn_norm(h)) + return out, past_kv + + +class MiniMindLM(PreTrainedModel): + config_class = LMConfig + + def __init__(self, params: LMConfig = None): + self.params = params or LMConfig() + super().__init__(self.params) + self.vocab_size, self.n_layers = params.vocab_size, params.n_layers + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.dropout = nn.Dropout(params.dropout) + self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.tok_embeddings.weight = self.output.weight + self.register_buffer("pos_cis", + precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), + persistent=False) + self.OUT = CausalLMOutputWithPast() + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + logits_to_keep: Union[int, torch.Tensor] = 0, + **args): + past_key_values = past_key_values or [None] * len(self.layers) + start_pos = args.get('start_pos', 0) + h = self.dropout(self.tok_embeddings(input_ids)) + pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] + past_kvs = [] + for l, layer in enumerate(self.layers): + h, past_kv = layer( + h, pos_cis, + past_key_value=past_key_values[l], + use_cache=use_cache + ) + past_kvs.append(past_kv) + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.output(self.norm(h)[:, slice_indices, :]) + aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) + self.OUT.__setitem__('last_hidden_state', h) + self.OUT.__setitem__('logits', logits) + self.OUT.__setitem__('aux_loss', aux_loss) + self.OUT.__setitem__('past_key_values', past_kvs) + return self.OUT + + @torch.inference_mode() + def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, + stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args): + # 流式生成 + if stream: + return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) + + # 直接生成 + generated = [] + for i in range(input_ids.size(0)): + non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) + for _ in range(num_return_sequences): + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) + tokens_list = [tokens[:, -1:] for tokens in out] + gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad + full_sequence = torch.cat([non_pad, gen], dim=-1) + generated.append(full_sequence) + + max_length = max(seq.size(1) for seq in generated) + generated = [ + torch.cat( + [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], + dim=-1) + for seq in generated + ] + output = torch.cat(generated, dim=0) + res = output.view(input_ids.size(0) * num_return_sequences, -1) + return res + + def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): + start, first_seq, past_kvs = input_ids.shape[1], True, None + while input_ids.shape[1] < max_new_tokens - 1: + if first_seq or not use_cache: + out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False + else: + out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, + start_pos=input_ids.shape[1] - 1, **args) + logits, past_kvs = out.logits[:, -1, :], out.past_key_values + logits[:, list(set(input_ids.tolist()[0]))] /= rp + logits /= (temperature + 1e-9) + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = -float('Inf') + input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + input_ids = torch.cat((input_ids, input_ids_next), dim=1) + yield input_ids[:, start:] + if input_ids_next.item() == eos_token_id: + break \ No newline at end of file