mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge d343c939c0 into a445144d39
This commit is contained in:
commit
2cbcff2ac5
|
|
@ -2,7 +2,7 @@
|
|||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
to PyTorch SDPA on incompatible CUDA GPUs, MPS, and CPU.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
|
@ -18,22 +18,30 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# Detection: Try to load FA3 on CUDA GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
||||
"""Try to load Flash Attention 3."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
# FA3 kernels are compiled for Hopper (sm90) only
|
||||
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
||||
if major != 9:
|
||||
return None
|
||||
# FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86)
|
||||
# Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
from kernels import get_kernel, has_kernel
|
||||
# The varunneal kernel obtains better results for H100/Hopper
|
||||
if major == 9:
|
||||
hf_kernel = "varunneal/flash-attention-3"
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
else:
|
||||
hf_kernel = "kernels-community/flash-attn3"
|
||||
if has_kernel(hf_kernel):
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
# Our custom Flash Attention module that automatically uses FA3 when compatible and SDPA fallback otherwise
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
@dataclass
|
||||
|
|
@ -101,7 +101,7 @@ class CausalSelfAttention(nn.Module):
|
|||
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.2
|
||||
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# Flash Attention (FA3 or SDPA fallback)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
|
|||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
print0("✓ Using Flash Attention 3: efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ Note on test structure:
|
|||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
|
||||
and verify they produce identical results. These require a compatible GPU (FA3 only
|
||||
works on sm80 and sm90) and use bfloat16 (FA3 doesn't support float32).
|
||||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
|
|
@ -46,11 +46,11 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# FA3 vs SDPA comparison tests (require Hopper GPU)
|
||||
# FA3 vs SDPA comparison tests
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
|
||||
class TestFA3VsSDPA:
|
||||
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
|
||||
"""Compare FA3 and SDPA produce identical results."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user