mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Add chunk mask caching and update attention functions for fixing some performance choking
This commit is contained in:
parent
dc54a1a307
commit
f8ca0b5c21
|
|
@ -63,10 +63,36 @@ def _resolve_use_fa3():
|
|||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mask cache for chunked inference (Tq != Tk, Tq > 1)
|
||||
# =============================================================================
|
||||
_MASK_CACHE: dict = {}
|
||||
_MASK_CACHE_MAX = 32
|
||||
|
||||
|
||||
def _get_chunk_mask(device, Tq, Tk, window):
|
||||
"""Cached causal (+sliding window) mask for chunk inference."""
|
||||
key = (device.type, device.index, Tq, Tk, window)
|
||||
m = _MASK_CACHE.get(key)
|
||||
if m is not None:
|
||||
return m
|
||||
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
m = col_idx <= row_idx
|
||||
if window >= 0 and window < Tk:
|
||||
m = m & ((row_idx - col_idx) <= window)
|
||||
|
||||
if len(_MASK_CACHE) >= _MASK_CACHE_MAX:
|
||||
_MASK_CACHE.clear()
|
||||
_MASK_CACHE[key] = m
|
||||
return m
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
q, k, v are (B, H, T, D) format.
|
||||
|
|
@ -77,7 +103,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
|
||||
# Full context, same length
|
||||
if (window < 0 or window >= Tq) and Tq == Tk:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=enable_gqa)
|
||||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
|
|
@ -88,19 +114,11 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
v = v[:, :, start:, :]
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask for sliding window/chunk inference
|
||||
device = q.device
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
# Chunk inference (Tq > 1, Tq != Tk): use cached explicit bool mask.
|
||||
mask = _get_chunk_mask(q.device, Tq, Tk, window)
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
|
|
@ -124,7 +142,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa, causal=causal)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
|
|
@ -139,7 +157,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
q: Queries, shape (B, T_new, H, D)
|
||||
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
||||
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
||||
cache_seqlens: Current position in cache, shape (B,) int32
|
||||
cache_seqlens: Current position in cache. Either an int (fast path, no
|
||||
GPU->CPU sync) or a tensor of shape (B,) int32 (FA3-compatible).
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
|
|
@ -154,17 +173,32 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
pos = cache_seqlens[0].item() # assume uniform position across batch
|
||||
|
||||
# Avoid GPU->CPU sync if caller passes a Python int.
|
||||
if isinstance(cache_seqlens, int):
|
||||
pos = cache_seqlens
|
||||
elif isinstance(cache_seqlens, torch.Tensor):
|
||||
pos = int(cache_seqlens[0].item()) # assume uniform position across batch
|
||||
else:
|
||||
pos = int(cache_seqlens)
|
||||
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
|
||||
# Get full cache up to current position + new tokens
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
# Sliding-window single-token decode: trim cache slice early so SDPA sees
|
||||
# only the window instead of the full prefix.
|
||||
window = window_size[0]
|
||||
if T_new == 1 and 0 <= window < end_pos:
|
||||
start = max(0, end_pos - (window + 1))
|
||||
k_full = k_cache[:, start:end_pos, :, :]
|
||||
v_full = v_cache[:, start:end_pos, :, :]
|
||||
else:
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
|
|
@ -172,7 +206,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa, causal=causal)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user