Efficient implementation of Inference KV cache
This commit is contained in:
parent
0cd7d4b2c2
commit
02297df3c1
@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int = 768,
|
dim: int = 512,
|
||||||
n_layers: int = 16,
|
n_layers: int = 8,
|
||||||
n_heads: int = 16,
|
n_heads: int = 16,
|
||||||
n_kv_heads: int = 8,
|
n_kv_heads: int = 8,
|
||||||
vocab_size: int = 6400,
|
vocab_size: int = 6400,
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import inspect
|
import inspect
|
||||||
|
import time
|
||||||
|
|
||||||
from .LMConfig import LMConfig
|
from .LMConfig import LMConfig
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -80,26 +82,15 @@ class Attention(nn.Module):
|
|||||||
self.dropout = args.dropout
|
self.dropout = args.dropout
|
||||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||||
|
|
||||||
if not self.flash:
|
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
mask = torch.triu(mask, diagonal=1)
|
||||||
mask = torch.triu(mask, diagonal=1)
|
self.register_buffer("mask", mask)
|
||||||
self.register_buffer("mask", mask)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
|
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
|
||||||
bsz, seqlen, _ = x.shape
|
bsz, seqlen, _ = x.shape
|
||||||
if use_kv_cache and self.eval():
|
|
||||||
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
|
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
||||||
else:
|
|
||||||
token = x[:, -1:, :]
|
|
||||||
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
|
|
||||||
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
|
|
||||||
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
|
|
||||||
|
|
||||||
self.k_cache, self.v_cache = xk, xv
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
else:
|
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
||||||
|
|
||||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
@ -107,6 +98,13 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||||
|
|
||||||
|
# 更高效的kv_cache实现
|
||||||
|
if kv_cache and self.eval():
|
||||||
|
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
|
||||||
|
xk = torch.cat((self.k_cache, xk), dim=1)
|
||||||
|
xv = torch.cat((self.v_cache, xv), dim=1)
|
||||||
|
self.k_cache, self.v_cache = xk, xv
|
||||||
|
|
||||||
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||||
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||||
|
|
||||||
@ -114,13 +112,12 @@ class Attention(nn.Module):
|
|||||||
xk = xk.transpose(1, 2)
|
xk = xk.transpose(1, 2)
|
||||||
xv = xv.transpose(1, 2)
|
xv = xv.transpose(1, 2)
|
||||||
|
|
||||||
if self.flash:
|
if self.flash and seqlen != 1:
|
||||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
||||||
dropout_p=self.dropout if self.training else 0.0,
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
else:
|
else:
|
||||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
assert hasattr(self, 'mask')
|
|
||||||
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||||
scores = self.attn_dropout(scores)
|
scores = self.attn_dropout(scores)
|
||||||
@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
|
|||||||
dropout=args.dropout,
|
dropout=args.dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, pos_cis, use_kv_cache=False):
|
def forward(self, x, pos_cis, kv_cache=False):
|
||||||
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
|
h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
|
||||||
out = h + self.feed_forward(self.ffn_norm(h))
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
|
|||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
||||||
use_kv_cache=False, **keyargs):
|
kv_cache=False, **keyargs):
|
||||||
|
current_idx = 0
|
||||||
if 'input_ids' in keyargs:
|
if 'input_ids' in keyargs:
|
||||||
tokens = keyargs['input_ids']
|
tokens = keyargs['input_ids']
|
||||||
if 'attention_mask' in keyargs:
|
if 'attention_mask' in keyargs:
|
||||||
targets = keyargs['attention_mask']
|
targets = keyargs['attention_mask']
|
||||||
|
if 'current_idx' in keyargs:
|
||||||
|
current_idx = int(keyargs['current_idx'])
|
||||||
|
|
||||||
_bsz, seqlen = tokens.shape
|
_bsz, seqlen = tokens.shape
|
||||||
h = self.tok_embeddings(tokens)
|
h = self.tok_embeddings(tokens)
|
||||||
h = self.dropout(h)
|
h = self.dropout(h)
|
||||||
pos_cis = self.pos_cis[:seqlen]
|
pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
h = layer(h, pos_cis, use_kv_cache)
|
h = layer(h, pos_cis, kv_cache)
|
||||||
|
|
||||||
h = self.norm(h)
|
h = self.norm(h)
|
||||||
|
|
||||||
@ -378,16 +378,21 @@ class Transformer(PreTrainedModel):
|
|||||||
return self.OUT
|
return self.OUT
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
|
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
|
||||||
use_kv_cache=True):
|
# rp: repetition_penalty
|
||||||
index = idx.shape[1]
|
index = idx.shape[1]
|
||||||
|
init_inference = True
|
||||||
while idx.shape[1] < max_new_tokens - 1:
|
while idx.shape[1] < max_new_tokens - 1:
|
||||||
inference_res = self(idx, use_kv_cache=use_kv_cache)
|
if init_inference or not kv_cache:
|
||||||
|
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
|
||||||
|
else:
|
||||||
|
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
|
||||||
|
|
||||||
logits = inference_res.logits
|
logits = inference_res.logits
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
for token in set(idx.tolist()[0]):
|
for token in set(idx.tolist()[0]):
|
||||||
logits[:, token] /= repetition_penalty
|
logits[:, token] /= rp
|
||||||
|
|
||||||
if temperature == 0.0:
|
if temperature == 0.0:
|
||||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user