update model mask
This commit is contained in:
parent
02297df3c1
commit
6759da45c1
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
|
||||
bsz, seqlen, _ = x.shape
|
||||
@ -338,6 +338,7 @@ class Transformer(PreTrainedModel):
|
||||
|
||||
self.last_loss = None
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self._no_split_modules = [name for name, _ in self.named_modules()]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
Loading…
x
Reference in New Issue
Block a user