Efficient implementation of Inference KV cache
This commit is contained in:
parent
0cd7d4b2c2
commit
02297df3c1
@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 768,
|
||||
n_layers: int = 16,
|
||||
dim: int = 512,
|
||||
n_layers: int = 8,
|
||||
n_heads: int = 16,
|
||||
n_kv_heads: int = 8,
|
||||
vocab_size: int = 6400,
|
||||
|
@ -1,6 +1,8 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple
|
||||
import numpy as np
|
||||
@ -80,26 +82,15 @@ class Attention(nn.Module):
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
|
||||
if not self.flash:
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask)
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
|
||||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
|
||||
bsz, seqlen, _ = x.shape
|
||||
if use_kv_cache and self.eval():
|
||||
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
else:
|
||||
token = x[:, -1:, :]
|
||||
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
|
||||
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
|
||||
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
|
||||
|
||||
self.k_cache, self.v_cache = xk, xv
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
@ -107,6 +98,13 @@ class Attention(nn.Module):
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# 更高效的kv_cache实现
|
||||
if kv_cache and self.eval():
|
||||
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
|
||||
xk = torch.cat((self.k_cache, xk), dim=1)
|
||||
xv = torch.cat((self.v_cache, xv), dim=1)
|
||||
self.k_cache, self.v_cache = xk, xv
|
||||
|
||||
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
|
||||
@ -114,13 +112,12 @@ class Attention(nn.Module):
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
if self.flash:
|
||||
if self.flash and seqlen != 1:
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=True)
|
||||
else:
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
assert hasattr(self, 'mask')
|
||||
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, pos_cis, use_kv_cache=False):
|
||||
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
|
||||
def forward(self, x, pos_cis, kv_cache=False):
|
||||
h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
||||
use_kv_cache=False, **keyargs):
|
||||
kv_cache=False, **keyargs):
|
||||
current_idx = 0
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
targets = keyargs['attention_mask']
|
||||
if 'current_idx' in keyargs:
|
||||
current_idx = int(keyargs['current_idx'])
|
||||
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
h = self.dropout(h)
|
||||
pos_cis = self.pos_cis[:seqlen]
|
||||
pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
|
||||
for idx, layer in enumerate(self.layers):
|
||||
h = layer(h, pos_cis, use_kv_cache)
|
||||
h = layer(h, pos_cis, kv_cache)
|
||||
|
||||
h = self.norm(h)
|
||||
|
||||
@ -378,16 +378,21 @@ class Transformer(PreTrainedModel):
|
||||
return self.OUT
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
|
||||
use_kv_cache=True):
|
||||
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
|
||||
# rp: repetition_penalty
|
||||
index = idx.shape[1]
|
||||
init_inference = True
|
||||
while idx.shape[1] < max_new_tokens - 1:
|
||||
inference_res = self(idx, use_kv_cache=use_kv_cache)
|
||||
if init_inference or not kv_cache:
|
||||
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
|
||||
else:
|
||||
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
|
||||
|
||||
logits = inference_res.logits
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
for token in set(idx.tolist()[0]):
|
||||
logits[:, token] /= repetition_penalty
|
||||
logits[:, token] /= rp
|
||||
|
||||
if temperature == 0.0:
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user