diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 5d27e5f..15411de 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -71,27 +71,26 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # Single token generation if Tq == 1: + if window >= 0 and window < Tk: + # window is "left" tokens we need to include (window + 1) keys total + start = max(0, Tk - (window + 1)) + k = k[:, :, start:, :] + v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Need explicit mask + # Need explicit mask for sliding window/chunk inference device = q.device - if Tq == Tk: - # Causal + sliding window - mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool)) - if window > 0 and window < Tq: - row_idx = torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - mask = mask & ((row_idx - col_idx) <= window) - else: - # Chunk inference: attend to prefix + causal within chunk - prefix_len = Tk - Tq - mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool) - mask[:, :prefix_len] = True - mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool)) + # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = col_idx <= row_idx + # sliding window (left) + if window >= 0 and window < Tk: + mask = mask & ((row_idx - col_idx) <= window) + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) - # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 2cf3ed7..9741c7f 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -178,6 +178,39 @@ class TestFA3VsSDPA: max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token") print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + def test_kvcache_single_token_sliding_window(self): + """Test single token decode with sliding window smaller than cache size. + + This catches the bug where SDPA ignores window_size during Tq=1 decode. + When window < Tk, FA3 only attends to the last (window+1) tokens, + but SDPA was attending to all cached tokens. + """ + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 32 # Enough tokens to exceed window + window = 8 # Window SMALLER than cache size + + k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_cache[:, :T_prefill, :, :] = k_init + v_cache[:, :T_prefill, :, :] = v_init + cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q_single, k_cache, v_cache, k=k_single, v=v_single, + cache_seqlens=cache_seqlens, + causal=True, window_size=(window, 0) # window=8 < Tk=33 + ) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window") + print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + def test_backward_gradients_match(self): """Verify gradients are similar between FA3 and SDPA.""" B, T, H, D = 2, 32, 4, 16