add Flash Attention 2 as a middle tier between FA3 and SDPA

on sm80+ non-Hopper GPUs (Blackwell, Ada, Ampere) with the flash-attn package installed, FA2 kernels replace the SDPA fallback. priority is FA3 > FA2 > SDPA. measured 28% faster than SDPA on GB10, and makes sliding-window attention fast on Blackwell (where FA3 is unavailable). no effect on H100: USE_FA3 wins whenever available so runs/speedrun.sh on 8xH100 runs the same kernels as before. tests/test_attention_fallback.py::TestFA2VsSDPA compares FA2 and SDPA output on any sm80+ GPU with flash-attn installed.

context: https://github.com/karpathy/nanochat/discussions/710 (the writeup was produced from my dgx-spark branch at https://github.com/matt-langston/nanochat/tree/dgx-spark, which carries these two PRs plus a DGX-Spark-Bundle-specific speedrun script I kept separate)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Matt Langston 2026-04-17 19:04:44 -07:00
parent 0aaca56805
commit 75bd386b8e
No known key found for this signature in database
GPG Key ID: 181CABA5854FEEC2
4 changed files with 284 additions and 12 deletions

View File

@ -1,8 +1,9 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Unified Flash Attention interface with automatic FA3/FA2/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
If the flash-attn package is installed, FA2 kernels are used instead of SDPA on sm80+ GPUs (FA3 wins if both are usable).
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -38,10 +39,25 @@ def _load_flash_attention_3():
return None
_fa3 = _load_flash_attention_3()
HAS_FA3 = _fa3 is not None
def _load_flash_attention_2():
"""Try to load Flash Attention 2 (requires flash-attn package, sm80+)."""
if not torch.cuda.is_available():
return None
try:
import flash_attn.flash_attn_interface as fa2
if hasattr(fa2, 'flash_attn_func') and hasattr(fa2, 'flash_attn_with_kvcache'):
return fa2
return None
except Exception:
return None
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
_fa3 = _load_flash_attention_3()
_fa2 = _load_flash_attention_2()
HAS_FA3 = _fa3 is not None
HAS_FA2 = _fa2 is not None
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
_override_impl = None
@ -50,6 +66,8 @@ def _resolve_use_fa3():
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl == 'fa2':
return False
if _override_impl == 'sdpa':
return False
if HAS_FA3:
@ -63,6 +81,23 @@ def _resolve_use_fa3():
USE_FA3 = _resolve_use_fa3()
def _resolve_use_fa2():
"""Decide once whether to use FA2, based on availability, override, and dtype. FA3 wins if both are usable."""
if _override_impl == 'fa2':
assert HAS_FA2, "Cannot override to FA2: flash-attn package not available"
return True
if _override_impl == 'sdpa' or USE_FA3:
return False
if HAS_FA2:
# FA2 supports bf16 and fp16, not fp32
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE != torch.float32:
return True
return False
USE_FA2 = _resolve_use_fa2()
# =============================================================================
# SDPA helpers
# =============================================================================
@ -118,6 +153,8 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
if USE_FA2:
return _fa2.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
@ -151,6 +188,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
if USE_FA2:
return _fa2.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape

View File

@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_FA2
from scripts.base_eval import evaluate_core
print_banner()
@ -100,10 +100,12 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
from nanochat.flash_attention import USE_FA3
from nanochat.flash_attention import USE_FA3, USE_FA2
using_fa3 = USE_FA3
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
elif USE_FA2:
print0("Using Flash Attention 2 (flash-attn package detected)")
else:
print0("!" * 80)
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:

View File

@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_FA2
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
@ -89,7 +89,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Flash Attention status
if not HAS_FA3:
if not HAS_FA3 and not HAS_FA2:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
# Load the model and tokenizer

View File

