mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-27 22:04:31 +00:00
remove unnecessary check to make the logic in CausalSelfAttention.forward() clearer
This commit is contained in:
commit
849d95ae1f
|
|
@ -98,8 +98,7 @@ class CausalSelfAttention(nn.Module):
|
|||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
attn_mask[:, :prefix_len] = True
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user