From b62a5bc44aafef02eba6c39236180aa424a82674 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 17:39:41 +0000 Subject: [PATCH] naturally i failed to include the actual code in the previous commit facepalm --- nanochat/flash_attention.py | 178 ++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 nanochat/flash_attention.py diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py new file mode 100644 index 0000000..5d27e5f --- /dev/null +++ b/nanochat/flash_attention.py @@ -0,0 +1,178 @@ +""" +Unified Flash Attention interface with automatic FA3/SDPA switching. + +Exports `flash_attn` module that matches the FA3 API exactly, but falls back +to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU. + +Usage (drop-in replacement for FA3): + from nanochat.flash_attention import flash_attn + + # Training (no KV cache) + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + + # Inference (with KV cache) + y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...) +""" +import torch +import torch.nn.functional as F + + +# ============================================================================= +# Detection: Try to load FA3 on Hopper+ GPUs +# ============================================================================= +def _load_flash_attention_3(): + """Try to load Flash Attention 3 (requires Hopper+ GPU).""" + if not torch.cuda.is_available(): + return None + try: + major, _ = torch.cuda.get_device_capability() + if major < 9: # Hopper is sm90 + return None + import os + os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + from kernels import get_kernel + return get_kernel('varunneal/flash-attention-3').flash_attn_interface + except Exception: + return None + + +_fa3 = _load_flash_attention_3() +HAS_FA3 = _fa3 is not None + +# Override for testing: set to 'fa3', 'sdpa', or None (auto) +_override_impl = None + + +def _use_fa3(): + """Determine whether to use FA3 based on availability and override.""" + if _override_impl == 'fa3': + assert HAS_FA3, "Cannot override to FA3: not available on this hardware" + return True + if _override_impl == 'sdpa': + return False + return HAS_FA3 # auto + + +# ============================================================================= +# SDPA helpers +# ============================================================================= +def _sdpa_attention(q, k, v, window_size, enable_gqa): + """ + SDPA attention with sliding window support. + q, k, v are (B, H, T, D) format. + """ + Tq = q.size(2) + Tk = k.size(2) + window = window_size[0] + + # Full context, same length + if (window < 0 or window >= Tq) and Tq == Tk: + return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + + # Single token generation + if Tq == 1: + return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) + + # Need explicit mask + device = q.device + if Tq == Tk: + # Causal + sliding window + mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool)) + if window > 0 and window < Tq: + row_idx = torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = mask & ((row_idx - col_idx) <= window) + else: + # Chunk inference: attend to prefix + causal within chunk + prefix_len = Tk - Tq + mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool) + mask[:, :prefix_len] = True + mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool)) + + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + + +# ============================================================================= +# Public API: Same interface as FA3 +# ============================================================================= +def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): + """ + Flash Attention for training (no KV cache). + + Args: + q, k, v: Tensors of shape (B, T, H, D) + causal: Whether to use causal masking + window_size: (left, right) sliding window. -1 means unlimited. + + Returns: + Output tensor of shape (B, T, H, D) + """ + if _use_fa3(): + return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + + # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + enable_gqa = q.size(1) != k.size(1) + y = _sdpa_attention(q, k, v, window_size, enable_gqa) + return y.transpose(1, 2) # back to (B, T, H, D) + + +def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None, + causal=False, window_size=(-1, -1)): + """ + Flash Attention with KV cache for inference. + + FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same. + + Args: + q: Queries, shape (B, T_new, H, D) + k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) + k, v: New keys/values to insert, shape (B, T_new, H_kv, D) + cache_seqlens: Current position in cache, shape (B,) int32 + causal: Whether to use causal masking + window_size: (left, right) sliding window. -1 means unlimited. + + Returns: + Output tensor of shape (B, T_new, H, D) + """ + if _use_fa3(): + return _fa3.flash_attn_with_kvcache( + q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, + causal=causal, window_size=window_size + ) + + # SDPA fallback: manually manage KV cache + B, T_new, H, D = q.shape + pos = cache_seqlens[0].item() # assume uniform position across batch + + # Insert new k, v into cache (in-place, matching FA3 behavior) + if k is not None and v is not None: + k_cache[:, pos:pos+T_new, :, :] = k + v_cache[:, pos:pos+T_new, :, :] = v + + # Get full cache up to current position + new tokens + end_pos = pos + T_new + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] + + # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) + q_sdpa = q.transpose(1, 2) + k_sdpa = k_full.transpose(1, 2) + v_sdpa = v_full.transpose(1, 2) + + enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + + return y_sdpa.transpose(1, 2) # back to (B, T, H, D) + + +# ============================================================================= +# Export: flash_attn module interface (drop-in replacement for FA3) +# ============================================================================= +from types import SimpleNamespace +flash_attn = SimpleNamespace( + flash_attn_func=flash_attn_func, + flash_attn_with_kvcache=flash_attn_with_kvcache, +)