Efficient implementation of Inference KV cache

This commit is contained in:
gongjy 2024-09-21 00:01:05 +08:00
parent 0cd7d4b2c2
commit 02297df3c1
2 changed files with 35 additions and 30 deletions

View File

@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
def __init__(
self,
dim: int = 768,
n_layers: int = 16,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 16,
n_kv_heads: int = 8,
vocab_size: int = 6400,

View File

@ -1,6 +1,8 @@
import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple
import numpy as np
@ -80,26 +82,15 @@ class Attention(nn.Module):
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
if not self.flash:
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
bsz, seqlen, _ = x.shape
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
token = x[:, -1:, :]
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
self.k_cache, self.v_cache = xk, xv
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@ -107,6 +98,13 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 更高效的kv_cache实现
if kv_cache and self.eval():
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
xk = torch.cat((self.k_cache, xk), dim=1)
xv = torch.cat((self.v_cache, xv), dim=1)
self.k_cache, self.v_cache = xk, xv
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
@ -114,13 +112,12 @@ class Attention(nn.Module):
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if self.flash:
if self.flash and seqlen != 1:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True)
else:
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
dropout=args.dropout,
)
def forward(self, x, pos_cis, use_kv_cache=False):
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
def forward(self, x, pos_cis, kv_cache=False):
h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
out = h + self.feed_forward(self.ffn_norm(h))
return out
@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
use_kv_cache=False, **keyargs):
kv_cache=False, **keyargs):
current_idx = 0
if 'input_ids' in keyargs:
tokens = keyargs['input_ids']
if 'attention_mask' in keyargs:
targets = keyargs['attention_mask']
if 'current_idx' in keyargs:
current_idx = int(keyargs['current_idx'])
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
pos_cis = self.pos_cis[:seqlen]
pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
for idx, layer in enumerate(self.layers):
h = layer(h, pos_cis, use_kv_cache)
h = layer(h, pos_cis, kv_cache)
h = self.norm(h)
@ -378,16 +378,21 @@ class Transformer(PreTrainedModel):
return self.OUT
@torch.inference_mode()
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
use_kv_cache=True):
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
# rp: repetition_penalty
index = idx.shape[1]
init_inference = True
while idx.shape[1] < max_new_tokens - 1:
inference_res = self(idx, use_kv_cache=use_kv_cache)
if init_inference or not kv_cache:
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
else:
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
logits = inference_res.logits
logits = logits[:, -1, :]
for token in set(idx.tolist()[0]):
logits[:, token] /= repetition_penalty
logits[:, token] /= rp
if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1)