From 5b65bc767e7638bcac0fd2121f3078d7a6d7d592 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sat, 15 Feb 2025 20:26:34 +0800 Subject: [PATCH] update cis init --- model/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model/model.py b/model/model.py index 070a8f2..3f49236 100644 --- a/model/model.py +++ b/model/model.py @@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module): return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) -def precompute_pos_cis(dim: int, end: int, theta: float = 1e4): +def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore @@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel): self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight - self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len, - theta=params.rope_theta), persistent=False) + self.register_buffer("pos_cis", + precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), + persistent=False) self.OUT = CausalLMOutputWithPast() def forward(self,