diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 5d27e5f..7c99149 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -69,29 +69,29 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): if (window < 0 or window >= Tq) and Tq == Tk: return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + device = q.device + # Single token generation if Tq == 1: + if window >= 0 and window < Tk: + # window is "left" tokens we need to include (window + 1) keys total + start = max(0, Tk - (window + 1)) + k = k[:, :, start:, :] + v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) + + # Need explicit mask for sliding window/chunk inference + # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = col_idx <= row_idx - # Need explicit mask - device = q.device - if Tq == Tk: - # Causal + sliding window - mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool)) - if window > 0 and window < Tq: - row_idx = torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - mask = mask & ((row_idx - col_idx) <= window) - else: - # Chunk inference: attend to prefix + causal within chunk - prefix_len = Tk - Tq - mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool) - mask[:, :prefix_len] = True - mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool)) - + # sliding window (left) + if window >= 0 and window < Tk: + mask = mask & ((row_idx - col_idx) <= window) + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) - # ============================================================================= # Public API: Same interface as FA3 # =============================================================================