diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 69b7144..c79d124 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -2,7 +2,7 @@ Unified Flash Attention interface with automatic FA3/SDPA switching. Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. +to PyTorch SDPA on incompatible CUDA GPUs, MPS, and CPU. Usage (drop-in replacement for FA3): from nanochat.flash_attention import flash_attn diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..39c1d60 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW -# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere +# Our custom Flash Attention module that automatically uses FA3 when compatible and SDPA fallback otherwise from nanochat.flash_attention import flash_attn @dataclass @@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module): q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) + # Flash Attention (FA3 or SDPA fallback) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window diff --git a/scripts/base_train.py b/scripts/base_train.py index 9be4b6b..4d17638 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -96,7 +96,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", # Flash Attention status if HAS_FA3: - print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") + print0("✓ Using Flash Attention 3: efficient, new and awesome.") else: print0("!" * 80) print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 9741c7f..5c011dc 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -7,8 +7,8 @@ Note on test structure: Tests are split into two classes due to dtype/device constraints: 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs - and verify they produce identical results. These require a Hopper GPU (FA3 only - works on sm90+) and use bfloat16 (FA3 doesn't support float32). + and verify they produce identical results. These require a compatible GPU (FA3 only + works on sm80 and sm90) and use bfloat16 (FA3 doesn't support float32). 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. @@ -45,11 +45,11 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): # ============================================================================= -# FA3 vs SDPA comparison tests (require Hopper GPU) +# FA3 vs SDPA comparison tests # ============================================================================= @pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") class TestFA3VsSDPA: - """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" + """Compare FA3 and SDPA produce identical results.""" DEVICE = "cuda" DTYPE = torch.bfloat16