diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 5189a4d0..af2aee32 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -63,36 +63,10 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() -# ============================================================================= -# Mask cache for chunked inference (Tq != Tk, Tq > 1) -# ============================================================================= -_MASK_CACHE: dict = {} -_MASK_CACHE_MAX = 32 - - -def _get_chunk_mask(device, Tq, Tk, window): - """Cached causal (+sliding window) mask for chunk inference.""" - key = (device.type, device.index, Tq, Tk, window) - m = _MASK_CACHE.get(key) - if m is not None: - return m - - row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - m = col_idx <= row_idx - if window >= 0 and window < Tk: - m = m & ((row_idx - col_idx) <= window) - - if len(_MASK_CACHE) >= _MASK_CACHE_MAX: - _MASK_CACHE.clear() - _MASK_CACHE[key] = m - return m - - # ============================================================================= # SDPA helpers # ============================================================================= -def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): +def _sdpa_attention(q, k, v, window_size, enable_gqa): """ SDPA attention with sliding window support. q, k, v are (B, H, T, D) format. @@ -103,7 +77,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) # Single token generation if Tq == 1: @@ -114,10 +88,18 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Chunk inference (Tq > 1, Tq != Tk): use cached explicit bool mask. - mask = _get_chunk_mask(q.device, Tq, Tk, window) - return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + # Need explicit mask for sliding window/chunk inference + device = q.device + # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = col_idx <= row_idx + # sliding window (left) + if window >= 0 and window < Tk: + mask = mask & ((row_idx - col_idx) <= window) + + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) # ============================================================================= # Public API: Same interface as FA3 @@ -142,7 +124,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): k = k.transpose(1, 2) v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) - y = _sdpa_attention(q, k, v, window_size, enable_gqa, causal=causal) + y = _sdpa_attention(q, k, v, window_size, enable_gqa) return y.transpose(1, 2) # back to (B, T, H, D) @@ -157,8 +139,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q: Queries, shape (B, T_new, H, D) k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) k, v: New keys/values to insert, shape (B, T_new, H_kv, D) - cache_seqlens: Current position in cache. Either an int (fast path, no - GPU->CPU sync) or a tensor of shape (B,) int32 (FA3-compatible). + cache_seqlens: Current position in cache, shape (B,) int32 causal: Whether to use causal masking window_size: (left, right) sliding window. -1 means unlimited. @@ -173,32 +154,17 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape - - # Avoid GPU->CPU sync if caller passes a Python int. - if isinstance(cache_seqlens, int): - pos = cache_seqlens - elif isinstance(cache_seqlens, torch.Tensor): - pos = int(cache_seqlens[0].item()) # assume uniform position across batch - else: - pos = int(cache_seqlens) + pos = cache_seqlens[0].item() # assume uniform position across batch # Insert new k, v into cache (in-place, matching FA3 behavior) if k is not None and v is not None: k_cache[:, pos:pos+T_new, :, :] = k v_cache[:, pos:pos+T_new, :, :] = v + # Get full cache up to current position + new tokens end_pos = pos + T_new - - # Sliding-window single-token decode: trim cache slice early so SDPA sees - # only the window instead of the full prefix. - window = window_size[0] - if T_new == 1 and 0 <= window < end_pos: - start = max(0, end_pos - (window + 1)) - k_full = k_cache[:, start:end_pos, :, :] - v_full = v_cache[:, start:end_pos, :, :] - else: - k_full = k_cache[:, :end_pos, :, :] - v_full = v_cache[:, :end_pos, :, :] + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) q_sdpa = q.transpose(1, 2) @@ -206,7 +172,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N v_sdpa = v_full.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa, causal=causal) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)