diff --git a/model/LMConfig.py b/model/LMConfig.py index 7069dc2..bf0e4b9 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig): def __init__( self, - dim: int = 768, - n_layers: int = 16, + dim: int = 512, + n_layers: int = 8, n_heads: int = 16, n_kv_heads: int = 8, vocab_size: int = 6400, diff --git a/model/model.py b/model/model.py index bc68dd6..36a5cb4 100644 --- a/model/model.py +++ b/model/model.py @@ -1,6 +1,8 @@ import math import struct import inspect +import time + from .LMConfig import LMConfig from typing import Any, Optional, Tuple import numpy as np @@ -80,26 +82,15 @@ class Attention(nn.Module): self.dropout = args.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn - if not self.flash: - # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") - mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) - mask = torch.triu(mask, diagonal=1) - self.register_buffer("mask", mask) + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) - def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False): + def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False): bsz, seqlen, _ = x.shape - if use_kv_cache and self.eval(): - if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - else: - token = x[:, -1:, :] - xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1) - xk = torch.cat((self.k_cache, self.wk(token)), dim=1) - xv = torch.cat((self.v_cache, self.wv(token)), dim=1) - self.k_cache, self.v_cache = xk, xv - else: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) @@ -107,6 +98,13 @@ class Attention(nn.Module): xq, xk = apply_rotary_emb(xq, xk, pos_cis) + # 更高效的kv_cache实现 + if kv_cache and self.eval(): + if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)): + xk = torch.cat((self.k_cache, xk), dim=1) + xv = torch.cat((self.v_cache, xv), dim=1) + self.k_cache, self.v_cache = xk, xv + xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -114,13 +112,12 @@ class Attention(nn.Module): xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) - if self.flash: + if self.flash and seqlen != 1: output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) - assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) @@ -304,8 +301,8 @@ class TransformerBlock(nn.Module): dropout=args.dropout, ) - def forward(self, x, pos_cis, use_kv_cache=False): - h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache) + def forward(self, x, pos_cis, kv_cache=False): + h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -351,18 +348,21 @@ class Transformer(PreTrainedModel): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, - use_kv_cache=False, **keyargs): + kv_cache=False, **keyargs): + current_idx = 0 if 'input_ids' in keyargs: tokens = keyargs['input_ids'] if 'attention_mask' in keyargs: targets = keyargs['attention_mask'] + if 'current_idx' in keyargs: + current_idx = int(keyargs['current_idx']) _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) - pos_cis = self.pos_cis[:seqlen] + pos_cis = self.pos_cis[current_idx:current_idx + seqlen] for idx, layer in enumerate(self.layers): - h = layer(h, pos_cis, use_kv_cache) + h = layer(h, pos_cis, kv_cache) h = self.norm(h) @@ -378,16 +378,21 @@ class Transformer(PreTrainedModel): return self.OUT @torch.inference_mode() - def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1., - use_kv_cache=True): + def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True): + # rp: repetition_penalty index = idx.shape[1] + init_inference = True while idx.shape[1] < max_new_tokens - 1: - inference_res = self(idx, use_kv_cache=use_kv_cache) + if init_inference or not kv_cache: + inference_res, init_inference = self(idx, kv_cache=kv_cache), False + else: + inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1) + logits = inference_res.logits logits = logits[:, -1, :] for token in set(idx.tolist()[0]): - logits[:, token] /= repetition_penalty + logits[:, token] /= rp if temperature == 0.0: _, idx_next = torch.topk(logits, k=1, dim=-1)