From 5c93a56be5ea24e87c5756126afe8f4b5fb1458b Mon Sep 17 00:00:00 2001 From: Eric Silberstein Date: Wed, 19 Nov 2025 16:31:41 -0500 Subject: [PATCH] remove unnecessary check --- nanochat/gpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..45ec847 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -98,8 +98,7 @@ class CausalSelfAttention(nn.Module): # First, each query attends to all the cached keys/values (i.e. full prefix) attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask prefix_len = Tk - Tq - if prefix_len > 0: # can't be negative but could be zero - attn_mask[:, :prefix_len] = True + attn_mask[:, :prefix_len] = True # Then, causal attention within this chunk attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)