update cis init

This commit is contained in:
gongjy 2025-02-15 20:26:34 +08:00
parent c1a77f5c0f
commit 5b65bc767e

View File

@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
def precompute_pos_cis(dim: int, end: int, theta: float = 1e4): def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore
@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len, self.register_buffer("pos_cis",
theta=params.rope_theta), persistent=False) precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast() self.OUT = CausalLMOutputWithPast()
def forward(self, def forward(self,