@ -1,10 +1,10 @@
"""
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
Test Flash Attention unified interface - verify FA3, FA2, and SDPA produce identical results.
Run: python -m pytest tests/test_attention_fallback.py -v -s
Note on test structure:
Tests are split into two classes due to dtype/device constraints:
Tests are split into three classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
@ -12,18 +12,22 @@ Note on test structure:
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
3. TestFA2VsSDPA: Same as (1), but comparing FA2 (flash-attn package, sm80+) against SDPA.
Runs on any sm80+ GPU and is the key correctness test on Blackwell where FA3 is unavailable.
"""
import torch
import pytest
import nanochat.flash_attention as fa_module
from nanochat.flash_attention import flash_attn, HAS_FA3
from nanochat.flash_attention import flash_attn, HAS_FA3, HAS_FA2
from nanochat.engine import KVCache
def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
"""Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA3/USE_FA2."""
fa_module._override_impl = impl
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
fa_module.USE_FA2 = fa_module._resolve_use_fa2()
def run_both_impls(fn):
@ -36,6 +40,16 @@ def run_both_impls(fn):
return out_fa3, out_sdpa
def run_fa2_and_sdpa(fn):
"""Run a function with both FA2 and SDPA, return both outputs."""
set_impl('fa2')
out_fa2 = fn()
set_impl('sdpa')
out_sdpa = fn()
set_impl(None) # reset
return out_fa2, out_sdpa
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
"""Assert two tensors are close, with helpful error message."""
max_diff = (t1 - t2).abs().max().item()
@ -334,6 +348,206 @@ class TestSDPAOnly:
set_impl(None)
# =============================================================================
# FA2 vs SDPA comparison tests (require flash-attn package)
# =============================================================================
@pytest.mark.skipif(not HAS_FA2, reason="FA2 required (flash-attn package)")
class TestFA2VsSDPA:
"""Compare FA2 and SDPA produce identical results. Requires flash-attn package (sm80+)."""
DEVICE = "cuda"
DTYPE = torch.bfloat16
def test_basic_causal(self):
"""Basic causal attention."""
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "basic_causal")
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_full_context(self):
"""Full context (window_size=-1)."""
B, T, H, D = 2, 128, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "full_context")
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_sliding_window(self):
"""Sliding window attention (key test: FA2 has native sliding kernels, SDPA uses an explicit mask - both must agree so GB10 can run SSSL instead of L)."""
B, T, H, D = 2, 128, 4, 32
window = 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window")
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_sliding_window_realistic(self):
"""Realistic sliding window matching depth=24 config (n_embd=1536, n_head=12, head_dim=128, seq_len=2048, SSSL short_window=768)."""
B, T, H, D = 2, 256, 12, 128 # scaled-down seq_len for speed
window = 768
n_kv_heads = 4 # GQA
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window_realistic")
print(f"sliding_window_realistic: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_gqa(self):
"""Group Query Attention (fewer KV heads than Q heads)."""
B, T, D = 2, 64, 32
n_heads = 8
n_kv_heads = 2
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "gqa")
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_prefill(self):
"""Test prefill (inserting multiple tokens into empty cache)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "prefill")
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token(self):
"""Test single token generation (cache already has content)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token")
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token_sliding_window(self):
"""Test single token decode with sliding window smaller than cache size (SSSL inference path)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 32 # Enough tokens to exceed window
window = 8 # Window SMALLER than cache size
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(window, 0) # window=8 < Tk=33
)
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token_sliding_window")
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA2 and SDPA."""
B, T, H, D = 2, 32, 4, 16
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
q = q_data.clone().requires_grad_(True)
k = k_data.clone().requires_grad_(True)
v = v_data.clone().requires_grad_(True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
set_impl('fa2')
y_fa2, q_grad_fa2, k_grad_fa2, v_grad_fa2 = run()
set_impl('sdpa')
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
set_impl(None)
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "backward_output")
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(q_grad_fa2, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(k_grad_fa2, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(v_grad_fa2, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
# =============================================================================
# Override mechanism tests
# =============================================================================
@ -358,6 +572,19 @@ class TestOverrideMechanism:
set_impl(None)
assert fa_module.USE_FA3 == HAS_FA3
@pytest.mark.skipif(not HAS_FA2, reason="FA2 required")
def test_override_fa2(self):
"""Test that override='fa2' uses FA2."""
set_impl('fa2')
assert fa_module.USE_FA2 == True
set_impl(None)
@pytest.mark.skipif(not (HAS_FA2 and not HAS_FA3), reason="FA2 auto-selection only applies when FA3 is unavailable")
def test_override_auto_fa2(self):
"""Test that override=None picks FA2 when FA3 is unavailable."""
set_impl(None)
assert fa_module.USE_FA2 == True
if __name__ == "__main__":
print(f"PyTorch version: {torch.__version__}")
@ -367,6 +594,7 @@ if __name__ == "__main__":
major, minor = torch.cuda.get_device_capability()
print(f"Compute capability: {major}.{minor}")
print(f"HAS_FA3: {HAS_FA3}")
print(f"HAS_FA2: {HAS_FA2}")
print()
pytest.main([__file__, "-v", "-s"])