update model

This commit is contained in:
gongjy 2024-09-16 16:59:52 +08:00
parent e4ad822c40
commit 8c18b324d0

View File

@ -66,20 +66,18 @@ class Attention(nn.Module):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.k_cache, self.v_cache = None, None
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
if not self.flash:
@ -88,57 +86,39 @@ class Attention(nn.Module):
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(
self,
x: torch.Tensor,
pos_cis: torch.Tensor,
use_kv_cache: bool = False,
past_kv: Tuple[torch.Tensor] = None
):
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
bsz, seqlen, _ = x.shape
# QKV
# inference
if use_kv_cache:
# 只计算最后一个token的Q
current_token = x[:, -1:, :]
if not past_kv:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)
token = x[:, -1:, :]
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
past_kv = (xk, xv)
self.k_cache, self.v_cache = xk, xv
else:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
@ -146,13 +126,11 @@ class Attention(nn.Module):
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
# restore time as batch dimension and concat heads
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return output, past_kv
return output
class FeedForward(nn.Module):
@ -182,7 +160,6 @@ class MoEGate(nn.Module):
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
# topk selection algorithm
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
@ -194,7 +171,7 @@ class MoEGate(nn.Module):
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
@ -202,19 +179,15 @@ class MoEGate(nn.Module):
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
@ -331,11 +304,10 @@ class TransformerBlock(nn.Module):
dropout=args.dropout,
)
def forward(self, x, pos_cis, use_kv_cache=False, past_kv: Tuple[torch.Tensor] = None):
attn_res, past_kv = self.attention(self.attention_norm(x), pos_cis, use_kv_cache, past_kv)
h = x + attn_res
def forward(self, x, pos_cis, use_kv_cache=False):
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv
return out
class Transformer(PreTrainedModel):
@ -357,22 +329,16 @@ class Transformer(PreTrainedModel):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
# share the unembedding parameters with the embedding parameters
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
# some useful precompute for the RoPE relative positional embeddings
self.tok_embeddings.weight = self.output.weight
pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("pos_cis", pos_cis, persistent=False)
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None
self.OUT = CausalLMOutputWithPast()
@ -384,11 +350,8 @@ class Transformer(PreTrainedModel):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None,
use_kv_cache=False, past_kvs=None, **keyargs):
if past_kvs is None:
past_kvs = [None for _ in range(self.n_layers)]
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
use_kv_cache=False, **keyargs):
if 'input_ids' in keyargs:
tokens = keyargs['input_ids']
if 'attention_mask' in keyargs:
@ -399,63 +362,45 @@ class Transformer(PreTrainedModel):
h = self.dropout(h)
pos_cis = self.pos_cis[:seqlen]
for idx, layer in enumerate(self.layers):
h, past_kvs[idx] = layer(h, pos_cis, use_kv_cache, past_kvs[idx])
h = layer(h, pos_cis, use_kv_cache)
h = self.norm(h)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
logits = self.output(h[:, [-1], :])
self.last_loss = None
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', self.last_loss)
if use_kv_cache:
return self.OUT, past_kvs
return self.OUT
@torch.inference_mode()
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.):
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
use_kv_cache=True):
index = idx.shape[1]
use_kv_cache = True
past_kvs = [None for _ in range(self.n_layers)]
while idx.shape[1] < max_new_tokens - 1:
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx # if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
inference_res = self(idx_cond, use_kv_cache=use_kv_cache, past_kvs=past_kvs)
if use_kv_cache:
logits, past_kvs = inference_res[0].logits, inference_res[1]
else:
logits = inference_res.logits
inference_res = self(idx, use_kv_cache=use_kv_cache)
logits = inference_res.logits
logits = logits[:, -1, :]
logits = logits[:, -1, :] # crop to just the final time step
# Apply repetition penalty
for token in set(idx.tolist()[0]):
logits[:, token] /= repetition_penalty
if temperature == 0.0:
# "sample" the single most likely index
__, idx_next = torch.topk(logits, k=1, dim=-1)
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# pluck the logits at the final step and scale by desired temperature
logits = logits / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, __ = torch.topk(logits, min(top_k, logits.size(-1)))
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
# append sampled index to the running sequence and continue
if idx_next == eos:
break
@ -468,63 +413,8 @@ class Transformer(PreTrainedModel):
@torch.inference_mode()
def eval_answer(self, idx):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
past_kvs = [None for _ in range(self.n_layers)]
inference_res = self(idx_cond, use_kv_cache=False, past_kvs=past_kvs)
inference_res = self(idx_cond)
logits = inference_res.logits
logits = logits[:, -1, :]
return logits
def export(self, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""
f = open(filepath, 'wb')
def serialize(t):
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
b = struct.pack(f'{len(d)}f', *d)
f.write(b)
# first write out the header
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
f.write(header)
# next write out the embedding weights
serialize(self.tok_embeddings.weight)
# now all the layers
# attention weights
for layer in self.layers:
serialize(layer.attention_norm.weight)
for layer in self.layers:
serialize(layer.attention.wq.weight)
for layer in self.layers:
serialize(layer.attention.wk.weight)
for layer in self.layers:
serialize(layer.attention.wv.weight)
for layer in self.layers:
serialize(layer.attention.wo.weight)
# ffn weights
for layer in self.layers:
serialize(layer.ffn_norm.weight)
for layer in self.layers:
serialize(layer.feed_forward.w1.weight)
for layer in self.layers:
serialize(layer.feed_forward.w2.weight)
for layer in self.layers:
serialize(layer.feed_forward.w3.weight)
# final rmsnorm
serialize(self.norm.weight)
# note: no need to write final classifier weights due to weight sharing
# pos_cis
serialize(self.freqs_cos[:p.max_seq_len])
serialize(self.freqs_sin[:p.max_seq_len])
# write to binary file
f.close()
print(f"wrote {filepath}")