diff --git a/model/model.py b/model/model.py
index 070a8f2..3f49236 100644
--- a/model/model.py
+++ b/model/model.py
@@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
         return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
 
 
-def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
+def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
     t = torch.arange(end, device=freqs.device)  # type: ignore
     freqs = torch.outer(t, freqs).float()  # type: ignore
@@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
         self.norm = RMSNorm(params.dim, eps=params.norm_eps)
         self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
         self.tok_embeddings.weight = self.output.weight
-        self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
-                                                           theta=params.rope_theta), persistent=False)
+        self.register_buffer("pos_cis",
+                             precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
+                             persistent=False)
         self.OUT = CausalLMOutputWithPast()
 
     def forward(self,