diff --git a/model/model.py b/model/model.py index 99e2911..9cb6a49 100644 --- a/model/model.py +++ b/model/model.py @@ -179,8 +179,6 @@ class Attention(nn.Module): def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache=True, db_value=None): bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。 @@ -190,11 +188,11 @@ class Attention(nn.Module): # 应用旋转位置编码(使用实数版本) xq, xk = apply_rotary_emb_real(xq, xk, pos_cis) - # kv_cache实现 - if past_key_value is not None: - xk = torch.cat([past_key_value[0], xk], dim=1) - xv = torch.cat([past_key_value[1], xv], dim=1) - past_kv = (xk, xv) if use_cache else None + # kv_cache实现 REMOVED + # if past_key_value is not None: + # xk = torch.cat([past_key_value[0], xk], dim=1) + # xv = torch.cat([past_key_value[1], xv], dim=1) + # past_kv = (xk, xv) if use_cache else None # 重复键值对 xq, xk, xv = ( @@ -247,7 +245,7 @@ class Attention(nn.Module): output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = self.resid_dropout(self.wo(output)) - return output, past_kv + return output @@ -459,7 +457,7 @@ class MiniMindBlock(nn.Module): # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 - def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True): + def forward(self, x, db_value, pos_cis): # import pdb;pdb.set_trace() # db_value = None @@ -504,11 +502,9 @@ class MiniMindBlock(nn.Module): # 注意力计算 - h_attn, past_kv = self.attention( + h_attn = self.attention( self.attention_norm(x), pos_cis, - past_key_value=past_key_value, - use_cache=use_cache, db_value=db_value ) @@ -519,7 +515,7 @@ class MiniMindBlock(nn.Module): # 前馈神经网络 out = h + self.feed_forward(self.ffn_norm(h)) - return out, past_kv + return out class ExtractDB(nn.Module): def __init__(self,params): @@ -642,15 +638,11 @@ class MiniMindLM(PreTrainedModel): def forward(self, input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, **args): - past_key_values = past_key_values or [None] * len(self.layers) start_pos = args.get('start_pos', 0) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)] - past_kvs = [] h_list = [] for l, layer in enumerate(self.layers): @@ -665,13 +657,10 @@ class MiniMindLM(PreTrainedModel): index = self.extract_db.q_to_k(h) db_value = self.extract_db.get_data(index) - h, past_kv = layer( - h, db_value, pos_cis_real, - past_key_value=past_key_values[l], - use_cache=use_cache + h = layer( + h, db_value, pos_cis_real ) - past_kvs.append(past_kv) h_list.append(h.unsqueeze(0)) h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3) @@ -698,7 +687,6 @@ class MiniMindLM(PreTrainedModel): # 进一步简化,只保留必要的参数 output = CausalLMOutputWithPast( logits=logits, - past_key_values=past_kvs, ) output.hidden_states = h @@ -714,17 +702,17 @@ class MiniMindLM(PreTrainedModel): @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, - stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args): + stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): # 流式生成 if stream: - return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) + return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) for _ in range(num_return_sequences): - out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) @@ -741,15 +729,14 @@ class MiniMindLM(PreTrainedModel): res = output.view(input_ids.size(0) * num_return_sequences, -1) return res - def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): - start, first_seq, past_kvs = input_ids.shape[1], True, None + def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): + start, first_seq = input_ids.shape[1], True while input_ids.shape[1] < max_new_tokens - 1: - if first_seq or not use_cache: - out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False + if first_seq: + out, first_seq = self(input_ids, **args), False else: - out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, - start_pos=input_ids.shape[1] - 1, **args) - logits, past_kvs = out.logits[:, -1, :], out.past_key_values + out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args) + logits = out.logits[:, -1, :] logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9) if top_p is not None and top_p < 1.0: