""" Test Flash Attention unified interface - verify FA3 and SDPA produce identical results. Run: python -m pytest tests/test_attention_fallback.py -v -s Note on test structure: Tests are split into two classes due to dtype/device constraints: 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs and verify they produce identical results. These require a Hopper GPU (FA3 only works on sm90+) and use bfloat16 (FA3 doesn't support float32). 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. """ import torch import pytest import nanochat.flash_attention as fa_module from nanochat.flash_attention import flash_attn, HAS_FA3 from nanochat.engine import KVCache def set_impl(impl): """Set the implementation override ('fa3', 'sdpa', or None for auto).""" fa_module._override_impl = impl def run_both_impls(fn): """Run a function with both FA3 and SDPA, return both outputs.""" set_impl('fa3') out_fa3 = fn() set_impl('sdpa') out_sdpa = fn() set_impl(None) # reset return out_fa3, out_sdpa def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): """Assert two tensors are close, with helpful error message.""" max_diff = (t1 - t2).abs().max().item() mean_diff = (t1 - t2).abs().mean().item() assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \ f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" return max_diff, mean_diff # ============================================================================= # FA3 vs SDPA comparison tests (require Hopper GPU) # ============================================================================= @pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") class TestFA3VsSDPA: """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" DEVICE = "cuda" DTYPE = torch.bfloat16 def test_basic_causal(self): """Basic causal attention.""" B, T, H, D = 2, 64, 4, 32 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal") print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_full_context(self): """Full context (window_size=-1).""" B, T, H, D = 2, 128, 4, 32 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context") print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_sliding_window(self): """Sliding window attention.""" B, T, H, D = 2, 128, 4, 32 window = 32 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window") print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_gqa(self): """Group Query Attention (fewer KV heads than Q heads).""" B, T, D = 2, 64, 32 n_heads = 8 n_kv_heads = 2 q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) def run(): return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa") print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_larger_model(self): """Larger dimensions closer to real model.""" B, T, H, D = 4, 256, 12, 64 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model") print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_kvcache_prefill(self): """Test prefill (inserting multiple tokens into empty cache).""" B, T_max, H, D = 2, 64, 4, 32 T_prefill = 16 q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE) return flash_attn.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=True, window_size=(T_max, 0) ) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill") print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_kvcache_single_token(self): """Test single token generation (cache already has content).""" B, T_max, H, D = 2, 64, 4, 32 T_prefill = 16 k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) k_cache[:, :T_prefill, :, :] = k_init v_cache[:, :T_prefill, :, :] = v_init cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) return flash_attn.flash_attn_with_kvcache( q_single, k_cache, v_cache, k=k_single, v=v_single, cache_seqlens=cache_seqlens, causal=True, window_size=(T_max, 0) ) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token") print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") def test_backward_gradients_match(self): """Verify gradients are similar between FA3 and SDPA.""" B, T, H, D = 2, 32, 4, 16 q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): q = q_data.clone().requires_grad_(True) k = k_data.clone().requires_grad_(True) v = v_data.clone().requires_grad_(True) y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) loss = y.sum() loss.backward() return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach() set_impl('fa3') y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run() set_impl('sdpa') y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run() set_impl(None) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output") print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05) print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05) print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05) print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") # ============================================================================= # SDPA-only tests (run on any device) # ============================================================================= class TestSDPAOnly: """Test SDPA fallback works correctly. Runs on any device.""" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 def test_basic_forward(self): """Test SDPA forward pass produces valid output.""" set_impl('sdpa') B, T, H, D = 2, 64, 4, 32 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) assert y.shape == (B, T, H, D) assert not torch.isnan(y).any(), "Output contains NaN" set_impl(None) def test_backward(self): """Test gradients flow through SDPA.""" set_impl('sdpa') B, T, H, D = 2, 32, 4, 16 q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) loss = y.sum() loss.backward() assert q.grad is not None, "No gradient for q" assert k.grad is not None, "No gradient for k" assert v.grad is not None, "No gradient for v" assert not torch.isnan(q.grad).any(), "NaN in q gradient" set_impl(None) def test_kvcache(self): """Test SDPA with KV cache.""" set_impl('sdpa') B, T_max, H, D = 2, 64, 4, 32 n_layers = 1 cache = KVCache( batch_size=B, num_heads=H, seq_len=T_max, head_dim=D, num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE ) k_cache, v_cache = cache.get_layer_cache(0) # Prefill T_prefill = 16 q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) y = flash_attn.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache.cache_seqlens, causal=True, window_size=(T_max, 0) ) cache.advance(T_prefill) assert y.shape == (B, T_prefill, H, D) assert cache.get_pos() == T_prefill # Generate single token q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) y_single = flash_attn.flash_attn_with_kvcache( q_single, k_cache, v_cache, k=k_single, v=v_single, cache_seqlens=cache.cache_seqlens, causal=True, window_size=(T_max, 0) ) cache.advance(1) assert y_single.shape == (B, 1, H, D) assert cache.get_pos() == T_prefill + 1 set_impl(None) # ============================================================================= # Override mechanism tests # ============================================================================= class TestOverrideMechanism: """Test that the override mechanism works correctly.""" @pytest.mark.skipif(not HAS_FA3, reason="FA3 required") def test_override_fa3(self): """Test that override='fa3' uses FA3.""" set_impl('fa3') assert fa_module._use_fa3() == True set_impl(None) def test_override_sdpa(self): """Test that override='sdpa' uses SDPA.""" set_impl('sdpa') assert fa_module._use_fa3() == False set_impl(None) def test_override_auto(self): """Test that override=None uses auto-detection.""" set_impl(None) assert fa_module._use_fa3() == HAS_FA3 if __name__ == "__main__": print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name()}") major, minor = torch.cuda.get_device_capability() print(f"Compute capability: {major}.{minor}") print(f"HAS_FA3: {HAS_FA3}") print() pytest.main([__file__, "-v", "-s"])