update model
This commit is contained in:
parent
e4ad822c40
commit
8c18b324d0
178
model/model.py
178
model/model.py
@ -66,20 +66,18 @@ class Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
model_parallel_size = 1
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.k_cache, self.v_cache = None, None
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
|
||||
# use flash attention or a manual implementation?
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
|
||||
if not self.flash:
|
||||
@ -88,57 +86,39 @@ class Attention(nn.Module):
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
use_kv_cache: bool = False,
|
||||
past_kv: Tuple[torch.Tensor] = None
|
||||
):
|
||||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
|
||||
bsz, seqlen, _ = x.shape
|
||||
# QKV
|
||||
# inference
|
||||
if use_kv_cache:
|
||||
# 只计算最后一个token的Q
|
||||
current_token = x[:, -1:, :]
|
||||
|
||||
if not past_kv:
|
||||
xq = self.wq(x)
|
||||
xk, xv = self.wk(x), self.wv(x)
|
||||
if use_kv_cache and self.eval():
|
||||
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
else:
|
||||
past_key, past_value = past_kv
|
||||
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
|
||||
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
|
||||
xv = torch.cat((past_value, self.wv(current_token)), dim=1)
|
||||
token = x[:, -1:, :]
|
||||
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
|
||||
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
|
||||
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
|
||||
|
||||
past_kv = (xk, xv)
|
||||
self.k_cache, self.v_cache = xk, xv
|
||||
else:
|
||||
xq = self.wq(x)
|
||||
xk, xv = self.wk(x), self.wv(x)
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
# RoPE relative positional embeddings
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# grouped multiquery attention: expand out keys and values
|
||||
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
|
||||
# make heads into a batch dimension
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
# flash implementation
|
||||
if self.flash:
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=True)
|
||||
else:
|
||||
# manual implementation
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
assert hasattr(self, 'mask')
|
||||
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
@ -146,13 +126,11 @@ class Attention(nn.Module):
|
||||
scores = self.attn_dropout(scores)
|
||||
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
||||
|
||||
# restore time as batch dimension and concat heads
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
|
||||
# final projection into the residual stream
|
||||
output = self.wo(output)
|
||||
output = self.resid_dropout(output)
|
||||
return output, past_kv
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -182,7 +160,6 @@ class MoEGate(nn.Module):
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
@ -194,7 +171,7 @@ class MoEGate(nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
@ -202,19 +179,15 @@ class MoEGate(nn.Module):
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
@ -331,11 +304,10 @@ class TransformerBlock(nn.Module):
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, pos_cis, use_kv_cache=False, past_kv: Tuple[torch.Tensor] = None):
|
||||
attn_res, past_kv = self.attention(self.attention_norm(x), pos_cis, use_kv_cache, past_kv)
|
||||
h = x + attn_res
|
||||
def forward(self, x, pos_cis, use_kv_cache=False):
|
||||
h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(PreTrainedModel):
|
||||
@ -357,22 +329,16 @@ class Transformer(PreTrainedModel):
|
||||
self.layers.append(TransformerBlock(layer_id, params))
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
|
||||
# share the unembedding parameters with the embedding parameters
|
||||
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
||||
|
||||
# some useful precompute for the RoPE relative positional embeddings
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
||||
self.register_buffer("pos_cis", pos_cis, persistent=False)
|
||||
|
||||
# init all weights
|
||||
self.apply(self._init_weights)
|
||||
# apply special scaled init to the residual projections, per GPT-2 paper
|
||||
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
|
||||
|
||||
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
||||
self.last_loss = None
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
|
||||
@ -384,11 +350,8 @@ class Transformer(PreTrainedModel):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: Optional[torch.Tensor] = None,
|
||||
targets: Optional[torch.Tensor] = None,
|
||||
use_kv_cache=False, past_kvs=None, **keyargs):
|
||||
if past_kvs is None:
|
||||
past_kvs = [None for _ in range(self.n_layers)]
|
||||
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
||||
use_kv_cache=False, **keyargs):
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
@ -399,63 +362,45 @@ class Transformer(PreTrainedModel):
|
||||
h = self.dropout(h)
|
||||
pos_cis = self.pos_cis[:seqlen]
|
||||
for idx, layer in enumerate(self.layers):
|
||||
h, past_kvs[idx] = layer(h, pos_cis, use_kv_cache, past_kvs[idx])
|
||||
h = layer(h, pos_cis, use_kv_cache)
|
||||
|
||||
h = self.norm(h)
|
||||
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
logits = self.output(h)
|
||||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the output on the very last position
|
||||
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
logits = self.output(h[:, [-1], :])
|
||||
self.last_loss = None
|
||||
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('last_loss', self.last_loss)
|
||||
|
||||
if use_kv_cache:
|
||||
return self.OUT, past_kvs
|
||||
return self.OUT
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.):
|
||||
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
|
||||
use_kv_cache=True):
|
||||
index = idx.shape[1]
|
||||
use_kv_cache = True
|
||||
past_kvs = [None for _ in range(self.n_layers)]
|
||||
while idx.shape[1] < max_new_tokens - 1:
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx # if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
inference_res = self(idx_cond, use_kv_cache=use_kv_cache, past_kvs=past_kvs)
|
||||
if use_kv_cache:
|
||||
logits, past_kvs = inference_res[0].logits, inference_res[1]
|
||||
else:
|
||||
logits = inference_res.logits
|
||||
inference_res = self(idx, use_kv_cache=use_kv_cache)
|
||||
logits = inference_res.logits
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
logits = logits[:, -1, :] # crop to just the final time step
|
||||
|
||||
# Apply repetition penalty
|
||||
for token in set(idx.tolist()[0]):
|
||||
logits[:, token] /= repetition_penalty
|
||||
|
||||
if temperature == 0.0:
|
||||
# "sample" the single most likely index
|
||||
__, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
else:
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, __ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
|
||||
# append sampled index to the running sequence and continue
|
||||
|
||||
if idx_next == eos:
|
||||
break
|
||||
|
||||
@ -468,63 +413,8 @@ class Transformer(PreTrainedModel):
|
||||
|
||||
@torch.inference_mode()
|
||||
def eval_answer(self, idx):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
past_kvs = [None for _ in range(self.n_layers)]
|
||||
inference_res = self(idx_cond, use_kv_cache=False, past_kvs=past_kvs)
|
||||
inference_res = self(idx_cond)
|
||||
logits = inference_res.logits
|
||||
logits = logits[:, -1, :]
|
||||
return logits
|
||||
|
||||
def export(self, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(t):
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
f.write(b)
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
p = self.params
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
serialize(self.tok_embeddings.weight)
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wq.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wk.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wv.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wo.weight)
|
||||
# ffn weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.ffn_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w1.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w2.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
serialize(self.norm.weight)
|
||||
# note: no need to write final classifier weights due to weight sharing
|
||||
# pos_cis
|
||||
serialize(self.freqs_cos[:p.max_seq_len])
|
||||
serialize(self.freqs_sin[:p.max_seq_len])
|
||||
|
||||
# write to binary file
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
Loading…
x
Reference in New Issue
Block a user