From 433aacf77010263644a578e034188fc657be9395 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 2 Feb 2026 18:38:54 +0100 Subject: [PATCH 1/8] improve FA3 kernel loading --- nanochat/flash_attention.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..69b7144 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -18,22 +18,21 @@ import torch.nn.functional as F # ============================================================================= -# Detection: Try to load FA3 on Hopper+ GPUs +# Detection: Try to load FA3 on CUDA GPUs # ============================================================================= def _load_flash_attention_3(): - """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" + """Try to load Flash Attention 3.""" + hf_kernel = 'kernels-community/flash-attn3' if not torch.cuda.is_available(): return None try: - major, _ = torch.cuda.get_device_capability() - # FA3 kernels are compiled for Hopper (sm90) only - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled - if major != 9: + from kernels import get_kernel, has_kernel + supported = has_kernel(hf_kernel) + if not supported: return None import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - from kernels import get_kernel - return get_kernel('varunneal/flash-attention-3').flash_attn_interface + return get_kernel(hf_kernel).flash_attn_interface except Exception: return None From 535f664cc2738d2106bb22aed29b55e3c5ba63fd Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 2 Feb 2026 19:41:42 +0100 Subject: [PATCH 2/8] also update comments --- nanochat/flash_attention.py | 2 +- nanochat/gpt.py | 4 ++-- scripts/base_train.py | 2 +- tests/test_attention_fallback.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 69b7144..c79d124 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -2,7 +2,7 @@ Unified Flash Attention interface with automatic FA3/SDPA switching. Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. +to PyTorch SDPA on incompatible CUDA GPUs, MPS, and CPU. Usage (drop-in replacement for FA3): from nanochat.flash_attention import flash_attn diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..39c1d60 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW -# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere +# Our custom Flash Attention module that automatically uses FA3 when compatible and SDPA fallback otherwise from nanochat.flash_attention import flash_attn @dataclass @@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module): q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) + # Flash Attention (FA3 or SDPA fallback) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window diff --git a/scripts/base_train.py b/scripts/base_train.py index 9be4b6b..4d17638 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -96,7 +96,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", # Flash Attention status if HAS_FA3: - print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") + print0("✓ Using Flash Attention 3: efficient, new and awesome.") else: print0("!" * 80) print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 9741c7f..5c011dc 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -7,8 +7,8 @@ Note on test structure: Tests are split into two classes due to dtype/device constraints: 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs - and verify they produce identical results. These require a Hopper GPU (FA3 only - works on sm90+) and use bfloat16 (FA3 doesn't support float32). + and verify they produce identical results. These require a compatible GPU (FA3 only + works on sm80 and sm90) and use bfloat16 (FA3 doesn't support float32). 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. @@ -45,11 +45,11 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): # ============================================================================= -# FA3 vs SDPA comparison tests (require Hopper GPU) +# FA3 vs SDPA comparison tests # ============================================================================= @pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") class TestFA3VsSDPA: - """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" + """Compare FA3 and SDPA produce identical results.""" DEVICE = "cuda" DTYPE = torch.bfloat16 From 345935c5f30ffe45a507a5a59b97f8cece1ae6e0 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 21:38:48 +0100 Subject: [PATCH 3/8] Add back docstring for clarification --- nanochat/flash_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index c79d124..67917b1 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -26,6 +26,8 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: + # FA3 kernels are currently compiled for Hopper (sm90) and Ampere (sm80/sm86) + # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled from kernels import get_kernel, has_kernel supported = has_kernel(hf_kernel) if not supported: From 0c4ce3697ea4d4378ba9183299d127502465b984 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 22:57:15 +0100 Subject: [PATCH 4/8] Ada is also supported --- nanochat/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 67917b1..c7e2d3e 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -26,8 +26,8 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: - # FA3 kernels are currently compiled for Hopper (sm90) and Ampere (sm80/sm86) - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled + # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) + # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released from kernels import get_kernel, has_kernel supported = has_kernel(hf_kernel) if not supported: From 683b88c2b5e534efa025de63777f9aec0d36ed09 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:45:25 +0100 Subject: [PATCH 5/8] keep varunneal kernel on H100, use community kernel for other supported cuda architectures --- nanochat/flash_attention.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 2b1f41a..6b5dae0 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -22,16 +22,15 @@ import torch.nn.functional as F # ============================================================================= def _load_flash_attention_3(): """Try to load Flash Attention 3.""" - hf_kernel = 'kernels-community/flash-attn3' if not torch.cuda.is_available(): return None try: + cap = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - from kernels import get_kernel, has_kernel - supported = has_kernel(hf_kernel) - if not supported: - return None + # varunneal kernel obtains better results for H100 + hf_kernel = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" return get_kernel(hf_kernel).flash_attn_interface From 5985cd867b9a891aed6f903776dc61e3ceabe2b2 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:47:09 +0100 Subject: [PATCH 6/8] use capability in the same way as on master --- nanochat/flash_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 6b5dae0..098afc3 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -25,11 +25,11 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: - cap = torch.cuda.get_device_capability() + major, _ = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - # varunneal kernel obtains better results for H100 - hf_kernel = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + # varunneal kernel obtains better results for H100/Hopper + hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" From ff833f8137426629305b0e24fe7c90cd1bc034ea Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:49:09 +0100 Subject: [PATCH 7/8] small fix --- nanochat/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 098afc3..3950da5 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -29,10 +29,10 @@ def _load_flash_attention_3(): # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released # varunneal kernel obtains better results for H100/Hopper - hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" - from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + from kernels import get_kernel + hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" return get_kernel(hf_kernel).flash_attn_interface except Exception: return None From d343c939c0550b1f6780051790adcd5f4d040f22 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:52:50 +0100 Subject: [PATCH 8/8] fix for unsupported cuda --- nanochat/flash_attention.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 3950da5..7412b91 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -28,12 +28,20 @@ def _load_flash_attention_3(): major, _ = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - # varunneal kernel obtains better results for H100/Hopper import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - from kernels import get_kernel - hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" - return get_kernel(hf_kernel).flash_attn_interface + from kernels import get_kernel, has_kernel + # The varunneal kernel obtains better results for H100/Hopper + if major == 9: + hf_kernel = "varunneal/flash-attention-3" + return get_kernel(hf_kernel).flash_attn_interface + else: + hf_kernel = "kernels-community/flash-attn3" + if has_kernel(hf_kernel): + return get_kernel(hf_kernel).flash_attn_interface + else: + return None + except Exception: return None