mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-09 03:59:52 +00:00
Fix SDPA KV-cache decode to respect sliding window (#456)
SDPA fallback now respects sliding window during single-token KV-cache decode by slicing K/V to the last (window + 1) tokens. Also simplifies the mask building for chunk inference to properly apply sliding window in that path as well. Fixes #452 Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
ace6740bdd
commit
3ba42e8135
|
|
@ -71,27 +71,26 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
if window >= 0 and window < Tk:
|
||||
# window is "left" tokens we need to include (window + 1) keys total
|
||||
start = max(0, Tk - (window + 1))
|
||||
k = k[:, :, start:, :]
|
||||
v = v[:, :, start:, :]
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask
|
||||
# Need explicit mask for sliding window/chunk inference
|
||||
device = q.device
|
||||
if Tq == Tk:
|
||||
# Causal + sliding window
|
||||
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
||||
if window > 0 and window < Tq:
|
||||
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
else:
|
||||
# Chunk inference: attend to prefix + causal within chunk
|
||||
prefix_len = Tk - Tq
|
||||
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
||||
mask[:, :prefix_len] = True
|
||||
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -178,6 +178,39 @@ class TestFA3VsSDPA:
|
|||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token_sliding_window(self):
|
||||
"""Test single token decode with sliding window smaller than cache size.
|
||||
|
||||
This catches the bug where SDPA ignores window_size during Tq=1 decode.
|
||||
When window < Tk, FA3 only attends to the last (window+1) tokens,
|
||||
but SDPA was attending to all cached tokens.
|
||||
"""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 32 # Enough tokens to exceed window
|
||||
window = 8 # Window SMALLER than cache size
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(window, 0) # window=8 < Tk=33
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window")
|
||||
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user