diff --git a/model/model.py b/model/model.py index 5976ad2..20535fb 100644 --- a/model/model.py +++ b/model/model.py @@ -27,15 +27,11 @@ class RMSNorm(torch.nn.Module): return output * self.weight -def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0, train_len: int = 512): +def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - # # 计算缩放因子 - # scale = train_len / end - # # 缩放旋转嵌入,实现线性的长度外推(注释掉不用是因为小模型依赖pos_cis拟合严重,直接做线性外推效果并不好) - # pos_cis = pos_cis * scale return pos_cis