rmsnorm float convert

This commit is contained in:
jingyaogong 2025-04-01 16:03:44 +08:00
parent e369b33265
commit bf81fd5f5e

View File

@ -14,13 +14,16 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):