remove unnecessary check to make the logic in CausalSelfAttention.forward() clearer

This commit is contained in:
Andrej 2025-12-08 18:30:37 -08:00 committed by GitHub
commit 849d95ae1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -98,8 +98,7 @@ class CausalSelfAttention(nn.Module):
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)