update cis init
This commit is contained in:
parent
c1a77f5c0f
commit
5b65bc767e
@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
|
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||||
@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||||
self.tok_embeddings.weight = self.output.weight
|
self.tok_embeddings.weight = self.output.weight
|
||||||
self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
|
self.register_buffer("pos_cis",
|
||||||
theta=params.rope_theta), persistent=False)
|
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||||
|
persistent=False)
|
||||||
self.OUT = CausalLMOutputWithPast()
|
self.OUT = CausalLMOutputWithPast()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user