Efficient implementation of Inference KV cache

This commit is contained in:
gongjy 2024-09-21 00:01:05 +08:00
parent 0cd7d4b2c2
commit 02297df3c1
2 changed files with 35 additions and 30 deletions

View File

@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
dim: int = 768, dim: int = 512,
n_layers: int = 16, n_layers: int = 8,
n_heads: int = 16, n_heads: int = 16,
n_kv_heads: int = 8, n_kv_heads: int = 8,
vocab_size: int = 6400, vocab_size: int = 6400,

View File

@ -1,6 +1,8 @@
import math import math
import struct import struct
import inspect import inspect
import time
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import numpy as np import numpy as np
@ -80,25 +82,14 @@ class Attention(nn.Module):
self.dropout = args.dropout self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
if not self.flash:
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1) mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask) self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False): def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
bsz, seqlen, _ = x.shape bsz, seqlen, _ = x.shape
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
token = x[:, -1:, :]
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
self.k_cache, self.v_cache = xk, xv
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
@ -107,6 +98,13 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, pos_cis) xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 更高效的kv_cache实现
if kv_cache and self.eval():
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
xk = torch.cat((self.k_cache, xk), dim=1)
xv = torch.cat((self.v_cache, xv), dim=1)
self.k_cache, self.v_cache = xk, xv
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
@ -114,13 +112,12 @@ class Attention(nn.Module):
xk = xk.transpose(1, 2) xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2) xv = xv.transpose(1, 2)
if self.flash: if self.flash and seqlen != 1:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
is_causal=True) is_causal=True)
else: else:
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores) scores = self.attn_dropout(scores)
@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
dropout=args.dropout, dropout=args.dropout,
) )
def forward(self, x, pos_cis, use_kv_cache=False): def forward(self, x, pos_cis, kv_cache=False):
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache) h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
out = h + self.feed_forward(self.ffn_norm(h)) out = h + self.feed_forward(self.ffn_norm(h))
return out return out
@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
use_kv_cache=False, **keyargs): kv_cache=False, **keyargs):
current_idx = 0
if 'input_ids' in keyargs: if 'input_ids' in keyargs:
tokens = keyargs['input_ids'] tokens = keyargs['input_ids']
if 'attention_mask' in keyargs: if 'attention_mask' in keyargs:
targets = keyargs['attention_mask'] targets = keyargs['attention_mask']
if 'current_idx' in keyargs:
current_idx = int(keyargs['current_idx'])
_bsz, seqlen = tokens.shape _bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens) h = self.tok_embeddings(tokens)
h = self.dropout(h) h = self.dropout(h)
pos_cis = self.pos_cis[:seqlen] pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
h = layer(h, pos_cis, use_kv_cache) h = layer(h, pos_cis, kv_cache)
h = self.norm(h) h = self.norm(h)
@ -378,16 +378,21 @@ class Transformer(PreTrainedModel):
return self.OUT return self.OUT
@torch.inference_mode() @torch.inference_mode()
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1., def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
use_kv_cache=True): # rp: repetition_penalty
index = idx.shape[1] index = idx.shape[1]
init_inference = True
while idx.shape[1] < max_new_tokens - 1: while idx.shape[1] < max_new_tokens - 1:
inference_res = self(idx, use_kv_cache=use_kv_cache) if init_inference or not kv_cache:
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
else:
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
logits = inference_res.logits logits = inference_res.logits
logits = logits[:, -1, :] logits = logits[:, -1, :]
for token in set(idx.tolist()[0]): for token in set(idx.tolist()[0]):
logits[:, token] /= repetition_penalty logits[:, token] /= rp
if temperature == 0.0: if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1) _, idx_next = torch.topk(logits, k=1, dim=-1)