删除了与kv cache有关的代码

This commit is contained in:
Jax922 2025-05-13 16:26:31 +08:00
parent f31e17030c
commit fc688ddde4

View File

@ -179,8 +179,6 @@ class Attention(nn.Module):
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
pos_cis: torch.Tensor, pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=True,
db_value=None): db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
@ -190,11 +188,11 @@ class Attention(nn.Module):
# 应用旋转位置编码(使用实数版本) # 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis) xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# kv_cache实现 # kv_cache实现 REMOVED
if past_key_value is not None: # if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1) # xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1) # xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None # past_kv = (xk, xv) if use_cache else None
# 重复键值对 # 重复键值对
xq, xk, xv = ( xq, xk, xv = (
@ -247,7 +245,7 @@ class Attention(nn.Module):
output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output)) output = self.resid_dropout(self.wo(output))
return output, past_kv return output
@ -459,7 +457,7 @@ class MiniMindBlock(nn.Module):
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True): def forward(self, x, db_value, pos_cis):
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace()
# db_value = None # db_value = None
@ -504,11 +502,9 @@ class MiniMindBlock(nn.Module):
# 注意力计算 # 注意力计算
h_attn, past_kv = self.attention( h_attn = self.attention(
self.attention_norm(x), self.attention_norm(x),
pos_cis, pos_cis,
past_key_value=past_key_value,
use_cache=use_cache,
db_value=db_value db_value=db_value
) )
@ -519,7 +515,7 @@ class MiniMindBlock(nn.Module):
# 前馈神经网络 # 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h)) out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv return out
class ExtractDB(nn.Module): class ExtractDB(nn.Module):
def __init__(self,params): def __init__(self,params):
@ -642,15 +638,11 @@ class MiniMindLM(PreTrainedModel):
def forward(self, def forward(self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
**args): **args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)] pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
h_list = [] h_list = []
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
@ -665,13 +657,10 @@ class MiniMindLM(PreTrainedModel):
index = self.extract_db.q_to_k(h) index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index) db_value = self.extract_db.get_data(index)
h, past_kv = layer( h = layer(
h, db_value, pos_cis_real, h, db_value, pos_cis_real
past_key_value=past_key_values[l],
use_cache=use_cache
) )
past_kvs.append(past_kv)
h_list.append(h.unsqueeze(0)) h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3) h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
@ -698,7 +687,6 @@ class MiniMindLM(PreTrainedModel):
# 进一步简化,只保留必要的参数 # 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast( output = CausalLMOutputWithPast(
logits=logits, logits=logits,
past_key_values=past_kvs,
) )
output.hidden_states = h output.hidden_states = h
@ -714,17 +702,17 @@ class MiniMindLM(PreTrainedModel):
@torch.inference_mode() @torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args): stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成 # 流式生成
if stream: if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成 # 直接生成
generated = [] generated = []
for i in range(input_ids.size(0)): for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences): for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out] tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1) full_sequence = torch.cat([non_pad, gen], dim=-1)
@ -741,15 +729,14 @@ class MiniMindLM(PreTrainedModel):
res = output.view(input_ids.size(0) * num_return_sequences, -1) res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None start, first_seq = input_ids.shape[1], True
while input_ids.shape[1] < max_new_tokens - 1: while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache: if first_seq:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False out, first_seq = self(input_ids, **args), False
else: else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
start_pos=input_ids.shape[1] - 1, **args) logits = out.logits[:, -1, :]
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9) logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0: