删除了与kv cache有关的代码
This commit is contained in:
parent
f31e17030c
commit
fc688ddde4
@ -179,8 +179,6 @@ class Attention(nn.Module):
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=True,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
@ -190,11 +188,11 @@ class Attention(nn.Module):
|
||||
|
||||
# 应用旋转位置编码(使用实数版本)
|
||||
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
# kv_cache实现 REMOVED
|
||||
# if past_key_value is not None:
|
||||
# xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
# xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
# past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
# 重复键值对
|
||||
xq, xk, xv = (
|
||||
@ -247,7 +245,7 @@ class Attention(nn.Module):
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output, past_kv
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@ -459,7 +457,7 @@ class MiniMindBlock(nn.Module):
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
||||
def forward(self, x, db_value, pos_cis):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
@ -504,11 +502,9 @@ class MiniMindBlock(nn.Module):
|
||||
|
||||
|
||||
# 注意力计算
|
||||
h_attn, past_kv = self.attention(
|
||||
h_attn = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
db_value=db_value
|
||||
)
|
||||
|
||||
@ -519,7 +515,7 @@ class MiniMindBlock(nn.Module):
|
||||
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
return out
|
||||
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
@ -642,15 +638,11 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
|
||||
past_kvs = []
|
||||
h_list = []
|
||||
|
||||
for l, layer in enumerate(self.layers):
|
||||
@ -665,13 +657,10 @@ class MiniMindLM(PreTrainedModel):
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
h, past_kv = layer(
|
||||
h, db_value, pos_cis_real,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
h = layer(
|
||||
h, db_value, pos_cis_real
|
||||
)
|
||||
|
||||
past_kvs.append(past_kv)
|
||||
h_list.append(h.unsqueeze(0))
|
||||
|
||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
||||
@ -698,7 +687,6 @@ class MiniMindLM(PreTrainedModel):
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
past_key_values=past_kvs,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
@ -714,17 +702,17 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
@ -741,15 +729,14 @@ class MiniMindLM(PreTrainedModel):
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq = input_ids.shape[1], True
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq or not use_cache:
|
||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits = out.logits[:, -1, :]
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user