diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..47873755 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -1,8 +1,9 @@ """ -Unified Flash Attention interface with automatic FA3/SDPA switching. +Unified Flash Attention interface with automatic FA3/FA2/SDPA switching. Exports `flash_attn` module that matches the FA3 API exactly, but falls back to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. +If the flash-attn package is installed, FA2 kernels are used instead of SDPA on sm80+ GPUs (FA3 wins if both are usable). Usage (drop-in replacement for FA3): from nanochat.flash_attention import flash_attn @@ -38,10 +39,25 @@ def _load_flash_attention_3(): return None -_fa3 = _load_flash_attention_3() -HAS_FA3 = _fa3 is not None +def _load_flash_attention_2(): + """Try to load Flash Attention 2 (requires flash-attn package, sm80+).""" + if not torch.cuda.is_available(): + return None + try: + import flash_attn.flash_attn_interface as fa2 + if hasattr(fa2, 'flash_attn_func') and hasattr(fa2, 'flash_attn_with_kvcache'): + return fa2 + return None + except Exception: + return None -# Override for testing: set to 'fa3', 'sdpa', or None (auto) + +_fa3 = _load_flash_attention_3() +_fa2 = _load_flash_attention_2() +HAS_FA3 = _fa3 is not None +HAS_FA2 = _fa2 is not None + +# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto) _override_impl = None @@ -50,6 +66,8 @@ def _resolve_use_fa3(): if _override_impl == 'fa3': assert HAS_FA3, "Cannot override to FA3: not available on this hardware" return True + if _override_impl == 'fa2': + return False if _override_impl == 'sdpa': return False if HAS_FA3: @@ -63,6 +81,23 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() +def _resolve_use_fa2(): + """Decide once whether to use FA2, based on availability, override, and dtype. FA3 wins if both are usable.""" + if _override_impl == 'fa2': + assert HAS_FA2, "Cannot override to FA2: flash-attn package not available" + return True + if _override_impl == 'sdpa' or USE_FA3: + return False + if HAS_FA2: + # FA2 supports bf16 and fp16, not fp32 + from nanochat.common import COMPUTE_DTYPE + if COMPUTE_DTYPE != torch.float32: + return True + return False + +USE_FA2 = _resolve_use_fa2() + + # ============================================================================= # SDPA helpers # ============================================================================= @@ -118,6 +153,8 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): """ if USE_FA3: return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + if USE_FA2: + return _fa2.flash_attn_func(q, k, v, causal=causal, window_size=window_size) # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) q = q.transpose(1, 2) @@ -151,6 +188,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) + if USE_FA2: + return _fa2.flash_attn_with_kvcache( + q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, + causal=causal, window_size=window_size + ) # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..c054e43f 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_FA2 from scripts.base_eval import evaluate_core print_banner() @@ -100,10 +100,12 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status -from nanochat.flash_attention import USE_FA3 +from nanochat.flash_attention import USE_FA3, USE_FA2 using_fa3 = USE_FA3 if using_fa3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") +elif USE_FA2: + print0("Using Flash Attention 2 (flash-attn package detected)") else: print0("!" * 80) if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index b46dd817..1ff398e4 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb import torch.distributed as dist -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_FA2 from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval @@ -89,7 +89,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Flash Attention status -if not HAS_FA3: +if not HAS_FA3 and not HAS_FA2: print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") # Load the model and tokenizer diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 3eddc721..96586250 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -1,10 +1,10 @@ """ -Test Flash Attention unified interface - verify FA3 and SDPA produce identical results. +Test Flash Attention unified interface - verify FA3, FA2, and SDPA produce identical results. Run: python -m pytest tests/test_attention_fallback.py -v -s Note on test structure: - Tests are split into two classes due to dtype/device constraints: + Tests are split into three classes due to dtype/device constraints: 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs and verify they produce identical results. These require a Hopper GPU (FA3 only @@ -12,18 +12,22 @@ Note on test structure: 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. + + 3. TestFA2VsSDPA: Same as (1), but comparing FA2 (flash-attn package, sm80+) against SDPA. + Runs on any sm80+ GPU and is the key correctness test on Blackwell where FA3 is unavailable. """ import torch import pytest import nanochat.flash_attention as fa_module -from nanochat.flash_attention import flash_attn, HAS_FA3 +from nanochat.flash_attention import flash_attn, HAS_FA3, HAS_FA2 from nanochat.engine import KVCache def set_impl(impl): - """Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3.""" + """Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA3/USE_FA2.""" fa_module._override_impl = impl fa_module.USE_FA3 = fa_module._resolve_use_fa3() + fa_module.USE_FA2 = fa_module._resolve_use_fa2() def run_both_impls(fn): @@ -36,6 +40,16 @@ def run_both_impls(fn): return out_fa3, out_sdpa +def run_fa2_and_sdpa(fn): + """Run a function with both FA2 and SDPA, return both outputs.""" + set_impl('fa2') + out_fa2 = fn() + set_impl('sdpa') + out_sdpa = fn() + set_impl(None) # reset + return out_fa2, out_sdpa + + def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): """Assert two tensors are close, with helpful error message.""" max_diff = (t1 - t2).abs().max().item() @@ -334,6 +348,206 @@ class TestSDPAOnly: set_impl(None) +# ============================================================================= +# FA2 vs SDPA comparison tests (require flash-attn package) +# ============================================================================= +@pytest.mark.skipif(not HAS_FA2, reason="FA2 required (flash-attn package)") +class TestFA2VsSDPA: + """Compare FA2 and SDPA produce identical results. Requires flash-attn package (sm80+).""" + + DEVICE = "cuda" + DTYPE = torch.bfloat16 + + def test_basic_causal(self): + """Basic causal attention.""" + B, T, H, D = 2, 64, 4, 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "basic_causal") + print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_full_context(self): + """Full context (window_size=-1).""" + B, T, H, D = 2, 128, 4, 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "full_context") + print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_sliding_window(self): + """Sliding window attention (key test: FA2 has native sliding kernels, SDPA uses an explicit mask - both must agree so GB10 can run SSSL instead of L).""" + B, T, H, D = 2, 128, 4, 32 + window = 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window") + print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_sliding_window_realistic(self): + """Realistic sliding window matching depth=24 config (n_embd=1536, n_head=12, head_dim=128, seq_len=2048, SSSL short_window=768).""" + B, T, H, D = 2, 256, 12, 128 # scaled-down seq_len for speed + window = 768 + n_kv_heads = 4 # GQA + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window_realistic") + print(f"sliding_window_realistic: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_gqa(self): + """Group Query Attention (fewer KV heads than Q heads).""" + B, T, D = 2, 64, 32 + n_heads = 8 + n_kv_heads = 2 + + q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "gqa") + print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_kvcache_prefill(self): + """Test prefill (inserting multiple tokens into empty cache).""" + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 16 + + q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, k=k, v=v, + cache_seqlens=cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "prefill") + print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_kvcache_single_token(self): + """Test single token generation (cache already has content).""" + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 16 + + k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_cache[:, :T_prefill, :, :] = k_init + v_cache[:, :T_prefill, :, :] = v_init + cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q_single, k_cache, v_cache, k=k_single, v=v_single, + cache_seqlens=cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token") + print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_kvcache_single_token_sliding_window(self): + """Test single token decode with sliding window smaller than cache size (SSSL inference path).""" + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 32 # Enough tokens to exceed window + window = 8 # Window SMALLER than cache size + + k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_cache[:, :T_prefill, :, :] = k_init + v_cache[:, :T_prefill, :, :] = v_init + cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q_single, k_cache, v_cache, k=k_single, v=v_single, + cache_seqlens=cache_seqlens, + causal=True, window_size=(window, 0) # window=8 < Tk=33 + ) + + y_fa2, y_sdpa = run_fa2_and_sdpa(run) + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token_sliding_window") + print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_backward_gradients_match(self): + """Verify gradients are similar between FA2 and SDPA.""" + B, T, H, D = 2, 32, 4, 16 + + q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + q = q_data.clone().requires_grad_(True) + k = k_data.clone().requires_grad_(True) + v = v_data.clone().requires_grad_(True) + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + loss = y.sum() + loss.backward() + return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach() + + set_impl('fa2') + y_fa2, q_grad_fa2, k_grad_fa2, v_grad_fa2 = run() + set_impl('sdpa') + y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run() + set_impl(None) + + max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "backward_output") + print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(q_grad_fa2, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05) + print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(k_grad_fa2, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05) + print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(v_grad_fa2, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05) + print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + # ============================================================================= # Override mechanism tests # ============================================================================= @@ -358,6 +572,19 @@ class TestOverrideMechanism: set_impl(None) assert fa_module.USE_FA3 == HAS_FA3 + @pytest.mark.skipif(not HAS_FA2, reason="FA2 required") + def test_override_fa2(self): + """Test that override='fa2' uses FA2.""" + set_impl('fa2') + assert fa_module.USE_FA2 == True + set_impl(None) + + @pytest.mark.skipif(not (HAS_FA2 and not HAS_FA3), reason="FA2 auto-selection only applies when FA3 is unavailable") + def test_override_auto_fa2(self): + """Test that override=None picks FA2 when FA3 is unavailable.""" + set_impl(None) + assert fa_module.USE_FA2 == True + if __name__ == "__main__": print(f"PyTorch version: {torch.__version__}") @@ -367,6 +594,7 @@ if __name__ == "__main__": major, minor = torch.cuda.get_device_capability() print(f"Compute capability: {major}.{minor}") print(f"HAS_FA3: {HAS_FA3}") + print(f"HAS_FA2: {HAS_FA2}") print() pytest.main([__file__, "-v", "-s"])