diff --git a/model/model.py b/model/model.py index 36a5cb4..e858ec7 100644 --- a/model/model.py +++ b/model/model.py @@ -85,7 +85,7 @@ class Attention(nn.Module): # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) - self.register_buffer("mask", mask) + self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False): bsz, seqlen, _ = x.shape @@ -338,6 +338,7 @@ class Transformer(PreTrainedModel): self.last_loss = None self.OUT = CausalLMOutputWithPast() + self._no_split_modules = [name for name, _ in self.named_modules()] def _init_weights(self, module): if isinstance(module, nn.Linear):