diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..86f440b 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -23,13 +23,8 @@ from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW -# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar) -import os -os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" -# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. -# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) -from kernels import get_kernel -flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface +# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere +from nanochat.flash_attention import flash_attn @dataclass class GPTConfig: @@ -87,8 +82,7 @@ class CausalSelfAttention(nn.Module): q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - # Attention with Flash Attention 3 - # FA3 handles GQA automatically when n_kv_heads < n_heads + # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window diff --git a/scripts/base_train.py b/scripts/base_train.py index 5293cd8..c61986e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -27,6 +27,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine +from nanochat.flash_attention import HAS_FA3 from scripts.base_eval import evaluate_model print_banner() @@ -86,6 +87,18 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) +# Flash Attention status +if HAS_FA3: + print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") +else: + print0("!" * 80) + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") + print0("WARNING: Training will be less efficient without FA3") + if args.window_pattern != "L": + print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") + print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") + print0("!" * 80) + # Tokenizer will be useful for evaluation, also we need the vocab size tokenizer = get_tokenizer() token_bytes = get_token_bytes(device=device) diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py new file mode 100644 index 0000000..2cf3ed7 --- /dev/null +++ b/tests/test_attention_fallback.py @@ -0,0 +1,338 @@ +""" +Test Flash Attention unified interface - verify FA3 and SDPA produce identical results. + +Run: python -m pytest tests/test_attention_fallback.py -v -s + +Note on test structure: + Tests are split into two classes due to dtype/device constraints: + + 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs + and verify they produce identical results. These require a Hopper GPU (FA3 only + works on sm90+) and use bfloat16 (FA3 doesn't support float32). + + 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run + on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. +""" +import torch +import pytest +import nanochat.flash_attention as fa_module +from nanochat.flash_attention import flash_attn, HAS_FA3 +from nanochat.engine import KVCache + + +def set_impl(impl): + """Set the implementation override ('fa3', 'sdpa', or None for auto).""" + fa_module._override_impl = impl + + +def run_both_impls(fn): + """Run a function with both FA3 and SDPA, return both outputs.""" + set_impl('fa3') + out_fa3 = fn() + set_impl('sdpa') + out_sdpa = fn() + set_impl(None) # reset + return out_fa3, out_sdpa + + +def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): + """Assert two tensors are close, with helpful error message.""" + max_diff = (t1 - t2).abs().max().item() + mean_diff = (t1 - t2).abs().mean().item() + assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \ + f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" + return max_diff, mean_diff + + +# ============================================================================= +# FA3 vs SDPA comparison tests (require Hopper GPU) +# ============================================================================= +@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") +class TestFA3VsSDPA: + """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" + + DEVICE = "cuda" + DTYPE = torch.bfloat16 + + def test_basic_causal(self): + """Basic causal attention.""" + B, T, H, D = 2, 64, 4, 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal") + print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_full_context(self): + """Full context (window_size=-1).""" + B, T, H, D = 2, 128, 4, 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context") + print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_sliding_window(self): + """Sliding window attention.""" + B, T, H, D = 2, 128, 4, 32 + window = 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window") + print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_gqa(self): + """Group Query Attention (fewer KV heads than Q heads).""" + B, T, D = 2, 64, 32 + n_heads = 8 + n_kv_heads = 2 + + q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa") + print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_larger_model(self): + """Larger dimensions closer to real model.""" + B, T, H, D = 4, 256, 12, 64 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model") + print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_kvcache_prefill(self): + """Test prefill (inserting multiple tokens into empty cache).""" + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 16 + + q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, k=k, v=v, + cache_seqlens=cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill") + print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_kvcache_single_token(self): + """Test single token generation (cache already has content).""" + B, T_max, H, D = 2, 64, 4, 32 + T_prefill = 16 + + k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_cache[:, :T_prefill, :, :] = k_init + v_cache[:, :T_prefill, :, :] = v_init + cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) + return flash_attn.flash_attn_with_kvcache( + q_single, k_cache, v_cache, k=k_single, v=v_single, + cache_seqlens=cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + + y_fa3, y_sdpa = run_both_impls(run) + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token") + print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + def test_backward_gradients_match(self): + """Verify gradients are similar between FA3 and SDPA.""" + B, T, H, D = 2, 32, 4, 16 + + q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + def run(): + q = q_data.clone().requires_grad_(True) + k = k_data.clone().requires_grad_(True) + v = v_data.clone().requires_grad_(True) + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + loss = y.sum() + loss.backward() + return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach() + + set_impl('fa3') + y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run() + set_impl('sdpa') + y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run() + set_impl(None) + + max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output") + print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05) + print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05) + print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05) + print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + +# ============================================================================= +# SDPA-only tests (run on any device) +# ============================================================================= +class TestSDPAOnly: + """Test SDPA fallback works correctly. Runs on any device.""" + + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + def test_basic_forward(self): + """Test SDPA forward pass produces valid output.""" + set_impl('sdpa') + B, T, H, D = 2, 64, 4, 32 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + + assert y.shape == (B, T, H, D) + assert not torch.isnan(y).any(), "Output contains NaN" + set_impl(None) + + def test_backward(self): + """Test gradients flow through SDPA.""" + set_impl('sdpa') + B, T, H, D = 2, 32, 4, 16 + q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + loss = y.sum() + loss.backward() + + assert q.grad is not None, "No gradient for q" + assert k.grad is not None, "No gradient for k" + assert v.grad is not None, "No gradient for v" + assert not torch.isnan(q.grad).any(), "NaN in q gradient" + set_impl(None) + + def test_kvcache(self): + """Test SDPA with KV cache.""" + set_impl('sdpa') + B, T_max, H, D = 2, 64, 4, 32 + n_layers = 1 + + cache = KVCache( + batch_size=B, num_heads=H, seq_len=T_max, head_dim=D, + num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE + ) + k_cache, v_cache = cache.get_layer_cache(0) + + # Prefill + T_prefill = 16 + q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) + + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, k=k, v=v, + cache_seqlens=cache.cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + cache.advance(T_prefill) + + assert y.shape == (B, T_prefill, H, D) + assert cache.get_pos() == T_prefill + + # Generate single token + q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) + + y_single = flash_attn.flash_attn_with_kvcache( + q_single, k_cache, v_cache, k=k_single, v=v_single, + cache_seqlens=cache.cache_seqlens, + causal=True, window_size=(T_max, 0) + ) + cache.advance(1) + + assert y_single.shape == (B, 1, H, D) + assert cache.get_pos() == T_prefill + 1 + set_impl(None) + + +# ============================================================================= +# Override mechanism tests +# ============================================================================= +class TestOverrideMechanism: + """Test that the override mechanism works correctly.""" + + @pytest.mark.skipif(not HAS_FA3, reason="FA3 required") + def test_override_fa3(self): + """Test that override='fa3' uses FA3.""" + set_impl('fa3') + assert fa_module._use_fa3() == True + set_impl(None) + + def test_override_sdpa(self): + """Test that override='sdpa' uses SDPA.""" + set_impl('sdpa') + assert fa_module._use_fa3() == False + set_impl(None) + + def test_override_auto(self): + """Test that override=None uses auto-detection.""" + set_impl(None) + assert fa_module._use_fa3() == HAS_FA3 + + +if __name__ == "__main__": + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name()}") + major, minor = torch.cuda.get_device_capability() + print(f"Compute capability: {major}.{minor}") + print(f"HAS_FA3: {HAS_FA3}") + print() + + pytest.main([__file__, "-v", "-s"])