diff --git a/model/model.py b/model/model.py index 0c563d1..4901bd7 100644 --- a/model/model.py +++ b/model/model.py @@ -66,20 +66,18 @@ class Attention(nn.Module): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 - model_parallel_size = 1 - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.k_cache, self.v_cache = None, None self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout - - # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn if not self.flash: @@ -88,57 +86,39 @@ class Attention(nn.Module): mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask) - def forward( - self, - x: torch.Tensor, - pos_cis: torch.Tensor, - use_kv_cache: bool = False, - past_kv: Tuple[torch.Tensor] = None - ): + def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False): bsz, seqlen, _ = x.shape - # QKV - # inference - if use_kv_cache: - # 只计算最后一个token的Q - current_token = x[:, -1:, :] - - if not past_kv: - xq = self.wq(x) - xk, xv = self.wk(x), self.wv(x) + if use_kv_cache and self.eval(): + if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) else: - past_key, past_value = past_kv - xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1) - xk = torch.cat((past_key, self.wk(current_token)), dim=1) - xv = torch.cat((past_value, self.wv(current_token)), dim=1) + token = x[:, -1:, :] + xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1) + xk = torch.cat((self.k_cache, self.wk(token)), dim=1) + xv = torch.cat((self.v_cache, self.wv(token)), dim=1) - past_kv = (xk, xv) + self.k_cache, self.v_cache = xk, xv else: - xq = self.wq(x) - xk, xv = self.wk(x), self.wv(x) + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # RoPE relative positional embeddings xq, xk = apply_rotary_emb(xq, xk, pos_cis) - # grouped multiquery attention: expand out keys and values xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - # make heads into a batch dimension - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) - # flash implementation if self.flash: output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: - # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) @@ -146,13 +126,11 @@ class Attention(nn.Module): scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) - # restore time as batch dimension and concat heads output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - # final projection into the residual stream output = self.wo(output) output = self.resid_dropout(output) - return output, past_kv + return output class FeedForward(nn.Module): @@ -182,7 +160,6 @@ class MoEGate(nn.Module): self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux - # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.dim self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) @@ -194,7 +171,7 @@ class MoEGate(nn.Module): def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape - ### compute gating score + hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == 'softmax': @@ -202,19 +179,15 @@ class MoEGate(nn.Module): else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') - ### select top-k experts topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) - ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator - ### expert-level computation auxiliary loss if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k - # always compute aux loss based on the naive greedy topk method topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) @@ -331,11 +304,10 @@ class TransformerBlock(nn.Module): dropout=args.dropout, ) - def forward(self, x, pos_cis, use_kv_cache=False, past_kv: Tuple[torch.Tensor] = None): - attn_res, past_kv = self.attention(self.attention_norm(x), pos_cis, use_kv_cache, past_kv) - h = x + attn_res + def forward(self, x, pos_cis, use_kv_cache=False): + h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache) out = h + self.feed_forward(self.ffn_norm(h)) - return out, past_kv + return out class Transformer(PreTrainedModel): @@ -357,22 +329,16 @@ class Transformer(PreTrainedModel): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) - - # share the unembedding parameters with the embedding parameters - self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying - - # some useful precompute for the RoPE relative positional embeddings + self.tok_embeddings.weight = self.output.weight pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("pos_cis", pos_cis, persistent=False) - # init all weights self.apply(self._init_weights) - # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): if pn.endswith('w3.weight') or pn.endswith('wo.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers)) - # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. self.last_loss = None self.OUT = CausalLMOutputWithPast() @@ -384,11 +350,8 @@ class Transformer(PreTrainedModel): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens: Optional[torch.Tensor] = None, - targets: Optional[torch.Tensor] = None, - use_kv_cache=False, past_kvs=None, **keyargs): - if past_kvs is None: - past_kvs = [None for _ in range(self.n_layers)] + def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, + use_kv_cache=False, **keyargs): if 'input_ids' in keyargs: tokens = keyargs['input_ids'] if 'attention_mask' in keyargs: @@ -399,63 +362,45 @@ class Transformer(PreTrainedModel): h = self.dropout(h) pos_cis = self.pos_cis[:seqlen] for idx, layer in enumerate(self.layers): - h, past_kvs[idx] = layer(h, pos_cis, use_kv_cache, past_kvs[idx]) + h = layer(h, pos_cis, use_kv_cache) h = self.norm(h) if targets is not None: - # if we are given some desired targets also calculate the loss logits = self.output(h) self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: - # inference-time mini-optimization: only forward the output on the very last position - logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim + logits = self.output(h[:, [-1], :]) self.last_loss = None self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('last_loss', self.last_loss) - if use_kv_cache: - return self.OUT, past_kvs return self.OUT - @torch.inference_mode() - def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.): + def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1., + use_kv_cache=True): index = idx.shape[1] - use_kv_cache = True - past_kvs = [None for _ in range(self.n_layers)] while idx.shape[1] < max_new_tokens - 1: - # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx # if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] - # forward the model to get the logits for the index in the sequence - inference_res = self(idx_cond, use_kv_cache=use_kv_cache, past_kvs=past_kvs) - if use_kv_cache: - logits, past_kvs = inference_res[0].logits, inference_res[1] - else: - logits = inference_res.logits + inference_res = self(idx, use_kv_cache=use_kv_cache) + logits = inference_res.logits + logits = logits[:, -1, :] - logits = logits[:, -1, :] # crop to just the final time step - - # Apply repetition penalty for token in set(idx.tolist()[0]): logits[:, token] /= repetition_penalty if temperature == 0.0: - # "sample" the single most likely index - __, idx_next = torch.topk(logits, k=1, dim=-1) + _, idx_next = torch.topk(logits, k=1, dim=-1) else: - # pluck the logits at the final step and scale by desired temperature logits = logits / temperature - # optionally crop the logits to only the top k options if top_k is not None: - v, __ = torch.topk(logits, min(top_k, logits.size(-1))) + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') - # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1, generator=None) - # append sampled index to the running sequence and continue + if idx_next == eos: break @@ -468,63 +413,8 @@ class Transformer(PreTrainedModel): @torch.inference_mode() def eval_answer(self, idx): - # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] - # forward the model to get the logits for the index in the sequence - past_kvs = [None for _ in range(self.n_layers)] - inference_res = self(idx_cond, use_kv_cache=False, past_kvs=past_kvs) + inference_res = self(idx_cond) logits = inference_res.logits logits = logits[:, -1, :] return logits - - def export(self, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - - def serialize(t): - d = t.detach().cpu().view(-1).numpy().astype(np.float32) - b = struct.pack(f'{len(d)}f', *d) - f.write(b) - - # first write out the header - hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] - p = self.params - n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) - f.write(header) - - # next write out the embedding weights - serialize(self.tok_embeddings.weight) - - # now all the layers - # attention weights - for layer in self.layers: - serialize(layer.attention_norm.weight) - for layer in self.layers: - serialize(layer.attention.wq.weight) - for layer in self.layers: - serialize(layer.attention.wk.weight) - for layer in self.layers: - serialize(layer.attention.wv.weight) - for layer in self.layers: - serialize(layer.attention.wo.weight) - # ffn weights - for layer in self.layers: - serialize(layer.ffn_norm.weight) - for layer in self.layers: - serialize(layer.feed_forward.w1.weight) - for layer in self.layers: - serialize(layer.feed_forward.w2.weight) - for layer in self.layers: - serialize(layer.feed_forward.w3.weight) - # final rmsnorm - serialize(self.norm.weight) - # note: no need to write final classifier weights due to weight sharing - # pos_cis - serialize(self.freqs_cos[:p.max_seq_len]) - serialize(self.freqs_sin[:p.max_seq_len]) - - # write to binary file - f.close() - print(f"wrote {filepath}") \ No newline at end of file