diff --git a/model/model.py b/model/model.py index 1dc6478..6316a04 100644 --- a/model/model.py +++ b/model/model.py @@ -14,13 +14,16 @@ from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float): + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def forward(self, x): - return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) + return self.weight * self._norm(x.float()).type_as(x) def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):