From bf81fd5f5ef4f17ab2ef258860dba1bbab504d42 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Tue, 1 Apr 2025 16:03:44 +0800 Subject: [PATCH] rmsnorm float convert --- model/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model/model.py b/model/model.py index 1dc6478..6316a04 100644 --- a/model/model.py +++ b/model/model.py @@ -14,13 +14,16 @@ from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float): + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def forward(self, x): - return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) + return self.weight * self._norm(x.float()).type_as(x) def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):