update model mask

This commit is contained in:
gongjy 2024-09-21 20:00:25 +08:00
parent 02297df3c1
commit 6759da45c1

View File

@ -85,7 +85,7 @@ class Attention(nn.Module):
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
bsz, seqlen, _ = x.shape
@ -338,6 +338,7 @@ class Transformer(PreTrainedModel):
self.last_loss = None
self.OUT = CausalLMOutputWithPast()
self._no_split_modules = [name for name, _ in self.named_modules()]
def _init_weights(self, module):
if isinstance(module, nn.Linear):