This commit is contained in:
Sofie Van Landeghem 2026-02-18 22:57:41 +01:00 committed by GitHub
commit 90dda4d95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 16 deletions

View File

@ -2,7 +2,7 @@
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
to PyTorch SDPA on incompatible CUDA GPUs, MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -18,22 +18,23 @@ import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# Detection: Try to load FA3 on CUDA GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
"""Try to load Flash Attention 3."""
hf_kernel = 'kernels-community/flash-attn3'
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
# FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86)
# Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released
from kernels import get_kernel, has_kernel
supported = has_kernel(hf_kernel)
if not supported:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
return get_kernel(hf_kernel).flash_attn_interface
except Exception:
return None

View File

@ -22,7 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
# Our custom Flash Attention module that automatically uses FA3 when compatible and SDPA fallback otherwise
from nanochat.flash_attention import flash_attn
@dataclass
@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# Flash Attention (FA3 or SDPA fallback)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window

View File

@ -102,7 +102,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
# Flash Attention status
if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
print0("✓ Using Flash Attention 3: efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")

View File

@ -7,8 +7,8 @@ Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
and verify they produce identical results. These require a compatible GPU (FA3 only
works on sm80 and sm90) and use bfloat16 (FA3 doesn't support float32).
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
@ -45,11 +45,11 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# FA3 vs SDPA comparison tests
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
"""Compare FA3 and SDPA produce identical results."""
DEVICE = "cuda"
DTYPE = torch.bfloat16