This commit is contained in:
Kartik Vashishta 2026-01-26 17:59:30 -08:00 committed by GitHub
commit 68168c3522
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -69,29 +69,29 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
device = q.device
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# window is "left" tokens we need to include (window + 1) keys total
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask for sliding window/chunk inference
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# Need explicit mask
device = q.device
if Tq == Tk:
# Causal + sliding window
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
if window > 0 and window < Tq:
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = mask & ((row_idx - col_idx) <= window)
else:
# Chunk inference: attend to prefix + causal within chunk
prefix_len = Tk - Tq
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
mask[:, :prefix_len] = True
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================