From f8ca0b5c21ae591858c18a76b88f570143d9e7ae Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Fri, 15 May 2026 23:16:01 +0300 Subject: [PATCH] Add chunk mask caching and update attention functions for fixing some performance choking --- nanochat/flash_attention.py | 74 +++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..5189a4d0 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -63,10 +63,36 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() +# ============================================================================= +# Mask cache for chunked inference (Tq != Tk, Tq > 1) +# ============================================================================= +_MASK_CACHE: dict = {} +_MASK_CACHE_MAX = 32 + + +def _get_chunk_mask(device, Tq, Tk, window): + """Cached causal (+sliding window) mask for chunk inference.""" + key = (device.type, device.index, Tq, Tk, window) + m = _MASK_CACHE.get(key) + if m is not None: + return m + + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + m = col_idx <= row_idx + if window >= 0 and window < Tk: + m = m & ((row_idx - col_idx) <= window) + + if len(_MASK_CACHE) >= _MASK_CACHE_MAX: + _MASK_CACHE.clear() + _MASK_CACHE[key] = m + return m + + # ============================================================================= # SDPA helpers # ============================================================================= -def _sdpa_attention(q, k, v, window_size, enable_gqa): +def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): """ SDPA attention with sliding window support. q, k, v are (B, H, T, D) format. @@ -77,7 +103,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=enable_gqa) # Single token generation if Tq == 1: @@ -88,19 +114,11 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Need explicit mask for sliding window/chunk inference - device = q.device - # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask - row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - mask = col_idx <= row_idx - - # sliding window (left) - if window >= 0 and window < Tk: - mask = mask & ((row_idx - col_idx) <= window) - + # Chunk inference (Tq > 1, Tq != Tk): use cached explicit bool mask. + mask = _get_chunk_mask(q.device, Tq, Tk, window) return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= @@ -124,7 +142,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): k = k.transpose(1, 2) v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) - y = _sdpa_attention(q, k, v, window_size, enable_gqa) + y = _sdpa_attention(q, k, v, window_size, enable_gqa, causal=causal) return y.transpose(1, 2) # back to (B, T, H, D) @@ -139,7 +157,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q: Queries, shape (B, T_new, H, D) k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) k, v: New keys/values to insert, shape (B, T_new, H_kv, D) - cache_seqlens: Current position in cache, shape (B,) int32 + cache_seqlens: Current position in cache. Either an int (fast path, no + GPU->CPU sync) or a tensor of shape (B,) int32 (FA3-compatible). causal: Whether to use causal masking window_size: (left, right) sliding window. -1 means unlimited. @@ -154,17 +173,32 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape - pos = cache_seqlens[0].item() # assume uniform position across batch + + # Avoid GPU->CPU sync if caller passes a Python int. + if isinstance(cache_seqlens, int): + pos = cache_seqlens + elif isinstance(cache_seqlens, torch.Tensor): + pos = int(cache_seqlens[0].item()) # assume uniform position across batch + else: + pos = int(cache_seqlens) # Insert new k, v into cache (in-place, matching FA3 behavior) if k is not None and v is not None: k_cache[:, pos:pos+T_new, :, :] = k v_cache[:, pos:pos+T_new, :, :] = v - # Get full cache up to current position + new tokens end_pos = pos + T_new - k_full = k_cache[:, :end_pos, :, :] - v_full = v_cache[:, :end_pos, :, :] + + # Sliding-window single-token decode: trim cache slice early so SDPA sees + # only the window instead of the full prefix. + window = window_size[0] + if T_new == 1 and 0 <= window < end_pos: + start = max(0, end_pos - (window + 1)) + k_full = k_cache[:, start:end_pos, :, :] + v_full = v_cache[:, start:end_pos, :, :] + else: + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) q_sdpa = q.transpose(1, 2) @@ -172,7 +206,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N v_sdpa = v_full.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa, causal=causal) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)